diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/converters.py | 103 | 
1 files changed, 102 insertions, 1 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 4a0633951..c9f525dd1 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -9,8 +9,11 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Context, Converter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from bot.api import ResponseCodeError +from bot.constants import URLs +from bot.utils.regex import INVITE_RE  log = logging.getLogger(__name__) @@ -34,6 +37,78 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s      return converter +class ValidDiscordServerInvite(Converter): +    """ +    A converter that validates whether a given string is a valid Discord server invite. + +    Raises 'BadArgument' if: +    - The string is not a valid Discord server invite. +    - The string is valid, but is an invite for a group DM. +    - The string is valid, but is expired. + +    Returns a (partial) guild object if: +    - The string is a valid vanity +    - The string is a full invite URI +    - The string contains the invite code (the stuff after discord.gg/) + +    See the Discord API docs for documentation on the guild object: +    https://discord.com/developers/docs/resources/guild#guild-object +    """ + +    async def convert(self, ctx: Context, server_invite: str) -> dict: +        """Check whether the string is a valid Discord server invite.""" +        invite_code = INVITE_RE.search(server_invite) +        if invite_code: +            response = await ctx.bot.http_session.get( +                f"{URLs.discord_invite_api}/{invite_code[1]}" +            ) +            if response.status != 404: +                invite_data = await response.json() +                return invite_data.get("guild") + +        id_converter = IDConverter() +        if id_converter._get_id_match(server_invite): +            raise BadArgument("Guild IDs are not supported, only invites.") + +        raise BadArgument("This does not appear to be a valid Discord server invite.") + + +class ValidFilterListType(Converter): +    """ +    A converter that checks whether the given string is a valid FilterList type. + +    Raises `BadArgument` if the argument is not a valid FilterList type, and simply +    passes through the given argument otherwise. +    """ + +    @staticmethod +    async def get_valid_types(bot: Bot) -> list: +        """ +        Try to get a list of valid filter list types. + +        Raise a BadArgument if the API can't respond. +        """ +        try: +            valid_types = await bot.api_client.get('bot/filter-lists/get-types') +        except ResponseCodeError: +            raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") + +        return [enum for enum, classname in valid_types] + +    async def convert(self, ctx: Context, list_type: str) -> str: +        """Checks whether the given string is a valid FilterList type.""" +        valid_types = await self.get_valid_types(ctx.bot) +        list_type = list_type.upper() + +        if list_type not in valid_types: +            valid_types_list = '\n'.join([f"โข {type_.lower()}" for type_ in valid_types]) +            raise BadArgument( +                f"You have provided an invalid list type!\n\n" +                f"Please provide one of the following: \n{valid_types_list}" +            ) +        return list_type + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -237,6 +312,32 @@ class Duration(DurationDelta):              raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class OffTopicName(Converter): +    """A converter that ensures an added off-topic name is valid.""" + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Attempt to replace any invalid characters with their approximate Unicode equivalent.""" +        allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + +        # Chain multiple words to a single one +        argument = "-".join(argument.split()) + +        if not (2 <= len(argument) <= 96): +            raise BadArgument("Channel name must be between 2 and 96 chars long") + +        elif not all(c.isalnum() or c in allowed_characters for c in argument): +            raise BadArgument( +                "Channel name must only consist of " +                "alphanumeric characters, minus signs or apostrophes." +            ) + +        # Replace invalid characters with unicode alternatives. +        table = str.maketrans( +            allowed_characters, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-' +        ) +        return argument.translate(table) + +  class ISODateTime(Converter):      """Converts an ISO-8601 datetime string into a datetime.datetime.""" | 
