aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-08-11 16:37:33 -0700
committerGravatar MarkKoz <[email protected]>2020-08-11 16:37:33 -0700
commitb01e854e3870eb90ef2cb9dec70040f4a673387d (patch)
treeed157f5ae9383675b648fbca13afa0c9f4ba7d69
parentReplace InfractionSearchQuery with a generic Snowflake converter (diff)
Create a UserMention converter
-rw-r--r--bot/cogs/moderation/management.py6
-rw-r--r--bot/converters.py18
2 files changed, 20 insertions, 4 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index af736d4de..c2cca5352 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -10,7 +10,7 @@ from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
-from bot.converters import Expiry, Snowflake, allowed_strings, proxy_user
+from bot.converters import Expiry, Snowflake, UserMention, allowed_strings, proxy_user
from bot.pagination import LinePaginator
from bot.utils import messages, time
from bot.utils.checks import in_whitelist_check, with_role_check
@@ -177,9 +177,9 @@ class ModManagement(commands.Cog):
# region: Search infractions
@infraction_group.group(name="search", invoke_without_command=True)
- async def infraction_search_group(self, ctx: Context, query: Snowflake) -> None:
+ async def infraction_search_group(self, ctx: Context, query: t.Union[UserMention, Snowflake, str]) -> None:
"""Searches for infractions in the database."""
- if isinstance(query, discord.User):
+ if isinstance(query, int):
await ctx.invoke(self.search_user, query)
else:
await ctx.invoke(self.search_reason, query)
diff --git a/bot/converters.py b/bot/converters.py
index 4c41d0ece..4cfd663ba 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -2,6 +2,7 @@ import logging
import re
import typing as t
from datetime import datetime
+from functools import partial
from ssl import CertificateError
import dateutil.parser
@@ -19,6 +20,7 @@ from bot.utils.regex import INVITE_RE
log = logging.getLogger(__name__)
DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)
+RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
@@ -481,7 +483,7 @@ class UserMentionOrID(UserConverter):
async def convert(self, ctx: Context, argument: str) -> discord.User:
"""Convert the `arg` to a `discord.User`."""
- match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
+ match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)
if match is not None:
return await super().convert(ctx, argument)
@@ -534,5 +536,19 @@ class FetchedUser(UserConverter):
raise BadArgument(f"User `{arg}` does not exist")
+def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int:
+ """
+ Extract the snowflake from `arg` using a regex `pattern` and return it as an int.
+
+ The snowflake is expected to be within the first capture group in `pattern`.
+ """
+ match = pattern.match(arg)
+ if not match:
+ raise BadArgument(f"Mention {str!r} is invalid.")
+
+ return int(match.group(1))
+
+
Expiry = t.Union[Duration, ISODateTime]
FetchedMember = t.Union[discord.Member, FetchedUser]
+UserMention = partial(_snowflake_from_regex, RE_USER_MENTION)