diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/converters.py | 332 | 
1 files changed, 179 insertions, 153 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 2a3943831..559e759e1 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ -import logging +from __future__ import annotations +  import re  import typing as t -from datetime import datetime -from functools import partial +from datetime import datetime, timezone  from ssl import CertificateError  import dateutil.parser @@ -10,18 +10,26 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import escape_markdown, snowflake_time +from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs +from bot.errors import InvalidInfraction  from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier +from bot.log import get_logger +from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string -log = logging.getLogger(__name__) +if t.TYPE_CHECKING: +    from bot.exts.info.source import SourceType + +log = get_logger(__name__) -DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +DISCORD_EPOCH_DT = snowflake_time(0)  RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") @@ -64,10 +72,10 @@ class ValidDiscordServerInvite(Converter):      async def convert(self, ctx: Context, server_invite: str) -> dict:          """Check whether the string is a valid Discord server invite.""" -        invite_code = INVITE_RE.search(server_invite) +        invite_code = INVITE_RE.match(server_invite)          if invite_code:              response = await ctx.bot.http_session.get( -                f"{URLs.discord_invite_api}/{invite_code[1]}" +                f"{URLs.discord_invite_api}/{invite_code.group('invite')}"              )              if response.status != 404:                  invite_data = await response.json() @@ -128,6 +136,44 @@ class ValidFilterListType(Converter):          return list_type +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class PackageName(Converter):      """      A converter that checks whether the given string is a valid package name. @@ -191,11 +237,16 @@ class Inventory(Converter):      async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:          """Convert url to Intersphinx inventory URL."""          await ctx.trigger_typing() -        if (inventory := await _inventory_parser.fetch_inventory(url)) is None: -            raise BadArgument( -                f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." -            ) -        return url, inventory +        try: +            inventory = await _inventory_parser.fetch_inventory(url) +        except _inventory_parser.InvalidHeaderError: +            raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.") +        else: +            if inventory is None: +                raise BadArgument( +                    f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." +                ) +            return url, inventory  class Snowflake(IDConverter): @@ -230,64 +281,43 @@ class Snowflake(IDConverter):          if time < DISCORD_EPOCH_DT:              raise BadArgument(f"{error}: timestamp is before the Discord epoch.") -        elif (datetime.utcnow() - time).days < -1: +        elif (datetime.now(timezone.utc) - time).days < -1:              raise BadArgument(f"{error}: timestamp is too far into the future.")          return snowflake -class TagNameConverter(Converter): -    """ -    Ensure that a proposed tag name is valid. - -    Valid tag names meet the following conditions: -        * All ASCII characters -        * Has at least one non-whitespace character -        * Not solely numeric -        * Shorter than 127 characters -    """ +class SourceConverter(Converter): +    """Convert an argument into a help command, tag, command, or cog."""      @staticmethod -    async def convert(ctx: Context, tag_name: str) -> str: -        """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" -        tag_name = tag_name.lower().strip() - -        # The tag name has at least one invalid character. -        if ascii(tag_name)[1:-1] != tag_name: -            raise BadArgument("Don't be ridiculous, you can't use that character!") - -        # The tag name is either empty, or consists of nothing but whitespace. -        elif not tag_name: -            raise BadArgument("Tag names should not be empty, or filled with whitespace.") - -        # The tag name is longer than 127 characters. -        elif len(tag_name) > 127: -            raise BadArgument("Are you insane? That's way too long!") - -        # The tag name is ascii but does not contain any letters. -        elif not any(character.isalpha() for character in tag_name): -            raise BadArgument("Tag names must contain at least one letter.") +    async def convert(ctx: Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        if argument.lower() == "help": +            return ctx.bot.help_command -        return tag_name +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd -class TagContentConverter(Converter): -    """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +        tags_cog = ctx.bot.get_cog("Tags") +        show_tag = True -    @staticmethod -    async def convert(ctx: Context, tag_content: str) -> str: -        """ -        Ensure tag_content is non-empty and contains at least one non-whitespace character. - -        If tag_content is valid, return the stripped version. -        """ -        tag_content = tag_content.strip() - -        # The tag contents should not be empty, or filled with whitespace. -        if not tag_content: -            raise BadArgument("Tag contents should not be empty, or filled with whitespace.") +        if not tags_cog: +            show_tag = False +        else: +            identifier = TagIdentifier.from_string(argument.lower()) +            if identifier in tags_cog.tags: +                return identifier +        escaped_arg = escape_markdown(argument) -        return tag_content +        raise BadArgument( +            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." +        )  class DurationDelta(Converter): @@ -324,7 +354,7 @@ class Duration(DurationDelta):          The converter supports the same symbols for each unit of time as its parent class.          """          delta = await super().convert(ctx, duration) -        now = datetime.utcnow() +        now = datetime.now(timezone.utc)          try:              return now + delta @@ -332,10 +362,29 @@ class Duration(DurationDelta):              raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class Age(DurationDelta): +    """Convert duration strings into UTC datetime.datetime objects.""" + +    async def convert(self, ctx: Context, duration: str) -> datetime: +        """ +        Converts a `duration` string to a datetime object that's `duration` in the past. + +        The converter supports the same symbols for each unit of time as its parent class. +        """ +        delta = await super().convert(ctx, duration) +        now = datetime.now(timezone.utc) + +        try: +            return now - delta +        except (ValueError, OverflowError): +            raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") + +  class OffTopicName(Converter):      """A converter that ensures an added off-topic name is valid.""" -    ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" +    ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>" +    TRANSLATED_CHARACTERS = "๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-๏ผ๏ผ"      @classmethod      def translate_name(cls, name: str, *, from_unicode: bool = True) -> str: @@ -345,9 +394,9 @@ class OffTopicName(Converter):          If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text.          """          if from_unicode: -            table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-') +            table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS)          else: -            table = str.maketrans('๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-', cls.ALLOWED_CHARACTERS) +            table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS)          return name.translate(table) @@ -379,8 +428,8 @@ class ISODateTime(Converter):          The converter is flexible in the formats it accepts, as it uses the `isoparse` method of          `dateutil.parser`. In general, it accepts datetime strings that start with a date,          optionally followed by a time. Specifying a timezone offset in the datetime string is -        supported, but the `datetime` object will be converted to UTC and will be returned without -        `tzinfo` as a timezone-unaware `datetime` object. +        supported, but the `datetime` object will be converted to UTC. If no timezone is specified, the datetime will +        be assumed to be in UTC already. In all cases, the returned object will have the UTC timezone.          See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse @@ -406,7 +455,8 @@ class ISODateTime(Converter):          if dt.tzinfo:              dt = dt.astimezone(dateutil.tz.UTC) -            dt = dt.replace(tzinfo=None) +        else:  # Without a timezone, assume it represents UTC. +            dt = dt.replace(tzinfo=dateutil.tz.UTC)          return dt @@ -416,11 +466,11 @@ class HushDurationConverter(Converter):      MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)") -    async def convert(self, ctx: Context, argument: str) -> t.Optional[int]: +    async def convert(self, ctx: Context, argument: str) -> int:          """          Convert `argument` to a duration that's max 15 minutes or None. -        If `"forever"` is passed, None is returned; otherwise an int of the extracted time. +        If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time.          Accepted formats are:          * <duration>,          * <duration>m, @@ -428,7 +478,7 @@ class HushDurationConverter(Converter):          * forever.          """          if argument == "forever": -            return None +            return -1          match = self.MINUTES_RE.match(argument)          if not match:              raise BadArgument(f"{argument} is not a valid minutes duration.") @@ -439,103 +489,51 @@ class HushDurationConverter(Converter):          return duration -def proxy_user(user_id: str) -> discord.Object: -    """ -    Create a proxy user object from the given id. +def _is_an_unambiguous_user_argument(argument: str) -> bool: +    """Check if the provided argument is a user mention, user id, or username (name#discrim).""" +    has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) -    Used when a Member or User object cannot be resolved. -    """ -    log.trace(f"Attempting to create a proxy user for the user id {user_id}.") +    # Check to see if the author passed a username (a discriminator exists) +    argument = argument.removeprefix('@') +    has_username = len(argument) > 5 and argument[-5] == '#' -    try: -        user_id = int(user_id) -    except ValueError: -        log.debug(f"Failed to create proxy user {user_id}: could not convert to int.") -        raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.") +    return has_id_or_mention or has_username -    user = discord.Object(user_id) -    user.mention = user.id -    user.display_name = f"<@{user.id}>" -    user.avatar_url_as = lambda static_format: None -    user.bot = False -    return user +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" +                          " `name#discriminator`.") -class UserMentionOrID(UserConverter): +class UnambiguousUser(UserConverter):      """ -    Converts to a `discord.User`, but only if a mention or userID is provided. +    Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. -    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. -    This is useful in cases where that lookup strategy would lead to ambiguity. +    Unlike the default `UserConverter`, it doesn't allow conversion from a name. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """      async def convert(self, ctx: Context, argument: str) -> discord.User: -        """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) - -        if match is not None: +        """Convert the `argument` to a `discord.User`.""" +        if _is_an_unambiguous_user_argument(argument):              return await super().convert(ctx, argument)          else: -            raise BadArgument(f"`{argument}` is not a User mention or a User ID.") +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) -class FetchedUser(UserConverter): +class UnambiguousMember(MemberConverter):      """ -    Converts to a `discord.User` or, if it fails, a `discord.Object`. - -    Unlike the default `UserConverter`, which only does lookups via the global user cache, this -    converter attempts to fetch the user via an API call to Discord when the using the cache is -    unsuccessful. - -    If the fetch also fails and the error doesn't imply the user doesn't exist, then a -    `discord.Object` is returned via the `user_proxy` converter. +    Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. -    The lookup strategy is as follows (in order): - -    1. Lookup by ID. -    2. Lookup by mention. -    3. Lookup by name#discrim -    4. Lookup by name -    5. Lookup via API -    6. Create a proxy user with discord.Object -    """ - -    async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]: -        """Convert the `arg` to a `discord.User` or `discord.Object`.""" -        try: -            return await super().convert(ctx, arg) -        except BadArgument: -            pass - -        try: -            user_id = int(arg) -            log.trace(f"Fetching user {user_id}...") -            return await ctx.bot.fetch_user(user_id) -        except ValueError: -            log.debug(f"Failed to fetch user {arg}: could not convert to int.") -            raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`") -        except discord.HTTPException as e: -            # If the Discord error isn't `Unknown user`, return a proxy instead -            if e.code != 10013: -                log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") -                return proxy_user(arg) - -            log.debug(f"Failed to fetch user {arg}: user does not exist.") -            raise BadArgument(f"User `{arg}` does not exist") - - -def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: -    """ -    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. - -    The snowflake is expected to be within the first capture group in `pattern`. +    Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """ -    match = pattern.match(arg) -    if not match: -        raise BadArgument(f"Mention {str!r} is invalid.") -    return int(match.group(1)) +    async def convert(self, ctx: Context, argument: str) -> discord.Member: +        """Convert the `argument` to a `discord.Member`.""" +        if _is_an_unambiguous_user_argument(argument): +            return await super().convert(ctx, argument) +        else: +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))  class Infraction(Converter): @@ -554,7 +552,7 @@ class Infraction(Converter):                  "ordering": "-inserted_at"              } -            infractions = await ctx.bot.api_client.get("bot/infractions", params=params) +            infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params)              if not infractions:                  raise BadArgument( @@ -564,9 +562,37 @@ class Infraction(Converter):                  return infractions[0]          else: -            return await ctx.bot.api_client.get(f"bot/infractions/{arg}") - +            try: +                return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded") +            except ResponseCodeError as e: +                if e.status == 404: +                    raise InvalidInfraction( +                        converter=Infraction, +                        original=e, +                        infraction_arg=arg +                    ) +                raise e + + +if t.TYPE_CHECKING: +    ValidDiscordServerInvite = dict  # noqa: F811 +    ValidFilterListType = str  # noqa: F811 +    Extension = str  # noqa: F811 +    PackageName = str  # noqa: F811 +    ValidURL = str  # noqa: F811 +    Inventory = t.Tuple[str, _inventory_parser.InventoryDict]  # noqa: F811 +    Snowflake = int  # noqa: F811 +    SourceConverter = SourceType  # noqa: F811 +    DurationDelta = relativedelta  # noqa: F811 +    Duration = datetime  # noqa: F811 +    Age = datetime  # noqa: F811 +    OffTopicName = str  # noqa: F811 +    ISODateTime = datetime  # noqa: F811 +    HushDurationConverter = int  # noqa: F811 +    UnambiguousUser = discord.User  # noqa: F811 +    UnambiguousMember = discord.Member  # noqa: F811 +    Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime] -FetchedMember = t.Union[discord.Member, FetchedUser] -UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) +MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] | 
