diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 220 |
1 files changed, 122 insertions, 98 deletions
diff --git a/bot/converters.py b/bot/converters.py index 038d2a287..48a5e3dc2 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging import re import typing as t from datetime import datetime -from functools import partial from ssl import CertificateError import dateutil.parser @@ -10,14 +11,19 @@ import dateutil.tz import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier +from bot.utils.extensions import EXTENSIONS, unqualify from bot.utils.regex import INVITE_RE from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: + from bot.exts.info.source import SourceType log = logging.getLogger(__name__) @@ -128,6 +134,44 @@ class ValidFilterListType(Converter): return list_type +class Extension(Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + if argument == unqualify(ext): + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise BadArgument(f":x: Could not find the extension `{argument}`.") + + class PackageName(Converter): """ A converter that checks whether the given string is a valid package name. @@ -236,23 +280,37 @@ class Snowflake(IDConverter): return snowflake -class TagContentConverter(Converter): - """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +class SourceConverter(Converter): + """Convert an argument into a help command, tag, command, or cog.""" @staticmethod - async def convert(ctx: Context, tag_content: str) -> str: - """ - Ensure tag_content is non-empty and contains at least one non-whitespace character. + async def convert(ctx: Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower() == "help": + return ctx.bot.help_command - If tag_content is valid, return the stripped version. - """ - tag_content = tag_content.strip() + cog = ctx.bot.get_cog(argument) + if cog: + return cog - # The tag contents should not be empty, or filled with whitespace. - if not tag_content: - raise BadArgument("Tag contents should not be empty, or filled with whitespace.") + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd - return tag_content + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier + escaped_arg = escape_markdown(argument) + + raise BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) class DurationDelta(Converter): @@ -381,11 +439,11 @@ class HushDurationConverter(Converter): MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)") - async def convert(self, ctx: Context, argument: str) -> t.Optional[int]: + async def convert(self, ctx: Context, argument: str) -> int: """ Convert `argument` to a duration that's max 15 minutes or None. - If `"forever"` is passed, None is returned; otherwise an int of the extracted time. + If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time. Accepted formats are: * <duration>, * <duration>m, @@ -393,7 +451,7 @@ class HushDurationConverter(Converter): * forever. """ if argument == "forever": - return None + return -1 match = self.MINUTES_RE.match(argument) if not match: raise BadArgument(f"{argument} is not a valid minutes duration.") @@ -404,103 +462,51 @@ class HushDurationConverter(Converter): return duration -def proxy_user(user_id: str) -> discord.Object: - """ - Create a proxy user object from the given id. +def _is_an_unambiguous_user_argument(argument: str) -> bool: + """Check if the provided argument is a user mention, user id, or username (name#discrim).""" + has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) - Used when a Member or User object cannot be resolved. - """ - log.trace(f"Attempting to create a proxy user for the user id {user_id}.") + # Check to see if the author passed a username (a discriminator exists) + argument = argument.removeprefix('@') + has_username = len(argument) > 5 and argument[-5] == '#' - try: - user_id = int(user_id) - except ValueError: - log.debug(f"Failed to create proxy user {user_id}: could not convert to int.") - raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.") + return has_id_or_mention or has_username - user = discord.Object(user_id) - user.mention = user.id - user.display_name = f"<@{user.id}>" - user.avatar_url_as = lambda static_format: None - user.bot = False - return user +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" + " `name#discriminator`.") -class UserMentionOrID(UserConverter): +class UnambiguousUser(UserConverter): """ - Converts to a `discord.User`, but only if a mention or userID is provided. + Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. - Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. - This is useful in cases where that lookup strategy would lead to ambiguity. + Unlike the default `UserConverter`, it doesn't allow conversion from a name. + This is useful in cases where that lookup strategy would lead to too much ambiguity. """ async def convert(self, ctx: Context, argument: str) -> discord.User: - """Convert the `arg` to a `discord.User`.""" - match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) - - if match is not None: + """Convert the `argument` to a `discord.User`.""" + if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: - raise BadArgument(f"`{argument}` is not a User mention or a User ID.") - - -class FetchedUser(UserConverter): - """ - Converts to a `discord.User` or, if it fails, a `discord.Object`. - - Unlike the default `UserConverter`, which only does lookups via the global user cache, this - converter attempts to fetch the user via an API call to Discord when the using the cache is - unsuccessful. - - If the fetch also fails and the error doesn't imply the user doesn't exist, then a - `discord.Object` is returned via the `user_proxy` converter. - - The lookup strategy is as follows (in order): - - 1. Lookup by ID. - 2. Lookup by mention. - 3. Lookup by name#discrim - 4. Lookup by name - 5. Lookup via API - 6. Create a proxy user with discord.Object - """ + raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) - async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]: - """Convert the `arg` to a `discord.User` or `discord.Object`.""" - try: - return await super().convert(ctx, arg) - except BadArgument: - pass - - try: - user_id = int(arg) - log.trace(f"Fetching user {user_id}...") - return await ctx.bot.fetch_user(user_id) - except ValueError: - log.debug(f"Failed to fetch user {arg}: could not convert to int.") - raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`") - except discord.HTTPException as e: - # If the Discord error isn't `Unknown user`, return a proxy instead - if e.code != 10013: - log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") - return proxy_user(arg) - - log.debug(f"Failed to fetch user {arg}: user does not exist.") - raise BadArgument(f"User `{arg}` does not exist") - -def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: +class UnambiguousMember(MemberConverter): """ - Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. - The snowflake is expected to be within the first capture group in `pattern`. + Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. + This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - match = pattern.match(arg) - if not match: - raise BadArgument(f"Mention {str!r} is invalid.") - return int(match.group(1)) + async def convert(self, ctx: Context, argument: str) -> discord.Member: + """Convert the `argument` to a `discord.Member`.""" + if _is_an_unambiguous_user_argument(argument): + return await super().convert(ctx, argument) + else: + raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) class Infraction(Converter): @@ -532,6 +538,24 @@ class Infraction(Converter): return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +if t.TYPE_CHECKING: + ValidDiscordServerInvite = dict # noqa: F811 + ValidFilterListType = str # noqa: F811 + Extension = str # noqa: F811 + PackageName = str # noqa: F811 + ValidURL = str # noqa: F811 + Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 + Snowflake = int # noqa: F811 + SourceConverter = SourceType # noqa: F811 + DurationDelta = relativedelta # noqa: F811 + Duration = datetime # noqa: F811 + OffTopicName = str # noqa: F811 + ISODateTime = datetime # noqa: F811 + HushDurationConverter = int # noqa: F811 + UnambiguousUser = discord.User # noqa: F811 + UnambiguousMember = discord.Member # noqa: F811 + Infraction = t.Optional[dict] # noqa: F811 + Expiry = t.Union[Duration, ISODateTime] -FetchedMember = t.Union[discord.Member, FetchedUser] -UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) +MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] |