diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 332 |
1 files changed, 179 insertions, 153 deletions
diff --git a/bot/converters.py b/bot/converters.py index 2a3943831..559e759e1 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ -import logging +from __future__ import annotations + import re import typing as t -from datetime import datetime -from functools import partial +from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser @@ -10,18 +10,26 @@ import dateutil.tz import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.utils import escape_markdown, snowflake_time +from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs +from bot.errors import InvalidInfraction from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier +from bot.log import get_logger +from bot.utils.extensions import EXTENSIONS, unqualify from bot.utils.regex import INVITE_RE from bot.utils.time import parse_duration_string -log = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from bot.exts.info.source import SourceType + +log = get_logger(__name__) -DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +DISCORD_EPOCH_DT = snowflake_time(0) RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") @@ -64,10 +72,10 @@ class ValidDiscordServerInvite(Converter): async def convert(self, ctx: Context, server_invite: str) -> dict: """Check whether the string is a valid Discord server invite.""" - invite_code = INVITE_RE.search(server_invite) + invite_code = INVITE_RE.match(server_invite) if invite_code: response = await ctx.bot.http_session.get( - f"{URLs.discord_invite_api}/{invite_code[1]}" + f"{URLs.discord_invite_api}/{invite_code.group('invite')}" ) if response.status != 404: invite_data = await response.json() @@ -128,6 +136,44 @@ class ValidFilterListType(Converter): return list_type +class Extension(Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + if argument == unqualify(ext): + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise BadArgument(f":x: Could not find the extension `{argument}`.") + + class PackageName(Converter): """ A converter that checks whether the given string is a valid package name. @@ -191,11 +237,16 @@ class Inventory(Converter): async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]: """Convert url to Intersphinx inventory URL.""" await ctx.trigger_typing() - if (inventory := await _inventory_parser.fetch_inventory(url)) is None: - raise BadArgument( - f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." - ) - return url, inventory + try: + inventory = await _inventory_parser.fetch_inventory(url) + except _inventory_parser.InvalidHeaderError: + raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.") + else: + if inventory is None: + raise BadArgument( + f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." + ) + return url, inventory class Snowflake(IDConverter): @@ -230,64 +281,43 @@ class Snowflake(IDConverter): if time < DISCORD_EPOCH_DT: raise BadArgument(f"{error}: timestamp is before the Discord epoch.") - elif (datetime.utcnow() - time).days < -1: + elif (datetime.now(timezone.utc) - time).days < -1: raise BadArgument(f"{error}: timestamp is too far into the future.") return snowflake -class TagNameConverter(Converter): - """ - Ensure that a proposed tag name is valid. - - Valid tag names meet the following conditions: - * All ASCII characters - * Has at least one non-whitespace character - * Not solely numeric - * Shorter than 127 characters - """ +class SourceConverter(Converter): + """Convert an argument into a help command, tag, command, or cog.""" @staticmethod - async def convert(ctx: Context, tag_name: str) -> str: - """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" - tag_name = tag_name.lower().strip() - - # The tag name has at least one invalid character. - if ascii(tag_name)[1:-1] != tag_name: - raise BadArgument("Don't be ridiculous, you can't use that character!") - - # The tag name is either empty, or consists of nothing but whitespace. - elif not tag_name: - raise BadArgument("Tag names should not be empty, or filled with whitespace.") - - # The tag name is longer than 127 characters. - elif len(tag_name) > 127: - raise BadArgument("Are you insane? That's way too long!") - - # The tag name is ascii but does not contain any letters. - elif not any(character.isalpha() for character in tag_name): - raise BadArgument("Tag names must contain at least one letter.") + async def convert(ctx: Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower() == "help": + return ctx.bot.help_command - return tag_name + cog = ctx.bot.get_cog(argument) + if cog: + return cog + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd -class TagContentConverter(Converter): - """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True - @staticmethod - async def convert(ctx: Context, tag_content: str) -> str: - """ - Ensure tag_content is non-empty and contains at least one non-whitespace character. - - If tag_content is valid, return the stripped version. - """ - tag_content = tag_content.strip() - - # The tag contents should not be empty, or filled with whitespace. - if not tag_content: - raise BadArgument("Tag contents should not be empty, or filled with whitespace.") + if not tags_cog: + show_tag = False + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier + escaped_arg = escape_markdown(argument) - return tag_content + raise BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) class DurationDelta(Converter): @@ -324,7 +354,7 @@ class Duration(DurationDelta): The converter supports the same symbols for each unit of time as its parent class. """ delta = await super().convert(ctx, duration) - now = datetime.utcnow() + now = datetime.now(timezone.utc) try: return now + delta @@ -332,10 +362,29 @@ class Duration(DurationDelta): raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class Age(DurationDelta): + """Convert duration strings into UTC datetime.datetime objects.""" + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the past. + + The converter supports the same symbols for each unit of time as its parent class. + """ + delta = await super().convert(ctx, duration) + now = datetime.now(timezone.utc) + + try: + return now - delta + except (ValueError, OverflowError): + raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") + + class OffTopicName(Converter): """A converter that ensures an added off-topic name is valid.""" - ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>" + TRANSLATED_CHARACTERS = "๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-๏ผ๏ผ" @classmethod def translate_name(cls, name: str, *, from_unicode: bool = True) -> str: @@ -345,9 +394,9 @@ class OffTopicName(Converter): If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text. """ if from_unicode: - table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-') + table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS) else: - table = str.maketrans('๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-', cls.ALLOWED_CHARACTERS) + table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS) return name.translate(table) @@ -379,8 +428,8 @@ class ISODateTime(Converter): The converter is flexible in the formats it accepts, as it uses the `isoparse` method of `dateutil.parser`. In general, it accepts datetime strings that start with a date, optionally followed by a time. Specifying a timezone offset in the datetime string is - supported, but the `datetime` object will be converted to UTC and will be returned without - `tzinfo` as a timezone-unaware `datetime` object. + supported, but the `datetime` object will be converted to UTC. If no timezone is specified, the datetime will + be assumed to be in UTC already. In all cases, the returned object will have the UTC timezone. See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse @@ -406,7 +455,8 @@ class ISODateTime(Converter): if dt.tzinfo: dt = dt.astimezone(dateutil.tz.UTC) - dt = dt.replace(tzinfo=None) + else: # Without a timezone, assume it represents UTC. + dt = dt.replace(tzinfo=dateutil.tz.UTC) return dt @@ -416,11 +466,11 @@ class HushDurationConverter(Converter): MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)") - async def convert(self, ctx: Context, argument: str) -> t.Optional[int]: + async def convert(self, ctx: Context, argument: str) -> int: """ Convert `argument` to a duration that's max 15 minutes or None. - If `"forever"` is passed, None is returned; otherwise an int of the extracted time. + If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time. Accepted formats are: * <duration>, * <duration>m, @@ -428,7 +478,7 @@ class HushDurationConverter(Converter): * forever. """ if argument == "forever": - return None + return -1 match = self.MINUTES_RE.match(argument) if not match: raise BadArgument(f"{argument} is not a valid minutes duration.") @@ -439,103 +489,51 @@ class HushDurationConverter(Converter): return duration -def proxy_user(user_id: str) -> discord.Object: - """ - Create a proxy user object from the given id. +def _is_an_unambiguous_user_argument(argument: str) -> bool: + """Check if the provided argument is a user mention, user id, or username (name#discrim).""" + has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) - Used when a Member or User object cannot be resolved. - """ - log.trace(f"Attempting to create a proxy user for the user id {user_id}.") + # Check to see if the author passed a username (a discriminator exists) + argument = argument.removeprefix('@') + has_username = len(argument) > 5 and argument[-5] == '#' - try: - user_id = int(user_id) - except ValueError: - log.debug(f"Failed to create proxy user {user_id}: could not convert to int.") - raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.") + return has_id_or_mention or has_username - user = discord.Object(user_id) - user.mention = user.id - user.display_name = f"<@{user.id}>" - user.avatar_url_as = lambda static_format: None - user.bot = False - return user +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" + " `name#discriminator`.") -class UserMentionOrID(UserConverter): +class UnambiguousUser(UserConverter): """ - Converts to a `discord.User`, but only if a mention or userID is provided. + Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. - Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. - This is useful in cases where that lookup strategy would lead to ambiguity. + Unlike the default `UserConverter`, it doesn't allow conversion from a name. + This is useful in cases where that lookup strategy would lead to too much ambiguity. """ async def convert(self, ctx: Context, argument: str) -> discord.User: - """Convert the `arg` to a `discord.User`.""" - match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) - - if match is not None: + """Convert the `argument` to a `discord.User`.""" + if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: - raise BadArgument(f"`{argument}` is not a User mention or a User ID.") + raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) -class FetchedUser(UserConverter): +class UnambiguousMember(MemberConverter): """ - Converts to a `discord.User` or, if it fails, a `discord.Object`. - - Unlike the default `UserConverter`, which only does lookups via the global user cache, this - converter attempts to fetch the user via an API call to Discord when the using the cache is - unsuccessful. - - If the fetch also fails and the error doesn't imply the user doesn't exist, then a - `discord.Object` is returned via the `user_proxy` converter. + Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. - The lookup strategy is as follows (in order): - - 1. Lookup by ID. - 2. Lookup by mention. - 3. Lookup by name#discrim - 4. Lookup by name - 5. Lookup via API - 6. Create a proxy user with discord.Object - """ - - async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]: - """Convert the `arg` to a `discord.User` or `discord.Object`.""" - try: - return await super().convert(ctx, arg) - except BadArgument: - pass - - try: - user_id = int(arg) - log.trace(f"Fetching user {user_id}...") - return await ctx.bot.fetch_user(user_id) - except ValueError: - log.debug(f"Failed to fetch user {arg}: could not convert to int.") - raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`") - except discord.HTTPException as e: - # If the Discord error isn't `Unknown user`, return a proxy instead - if e.code != 10013: - log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") - return proxy_user(arg) - - log.debug(f"Failed to fetch user {arg}: user does not exist.") - raise BadArgument(f"User `{arg}` does not exist") - - -def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: - """ - Extract the snowflake from `arg` using a regex `pattern` and return it as an int. - - The snowflake is expected to be within the first capture group in `pattern`. + Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. + This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - match = pattern.match(arg) - if not match: - raise BadArgument(f"Mention {str!r} is invalid.") - return int(match.group(1)) + async def convert(self, ctx: Context, argument: str) -> discord.Member: + """Convert the `argument` to a `discord.Member`.""" + if _is_an_unambiguous_user_argument(argument): + return await super().convert(ctx, argument) + else: + raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) class Infraction(Converter): @@ -554,7 +552,7 @@ class Infraction(Converter): "ordering": "-inserted_at" } - infractions = await ctx.bot.api_client.get("bot/infractions", params=params) + infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params) if not infractions: raise BadArgument( @@ -564,9 +562,37 @@ class Infraction(Converter): return infractions[0] else: - return await ctx.bot.api_client.get(f"bot/infractions/{arg}") - + try: + return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded") + except ResponseCodeError as e: + if e.status == 404: + raise InvalidInfraction( + converter=Infraction, + original=e, + infraction_arg=arg + ) + raise e + + +if t.TYPE_CHECKING: + ValidDiscordServerInvite = dict # noqa: F811 + ValidFilterListType = str # noqa: F811 + Extension = str # noqa: F811 + PackageName = str # noqa: F811 + ValidURL = str # noqa: F811 + Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 + Snowflake = int # noqa: F811 + SourceConverter = SourceType # noqa: F811 + DurationDelta = relativedelta # noqa: F811 + Duration = datetime # noqa: F811 + Age = datetime # noqa: F811 + OffTopicName = str # noqa: F811 + ISODateTime = datetime # noqa: F811 + HushDurationConverter = int # noqa: F811 + UnambiguousUser = discord.User # noqa: F811 + UnambiguousMember = discord.Member # noqa: F811 + Infraction = t.Optional[dict] # noqa: F811 Expiry = t.Union[Duration, ISODateTime] -FetchedMember = t.Union[discord.Member, FetchedUser] -UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) +MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] |