diff options
| author | 2020-08-11 16:37:33 -0700 | |
|---|---|---|
| committer | 2020-08-11 16:37:33 -0700 | |
| commit | b01e854e3870eb90ef2cb9dec70040f4a673387d (patch) | |
| tree | ed157f5ae9383675b648fbca13afa0c9f4ba7d69 | |
| parent | Replace InfractionSearchQuery with a generic Snowflake converter (diff) | |
Create a UserMention converter
| -rw-r--r-- | bot/cogs/moderation/management.py | 6 | ||||
| -rw-r--r-- | bot/converters.py | 18 |
2 files changed, 20 insertions, 4 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index af736d4de..c2cca5352 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot -from bot.converters import Expiry, Snowflake, allowed_strings, proxy_user +from bot.converters import Expiry, Snowflake, UserMention, allowed_strings, proxy_user from bot.pagination import LinePaginator from bot.utils import messages, time from bot.utils.checks import in_whitelist_check, with_role_check @@ -177,9 +177,9 @@ class ModManagement(commands.Cog): # region: Search infractions @infraction_group.group(name="search", invoke_without_command=True) - async def infraction_search_group(self, ctx: Context, query: Snowflake) -> None: + async def infraction_search_group(self, ctx: Context, query: t.Union[UserMention, Snowflake, str]) -> None: """Searches for infractions in the database.""" - if isinstance(query, discord.User): + if isinstance(query, int): await ctx.invoke(self.search_user, query) else: await ctx.invoke(self.search_reason, query) diff --git a/bot/converters.py b/bot/converters.py index 4c41d0ece..4cfd663ba 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -2,6 +2,7 @@ import logging import re import typing as t from datetime import datetime +from functools import partial from ssl import CertificateError import dateutil.parser @@ -19,6 +20,7 @@ from bot.utils.regex import INVITE_RE log = logging.getLogger(__name__) DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: @@ -481,7 +483,7 @@ class UserMentionOrID(UserConverter): async def convert(self, ctx: Context, argument: str) -> discord.User: """Convert the `arg` to a `discord.User`.""" - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) + match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) if match is not None: return await super().convert(ctx, argument) @@ -534,5 +536,19 @@ class FetchedUser(UserConverter): raise BadArgument(f"User `{arg}` does not exist") +def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: + """ + Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + + The snowflake is expected to be within the first capture group in `pattern`. + """ + match = pattern.match(arg) + if not match: + raise BadArgument(f"Mention {str!r} is invalid.") + + return int(match.group(1)) + + Expiry = t.Union[Duration, ISODateTime] FetchedMember = t.Union[discord.Member, FetchedUser] +UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) |