diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/converters.py | 219 | 
1 files changed, 121 insertions, 98 deletions
diff --git a/bot/converters.py b/bot/converters.py index 2a3943831..bd4044c7e 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,9 @@ +from __future__ import annotations +  import logging  import re  import typing as t  from datetime import datetime -from functools import partial  from ssl import CertificateError  import dateutil.parser @@ -10,14 +11,18 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs  from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: +    from bot.exts.info.source import SourceType  log = logging.getLogger(__name__) @@ -128,6 +133,44 @@ class ValidFilterListType(Converter):          return list_type +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class PackageName(Converter):      """      A converter that checks whether the given string is a valid package name. @@ -271,23 +314,36 @@ class TagNameConverter(Converter):          return tag_name -class TagContentConverter(Converter): -    """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +class SourceConverter(Converter): +    """Convert an argument into a help command, tag, command, or cog."""      @staticmethod -    async def convert(ctx: Context, tag_content: str) -> str: -        """ -        Ensure tag_content is non-empty and contains at least one non-whitespace character. +    async def convert(ctx: Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        if argument.lower() == "help": +            return ctx.bot.help_command -        If tag_content is valid, return the stripped version. -        """ -        tag_content = tag_content.strip() +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog + +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd -        # The tag contents should not be empty, or filled with whitespace. -        if not tag_content: -            raise BadArgument("Tag contents should not be empty, or filled with whitespace.") +        tags_cog = ctx.bot.get_cog("Tags") +        show_tag = True -        return tag_content +        if not tags_cog: +            show_tag = False +        elif argument.lower() in tags_cog._cache: +            return argument.lower() + +        escaped_arg = escape_markdown(argument) + +        raise BadArgument( +            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." +        )  class DurationDelta(Converter): @@ -416,11 +472,11 @@ class HushDurationConverter(Converter):      MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)") -    async def convert(self, ctx: Context, argument: str) -> t.Optional[int]: +    async def convert(self, ctx: Context, argument: str) -> int:          """          Convert `argument` to a duration that's max 15 minutes or None. -        If `"forever"` is passed, None is returned; otherwise an int of the extracted time. +        If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time.          Accepted formats are:          * <duration>,          * <duration>m, @@ -428,7 +484,7 @@ class HushDurationConverter(Converter):          * forever.          """          if argument == "forever": -            return None +            return -1          match = self.MINUTES_RE.match(argument)          if not match:              raise BadArgument(f"{argument} is not a valid minutes duration.") @@ -439,103 +495,51 @@ class HushDurationConverter(Converter):          return duration -def proxy_user(user_id: str) -> discord.Object: -    """ -    Create a proxy user object from the given id. +def _is_an_unambiguous_user_argument(argument: str) -> bool: +    """Check if the provided argument is a user mention, user id, or username (name#discrim).""" +    has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) -    Used when a Member or User object cannot be resolved. -    """ -    log.trace(f"Attempting to create a proxy user for the user id {user_id}.") +    # Check to see if the author passed a username (a discriminator exists) +    argument = argument.removeprefix('@') +    has_username = len(argument) > 5 and argument[-5] == '#' -    try: -        user_id = int(user_id) -    except ValueError: -        log.debug(f"Failed to create proxy user {user_id}: could not convert to int.") -        raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.") +    return has_id_or_mention or has_username -    user = discord.Object(user_id) -    user.mention = user.id -    user.display_name = f"<@{user.id}>" -    user.avatar_url_as = lambda static_format: None -    user.bot = False -    return user +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" +                          " `name#discriminator`.") -class UserMentionOrID(UserConverter): +class UnambiguousUser(UserConverter):      """ -    Converts to a `discord.User`, but only if a mention or userID is provided. +    Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. -    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. -    This is useful in cases where that lookup strategy would lead to ambiguity. +    Unlike the default `UserConverter`, it doesn't allow conversion from a name. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """      async def convert(self, ctx: Context, argument: str) -> discord.User: -        """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) - -        if match is not None: +        """Convert the `argument` to a `discord.User`.""" +        if _is_an_unambiguous_user_argument(argument):              return await super().convert(ctx, argument)          else: -            raise BadArgument(f"`{argument}` is not a User mention or a User ID.") - - -class FetchedUser(UserConverter): -    """ -    Converts to a `discord.User` or, if it fails, a `discord.Object`. - -    Unlike the default `UserConverter`, which only does lookups via the global user cache, this -    converter attempts to fetch the user via an API call to Discord when the using the cache is -    unsuccessful. - -    If the fetch also fails and the error doesn't imply the user doesn't exist, then a -    `discord.Object` is returned via the `user_proxy` converter. - -    The lookup strategy is as follows (in order): - -    1. Lookup by ID. -    2. Lookup by mention. -    3. Lookup by name#discrim -    4. Lookup by name -    5. Lookup via API -    6. Create a proxy user with discord.Object -    """ +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) -    async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]: -        """Convert the `arg` to a `discord.User` or `discord.Object`.""" -        try: -            return await super().convert(ctx, arg) -        except BadArgument: -            pass - -        try: -            user_id = int(arg) -            log.trace(f"Fetching user {user_id}...") -            return await ctx.bot.fetch_user(user_id) -        except ValueError: -            log.debug(f"Failed to fetch user {arg}: could not convert to int.") -            raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`") -        except discord.HTTPException as e: -            # If the Discord error isn't `Unknown user`, return a proxy instead -            if e.code != 10013: -                log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") -                return proxy_user(arg) - -            log.debug(f"Failed to fetch user {arg}: user does not exist.") -            raise BadArgument(f"User `{arg}` does not exist") - -def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: +class UnambiguousMember(MemberConverter):      """ -    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. +    Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. -    The snowflake is expected to be within the first capture group in `pattern`. +    Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """ -    match = pattern.match(arg) -    if not match: -        raise BadArgument(f"Mention {str!r} is invalid.") -    return int(match.group(1)) +    async def convert(self, ctx: Context, argument: str) -> discord.Member: +        """Convert the `argument` to a `discord.Member`.""" +        if _is_an_unambiguous_user_argument(argument): +            return await super().convert(ctx, argument) +        else: +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))  class Infraction(Converter): @@ -567,6 +571,25 @@ class Infraction(Converter):              return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +if t.TYPE_CHECKING: +    ValidDiscordServerInvite = dict  # noqa: F811 +    ValidFilterListType = str  # noqa: F811 +    Extension = str  # noqa: F811 +    PackageName = str  # noqa: F811 +    ValidURL = str  # noqa: F811 +    Inventory = t.Tuple[str, _inventory_parser.InventoryDict]  # noqa: F811 +    Snowflake = int  # noqa: F811 +    TagNameConverter = str  # noqa: F811 +    SourceConverter = SourceType  # noqa: F811 +    DurationDelta = relativedelta  # noqa: F811 +    Duration = datetime  # noqa: F811 +    OffTopicName = str  # noqa: F811 +    ISODateTime = datetime  # noqa: F811 +    HushDurationConverter = int  # noqa: F811 +    UnambiguousUser = discord.User  # noqa: F811 +    UnambiguousMember = discord.Member  # noqa: F811 +    Infraction = t.Optional[dict]  # noqa: F811 +  Expiry = t.Union[Duration, ISODateTime] -FetchedMember = t.Union[discord.Member, FetchedUser] -UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) +MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser]  |