diff options
| -rw-r--r-- | bot/cogs/moderation/management.py | 6 | ||||
| -rw-r--r-- | bot/converters.py | 18 | 
2 files changed, 20 insertions, 4 deletions
| diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index af736d4de..c2cca5352 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Snowflake, allowed_strings, proxy_user +from bot.converters import Expiry, Snowflake, UserMention, allowed_strings, proxy_user  from bot.pagination import LinePaginator  from bot.utils import messages, time  from bot.utils.checks import in_whitelist_check, with_role_check @@ -177,9 +177,9 @@ class ModManagement(commands.Cog):      # region: Search infractions      @infraction_group.group(name="search", invoke_without_command=True) -    async def infraction_search_group(self, ctx: Context, query: Snowflake) -> None: +    async def infraction_search_group(self, ctx: Context, query: t.Union[UserMention, Snowflake, str]) -> None:          """Searches for infractions in the database.""" -        if isinstance(query, discord.User): +        if isinstance(query, int):              await ctx.invoke(self.search_user, query)          else:              await ctx.invoke(self.search_reason, query) diff --git a/bot/converters.py b/bot/converters.py index 4c41d0ece..4cfd663ba 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -2,6 +2,7 @@ import logging  import re  import typing as t  from datetime import datetime +from functools import partial  from ssl import CertificateError  import dateutil.parser @@ -19,6 +20,7 @@ from bot.utils.regex import INVITE_RE  log = logging.getLogger(__name__)  DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")  def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: @@ -481,7 +483,7 @@ class UserMentionOrID(UserConverter):      async def convert(self, ctx: Context, argument: str) -> discord.User:          """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) +        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)          if match is not None:              return await super().convert(ctx, argument) @@ -534,5 +536,19 @@ class FetchedUser(UserConverter):              raise BadArgument(f"User `{arg}` does not exist") +def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: +    """ +    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + +    The snowflake is expected to be within the first capture group in `pattern`. +    """ +    match = pattern.match(arg) +    if not match: +        raise BadArgument(f"Mention {str!r} is invalid.") + +    return int(match.group(1)) + +  Expiry = t.Union[Duration, ISODateTime]  FetchedMember = t.Union[discord.Member, FetchedUser] +UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) | 
