diff options
Diffstat (limited to 'bot/converters.py')
| -rw-r--r-- | bot/converters.py | 98 | 
1 files changed, 73 insertions, 25 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 546f6e8f4..5aaff4e76 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,6 +1,5 @@  from __future__ import annotations -import logging  import re  import typing as t  from datetime import datetime @@ -11,20 +10,23 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter  from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time  from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs +from bot.errors import InvalidInfraction  from bot.exts.info.doc import _inventory_parser +from bot.log import get_logger  from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string +  if t.TYPE_CHECKING:      from bot.exts.info.source import SourceType -log = logging.getLogger(__name__) +log = get_logger(__name__)  DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)  RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") @@ -69,10 +71,10 @@ class ValidDiscordServerInvite(Converter):      async def convert(self, ctx: Context, server_invite: str) -> dict:          """Check whether the string is a valid Discord server invite.""" -        invite_code = INVITE_RE.search(server_invite) +        invite_code = INVITE_RE.match(server_invite)          if invite_code:              response = await ctx.bot.http_session.get( -                f"{URLs.discord_invite_api}/{invite_code[1]}" +                f"{URLs.discord_invite_api}/{invite_code.group('invite')}"              )              if response.status != 404:                  invite_data = await response.json() @@ -234,11 +236,16 @@ class Inventory(Converter):      async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:          """Convert url to Intersphinx inventory URL."""          await ctx.trigger_typing() -        if (inventory := await _inventory_parser.fetch_inventory(url)) is None: -            raise BadArgument( -                f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." -            ) -        return url, inventory +        try: +            inventory = await _inventory_parser.fetch_inventory(url) +        except _inventory_parser.InvalidHeaderError: +            raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.") +        else: +            if inventory is None: +                raise BadArgument( +                    f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." +                ) +            return url, inventory  class Snowflake(IDConverter): @@ -266,7 +273,7 @@ class Snowflake(IDConverter):          snowflake = int(arg)          try: -            time = snowflake_time(snowflake) +            time = snowflake_time(snowflake).replace(tzinfo=None)          except (OverflowError, OSError) as e:              # Not sure if this can ever even happen, but let's be safe.              raise BadArgument(f"{error}: {e}") @@ -409,7 +416,8 @@ class Age(DurationDelta):  class OffTopicName(Converter):      """A converter that ensures an added off-topic name is valid.""" -    ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" +    ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>" +    TRANSLATED_CHARACTERS = "๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-๏ผ๏ผ"      @classmethod      def translate_name(cls, name: str, *, from_unicode: bool = True) -> str: @@ -419,9 +427,9 @@ class OffTopicName(Converter):          If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text.          """          if from_unicode: -            table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-') +            table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS)          else: -            table = str.maketrans('๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-', cls.ALLOWED_CHARACTERS) +            table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS)          return name.translate(table) @@ -513,22 +521,51 @@ class HushDurationConverter(Converter):          return duration -class UserMentionOrID(UserConverter): +def _is_an_unambiguous_user_argument(argument: str) -> bool: +    """Check if the provided argument is a user mention, user id, or username (name#discrim).""" +    has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) + +    # Check to see if the author passed a username (a discriminator exists) +    argument = argument.removeprefix('@') +    has_username = len(argument) > 5 and argument[-5] == '#' + +    return has_id_or_mention or has_username + + +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" +                          " `name#discriminator`.") + + +class UnambiguousUser(UserConverter):      """ -    Converts to a `discord.User`, but only if a mention or userID is provided. +    Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. -    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. -    This is useful in cases where that lookup strategy would lead to ambiguity. +    Unlike the default `UserConverter`, it doesn't allow conversion from a name. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """      async def convert(self, ctx: Context, argument: str) -> discord.User: -        """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) +        """Convert the `argument` to a `discord.User`.""" +        if _is_an_unambiguous_user_argument(argument): +            return await super().convert(ctx, argument) +        else: +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) + -        if match is not None: +class UnambiguousMember(MemberConverter): +    """ +    Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. + +    Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. +    This is useful in cases where that lookup strategy would lead to too much ambiguity. +    """ + +    async def convert(self, ctx: Context, argument: str) -> discord.Member: +        """Convert the `argument` to a `discord.Member`.""" +        if _is_an_unambiguous_user_argument(argument):              return await super().convert(ctx, argument)          else: -            raise BadArgument(f"`{argument}` is not a User mention or a User ID.") +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))  class Infraction(Converter): @@ -547,7 +584,7 @@ class Infraction(Converter):                  "ordering": "-inserted_at"              } -            infractions = await ctx.bot.api_client.get("bot/infractions", params=params) +            infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params)              if not infractions:                  raise BadArgument( @@ -557,7 +594,16 @@ class Infraction(Converter):                  return infractions[0]          else: -            return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +            try: +                return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded") +            except ResponseCodeError as e: +                if e.status == 404: +                    raise InvalidInfraction( +                        converter=Infraction, +                        original=e, +                        infraction_arg=arg +                    ) +                raise e  if t.TYPE_CHECKING: @@ -576,8 +622,10 @@ if t.TYPE_CHECKING:      OffTopicName = str  # noqa: F811      ISODateTime = datetime  # noqa: F811      HushDurationConverter = int  # noqa: F811 -    UserMentionOrID = discord.User  # noqa: F811 +    UnambiguousUser = discord.User  # noqa: F811 +    UnambiguousMember = discord.Member  # noqa: F811      Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime]  MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] | 
