diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/converters.py | 217 | 
1 files changed, 203 insertions, 14 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 4deb59f87..2e118d476 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -2,6 +2,7 @@ import logging  import re  import typing as t  from datetime import datetime +from functools import partial  from ssl import CertificateError  import dateutil.parser @@ -9,11 +10,18 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Context, Converter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from discord.utils import DISCORD_EPOCH, snowflake_time +from bot.api import ResponseCodeError +from bot.constants import URLs +from bot.utils.regex import INVITE_RE  log = logging.getLogger(__name__) +DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") +  def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:      """ @@ -34,6 +42,90 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s      return converter +class ValidDiscordServerInvite(Converter): +    """ +    A converter that validates whether a given string is a valid Discord server invite. + +    Raises 'BadArgument' if: +    - The string is not a valid Discord server invite. +    - The string is valid, but is an invite for a group DM. +    - The string is valid, but is expired. + +    Returns a (partial) guild object if: +    - The string is a valid vanity +    - The string is a full invite URI +    - The string contains the invite code (the stuff after discord.gg/) + +    See the Discord API docs for documentation on the guild object: +    https://discord.com/developers/docs/resources/guild#guild-object +    """ + +    async def convert(self, ctx: Context, server_invite: str) -> dict: +        """Check whether the string is a valid Discord server invite.""" +        invite_code = INVITE_RE.search(server_invite) +        if invite_code: +            response = await ctx.bot.http_session.get( +                f"{URLs.discord_invite_api}/{invite_code[1]}" +            ) +            if response.status != 404: +                invite_data = await response.json() +                return invite_data.get("guild") + +        id_converter = IDConverter() +        if id_converter._get_id_match(server_invite): +            raise BadArgument("Guild IDs are not supported, only invites.") + +        raise BadArgument("This does not appear to be a valid Discord server invite.") + + +class ValidFilterListType(Converter): +    """ +    A converter that checks whether the given string is a valid FilterList type. + +    Raises `BadArgument` if the argument is not a valid FilterList type, and simply +    passes through the given argument otherwise. +    """ + +    @staticmethod +    async def get_valid_types(bot: Bot) -> list: +        """ +        Try to get a list of valid filter list types. + +        Raise a BadArgument if the API can't respond. +        """ +        try: +            valid_types = await bot.api_client.get('bot/filter-lists/get-types') +        except ResponseCodeError: +            raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") + +        return [enum for enum, classname in valid_types] + +    async def convert(self, ctx: Context, list_type: str) -> str: +        """Checks whether the given string is a valid FilterList type.""" +        valid_types = await self.get_valid_types(ctx.bot) +        list_type = list_type.upper() + +        if list_type not in valid_types: + +            # Maybe the user is using the plural form of this type, +            # e.g. "guild_invites" instead of "guild_invite". +            # +            # This code will support the simple plural form (a single 's' at the end), +            # which works for all current list types, but if a list type is added in the future +            # which has an irregular plural form (like 'ies'), this code will need to be +            # refactored to support this. +            if list_type.endswith("S") and list_type[:-1] in valid_types: +                list_type = list_type[:-1] + +            else: +                valid_types_list = '\n'.join([f"โข {type_.lower()}" for type_ in valid_types]) +                raise BadArgument( +                    f"You have provided an invalid list type!\n\n" +                    f"Please provide one of the following: \n{valid_types_list}" +                ) +        return list_type + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -85,17 +177,42 @@ class ValidURL(Converter):          return url -class InfractionSearchQuery(Converter): -    """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" +class Snowflake(IDConverter): +    """ +    Converts to an int if the argument is a valid Discord snowflake. + +    A snowflake is valid if: + +    * It consists of 15-21 digits (0-9) +    * Its parsed datetime is after the Discord epoch +    * Its parsed datetime is less than 1 day after the current time +    """ + +    async def convert(self, ctx: Context, arg: str) -> int: +        """ +        Ensure `arg` matches the ID pattern and its timestamp is in range. + +        Return `arg` as an int if it's a valid snowflake. +        """ +        error = f"Invalid snowflake {arg!r}" + +        if not self._get_id_match(arg): +            raise BadArgument(error) + +        snowflake = int(arg) -    @staticmethod -    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: -        """Check if the argument is a Discord user, and if not, falls back to a string."""          try: -            maybe_snowflake = arg.strip("<@!>") -            return await ctx.bot.fetch_user(maybe_snowflake) -        except (discord.NotFound, discord.HTTPException): -            return arg +            time = snowflake_time(snowflake) +        except (OverflowError, OSError) as e: +            # Not sure if this can ever even happen, but let's be safe. +            raise BadArgument(f"{error}: {e}") + +        if time < DISCORD_EPOCH_DT: +            raise BadArgument(f"{error}: timestamp is before the Discord epoch.") +        elif (datetime.utcnow() - time).days < -1: +            raise BadArgument(f"{error}: timestamp is too far into the future.") + +        return snowflake  class Subreddit(Converter): @@ -181,8 +298,8 @@ class TagContentConverter(Converter):          return tag_content -class Duration(Converter): -    """Convert duration strings into UTC datetime.datetime objects.""" +class DurationDelta(Converter): +    """Convert duration strings into dateutil.relativedelta.relativedelta objects."""      duration_parser = re.compile(          r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" @@ -194,9 +311,9 @@ class Duration(Converter):          r"((?P<seconds>\d+?) ?(seconds|second|S|s))?"      ) -    async def convert(self, ctx: Context, duration: str) -> datetime: +    async def convert(self, ctx: Context, duration: str) -> relativedelta:          """ -        Converts a `duration` string to a datetime object that's `duration` in the future. +        Converts a `duration` string to a relativedelta object.          The converter supports the following symbols for each unit of time:          - years: `Y`, `y`, `year`, `years` @@ -215,6 +332,20 @@ class Duration(Converter):          duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()}          delta = relativedelta(**duration_dict) + +        return delta + + +class Duration(DurationDelta): +    """Convert duration strings into UTC datetime.datetime objects.""" + +    async def convert(self, ctx: Context, duration: str) -> datetime: +        """ +        Converts a `duration` string to a datetime object that's `duration` in the future. + +        The converter supports the same symbols for each unit of time as its parent class. +        """ +        delta = await super().convert(ctx, duration)          now = datetime.utcnow()          try: @@ -223,6 +354,32 @@ class Duration(Converter):              raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class OffTopicName(Converter): +    """A converter that ensures an added off-topic name is valid.""" + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Attempt to replace any invalid characters with their approximate Unicode equivalent.""" +        allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + +        # Chain multiple words to a single one +        argument = "-".join(argument.split()) + +        if not (2 <= len(argument) <= 96): +            raise BadArgument("Channel name must be between 2 and 96 chars long") + +        elif not all(c.isalnum() or c in allowed_characters for c in argument): +            raise BadArgument( +                "Channel name must only consist of " +                "alphanumeric characters, minus signs or apostrophes." +            ) + +        # Replace invalid characters with unicode alternatives. +        table = str.maketrans( +            allowed_characters, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-' +        ) +        return argument.translate(table) + +  class ISODateTime(Converter):      """Converts an ISO-8601 datetime string into a datetime.datetime.""" @@ -316,6 +473,24 @@ def proxy_user(user_id: str) -> discord.Object:      return user +class UserMentionOrID(UserConverter): +    """ +    Converts to a `discord.User`, but only if a mention or userID is provided. + +    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. +    This is useful in cases where that lookup strategy would lead to ambiguity. +    """ + +    async def convert(self, ctx: Context, argument: str) -> discord.User: +        """Convert the `arg` to a `discord.User`.""" +        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) + +        if match is not None: +            return await super().convert(ctx, argument) +        else: +            raise BadArgument(f"`{argument}` is not a User mention or a User ID.") + +  class FetchedUser(UserConverter):      """      Converts to a `discord.User` or, if it fails, a `discord.Object`. @@ -361,5 +536,19 @@ class FetchedUser(UserConverter):              raise BadArgument(f"User `{arg}` does not exist") +def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: +    """ +    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + +    The snowflake is expected to be within the first capture group in `pattern`. +    """ +    match = pattern.match(arg) +    if not match: +        raise BadArgument(f"Mention {str!r} is invalid.") + +    return int(match.group(1)) + +  Expiry = t.Union[Duration, ISODateTime]  FetchedMember = t.Union[discord.Member, FetchedUser] +UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) | 
