diff options
Diffstat (limited to 'bot/converters.py')
| -rw-r--r-- | bot/converters.py | 307 |
1 files changed, 269 insertions, 38 deletions
diff --git a/bot/converters.py b/bot/converters.py index cca57a02d..d0a9731d6 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -2,6 +2,7 @@ import logging import re import typing as t from datetime import datetime +from functools import partial from ssl import CertificateError import dateutil.parser @@ -9,11 +10,18 @@ import dateutil.tz import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Context, Converter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from discord.utils import DISCORD_EPOCH, snowflake_time +from bot.api import ResponseCodeError +from bot.constants import URLs +from bot.utils.regex import INVITE_RE log = logging.getLogger(__name__) +DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") + def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: """ @@ -34,6 +42,90 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s return converter +class ValidDiscordServerInvite(Converter): + """ + A converter that validates whether a given string is a valid Discord server invite. + + Raises 'BadArgument' if: + - The string is not a valid Discord server invite. + - The string is valid, but is an invite for a group DM. + - The string is valid, but is expired. + + Returns a (partial) guild object if: + - The string is a valid vanity + - The string is a full invite URI + - The string contains the invite code (the stuff after discord.gg/) + + See the Discord API docs for documentation on the guild object: + https://discord.com/developers/docs/resources/guild#guild-object + """ + + async def convert(self, ctx: Context, server_invite: str) -> dict: + """Check whether the string is a valid Discord server invite.""" + invite_code = INVITE_RE.search(server_invite) + if invite_code: + response = await ctx.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite_code[1]}" + ) + if response.status != 404: + invite_data = await response.json() + return invite_data.get("guild") + + id_converter = IDConverter() + if id_converter._get_id_match(server_invite): + raise BadArgument("Guild IDs are not supported, only invites.") + + raise BadArgument("This does not appear to be a valid Discord server invite.") + + +class ValidFilterListType(Converter): + """ + A converter that checks whether the given string is a valid FilterList type. + + Raises `BadArgument` if the argument is not a valid FilterList type, and simply + passes through the given argument otherwise. + """ + + @staticmethod + async def get_valid_types(bot: Bot) -> list: + """ + Try to get a list of valid filter list types. + + Raise a BadArgument if the API can't respond. + """ + try: + valid_types = await bot.api_client.get('bot/filter-lists/get-types') + except ResponseCodeError: + raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") + + return [enum for enum, classname in valid_types] + + async def convert(self, ctx: Context, list_type: str) -> str: + """Checks whether the given string is a valid FilterList type.""" + valid_types = await self.get_valid_types(ctx.bot) + list_type = list_type.upper() + + if list_type not in valid_types: + + # Maybe the user is using the plural form of this type, + # e.g. "guild_invites" instead of "guild_invite". + # + # This code will support the simple plural form (a single 's' at the end), + # which works for all current list types, but if a list type is added in the future + # which has an irregular plural form (like 'ies'), this code will need to be + # refactored to support this. + if list_type.endswith("S") and list_type[:-1] in valid_types: + list_type = list_type[:-1] + + else: + valid_types_list = '\n'.join([f"โข {type_.lower()}" for type_ in valid_types]) + raise BadArgument( + f"You have provided an invalid list type!\n\n" + f"Please provide one of the following: \n{valid_types_list}" + ) + return list_type + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -85,17 +177,42 @@ class ValidURL(Converter): return url -class InfractionSearchQuery(Converter): - """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" +class Snowflake(IDConverter): + """ + Converts to an int if the argument is a valid Discord snowflake. + + A snowflake is valid if: + + * It consists of 15-21 digits (0-9) + * Its parsed datetime is after the Discord epoch + * Its parsed datetime is less than 1 day after the current time + """ + + async def convert(self, ctx: Context, arg: str) -> int: + """ + Ensure `arg` matches the ID pattern and its timestamp is in range. + + Return `arg` as an int if it's a valid snowflake. + """ + error = f"Invalid snowflake {arg!r}" + + if not self._get_id_match(arg): + raise BadArgument(error) + + snowflake = int(arg) - @staticmethod - async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: - """Check if the argument is a Discord user, and if not, falls back to a string.""" try: - maybe_snowflake = arg.strip("<@!>") - return await ctx.bot.fetch_user(maybe_snowflake) - except (discord.NotFound, discord.HTTPException): - return arg + time = snowflake_time(snowflake) + except (OverflowError, OSError) as e: + # Not sure if this can ever even happen, but let's be safe. + raise BadArgument(f"{error}: {e}") + + if time < DISCORD_EPOCH_DT: + raise BadArgument(f"{error}: timestamp is before the Discord epoch.") + elif (datetime.utcnow() - time).days < -1: + raise BadArgument(f"{error}: timestamp is too far into the future.") + + return snowflake class Subreddit(Converter): @@ -141,40 +258,24 @@ class TagNameConverter(Converter): @staticmethod async def convert(ctx: Context, tag_name: str) -> str: """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" - def is_number(value: str) -> bool: - """Check to see if the input string is numeric.""" - try: - float(value) - except ValueError: - return False - return True - tag_name = tag_name.lower().strip() # The tag name has at least one invalid character. if ascii(tag_name)[1:-1] != tag_name: - log.warning(f"{ctx.author} tried to put an invalid character in a tag name. " - "Rejecting the request.") raise BadArgument("Don't be ridiculous, you can't use that character!") # The tag name is either empty, or consists of nothing but whitespace. elif not tag_name: - log.warning(f"{ctx.author} tried to create a tag with a name consisting only of whitespace. " - "Rejecting the request.") raise BadArgument("Tag names should not be empty, or filled with whitespace.") - # The tag name is a number of some kind, we don't allow that. - elif is_number(tag_name): - log.warning(f"{ctx.author} tried to create a tag with a digit as its name. " - "Rejecting the request.") - raise BadArgument("Tag names can't be numbers.") - # The tag name is longer than 127 characters. elif len(tag_name) > 127: - log.warning(f"{ctx.author} tried to request a tag name with over 127 characters. " - "Rejecting the request.") raise BadArgument("Are you insane? That's way too long!") + # The tag name is ascii but does not contain any letters. + elif not any(character.isalpha() for character in tag_name): + raise BadArgument("Tag names must contain at least one letter.") + return tag_name @@ -192,15 +293,13 @@ class TagContentConverter(Converter): # The tag contents should not be empty, or filled with whitespace. if not tag_content: - log.warning(f"{ctx.author} tried to create a tag containing only whitespace. " - "Rejecting the request.") raise BadArgument("Tag contents should not be empty, or filled with whitespace.") return tag_content -class Duration(Converter): - """Convert duration strings into UTC datetime.datetime objects.""" +class DurationDelta(Converter): + """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" duration_parser = re.compile( r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" @@ -212,9 +311,9 @@ class Duration(Converter): r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" ) - async def convert(self, ctx: Context, duration: str) -> datetime: + async def convert(self, ctx: Context, duration: str) -> relativedelta: """ - Converts a `duration` string to a datetime object that's `duration` in the future. + Converts a `duration` string to a relativedelta object. The converter supports the following symbols for each unit of time: - years: `Y`, `y`, `year`, `years` @@ -233,9 +332,52 @@ class Duration(Converter): duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} delta = relativedelta(**duration_dict) + + return delta + + +class Duration(DurationDelta): + """Convert duration strings into UTC datetime.datetime objects.""" + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports the same symbols for each unit of time as its parent class. + """ + delta = await super().convert(ctx, duration) now = datetime.utcnow() - return now + delta + try: + return now + delta + except ValueError: + raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") + + +class OffTopicName(Converter): + """A converter that ensures an added off-topic name is valid.""" + + async def convert(self, ctx: Context, argument: str) -> str: + """Attempt to replace any invalid characters with their approximate Unicode equivalent.""" + allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + + # Chain multiple words to a single one + argument = "-".join(argument.split()) + + if not (2 <= len(argument) <= 96): + raise BadArgument("Channel name must be between 2 and 96 chars long") + + elif not all(c.isalnum() or c in allowed_characters for c in argument): + raise BadArgument( + "Channel name must only consist of " + "alphanumeric characters, minus signs or apostrophes." + ) + + # Replace invalid characters with unicode alternatives. + table = str.maketrans( + allowed_characters, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-' + ) + return argument.translate(table) class ISODateTime(Converter): @@ -280,6 +422,34 @@ class ISODateTime(Converter): return dt +class HushDurationConverter(Converter): + """Convert passed duration to `int` minutes or `None`.""" + + MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)") + + async def convert(self, ctx: Context, argument: str) -> t.Optional[int]: + """ + Convert `argument` to a duration that's max 15 minutes or None. + + If `"forever"` is passed, None is returned; otherwise an int of the extracted time. + Accepted formats are: + * <duration>, + * <duration>m, + * <duration>M, + * forever. + """ + if argument == "forever": + return None + match = self.MINUTES_RE.match(argument) + if not match: + raise BadArgument(f"{argument} is not a valid minutes duration.") + + duration = int(match.group(1)) + if duration > 15: + raise BadArgument("Duration must be at most 15 minutes.") + return duration + + def proxy_user(user_id: str) -> discord.Object: """ Create a proxy user object from the given id. @@ -303,6 +473,24 @@ def proxy_user(user_id: str) -> discord.Object: return user +class UserMentionOrID(UserConverter): + """ + Converts to a `discord.User`, but only if a mention or userID is provided. + + Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. + This is useful in cases where that lookup strategy would lead to ambiguity. + """ + + async def convert(self, ctx: Context, argument: str) -> discord.User: + """Convert the `arg` to a `discord.User`.""" + match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) + + if match is not None: + return await super().convert(ctx, argument) + else: + raise BadArgument(f"`{argument}` is not a User mention or a User ID.") + + class FetchedUser(UserConverter): """ Converts to a `discord.User` or, if it fails, a `discord.Object`. @@ -341,12 +529,55 @@ class FetchedUser(UserConverter): except discord.HTTPException as e: # If the Discord error isn't `Unknown user`, return a proxy instead if e.code != 10013: - log.warning(f"Failed to fetch user, returning a proxy instead: status {e.status}") + log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") return proxy_user(arg) log.debug(f"Failed to fetch user {arg}: user does not exist.") raise BadArgument(f"User `{arg}` does not exist") +def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: + """ + Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + + The snowflake is expected to be within the first capture group in `pattern`. + """ + match = pattern.match(arg) + if not match: + raise BadArgument(f"Mention {str!r} is invalid.") + + return int(match.group(1)) + + +class Infraction(Converter): + """ + Attempts to convert a given infraction ID into an infraction. + + Alternatively, `l`, `last`, or `recent` can be passed in order to + obtain the most recent infraction by the actor. + """ + + async def convert(self, ctx: Context, arg: str) -> t.Optional[dict]: + """Attempts to convert `arg` into an infraction `dict`.""" + if arg in ("l", "last", "recent"): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + + infractions = await ctx.bot.api_client.get("bot/infractions", params=params) + + if not infractions: + raise BadArgument( + "Couldn't find most recent infraction; you have never given an infraction." + ) + else: + return infractions[0] + + else: + return await ctx.bot.api_client.get(f"bot/infractions/{arg}") + + Expiry = t.Union[Duration, ISODateTime] FetchedMember = t.Union[discord.Member, FetchedUser] +UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) |