aboutsummaryrefslogtreecommitdiffstats
path: root/bot/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/converters.py')
-rw-r--r--bot/converters.py67
1 files changed, 55 insertions, 12 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 1358cbf1e..2e118d476 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -2,6 +2,7 @@ import logging
import re
import typing as t
from datetime import datetime
+from functools import partial
from ssl import CertificateError
import dateutil.parser
@@ -10,6 +11,7 @@ import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
+from discord.utils import DISCORD_EPOCH, snowflake_time
from bot.api import ResponseCodeError
from bot.constants import URLs
@@ -17,6 +19,9 @@ from bot.utils.regex import INVITE_RE
log = logging.getLogger(__name__)
+DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)
+RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
+
def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
"""
@@ -172,17 +177,42 @@ class ValidURL(Converter):
return url
-class InfractionSearchQuery(Converter):
- """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
+class Snowflake(IDConverter):
+ """
+ Converts to an int if the argument is a valid Discord snowflake.
+
+ A snowflake is valid if:
+
+ * It consists of 15-21 digits (0-9)
+ * Its parsed datetime is after the Discord epoch
+ * Its parsed datetime is less than 1 day after the current time
+ """
+
+ async def convert(self, ctx: Context, arg: str) -> int:
+ """
+ Ensure `arg` matches the ID pattern and its timestamp is in range.
+
+ Return `arg` as an int if it's a valid snowflake.
+ """
+ error = f"Invalid snowflake {arg!r}"
+
+ if not self._get_id_match(arg):
+ raise BadArgument(error)
+
+ snowflake = int(arg)
- @staticmethod
- async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
- """Check if the argument is a Discord user, and if not, falls back to a string."""
try:
- maybe_snowflake = arg.strip("<@!>")
- return await ctx.bot.fetch_user(maybe_snowflake)
- except (discord.NotFound, discord.HTTPException):
- return arg
+ time = snowflake_time(snowflake)
+ except (OverflowError, OSError) as e:
+ # Not sure if this can ever even happen, but let's be safe.
+ raise BadArgument(f"{error}: {e}")
+
+ if time < DISCORD_EPOCH_DT:
+ raise BadArgument(f"{error}: timestamp is before the Discord epoch.")
+ elif (datetime.utcnow() - time).days < -1:
+ raise BadArgument(f"{error}: timestamp is too far into the future.")
+
+ return snowflake
class Subreddit(Converter):
@@ -447,14 +477,13 @@ class UserMentionOrID(UserConverter):
"""
Converts to a `discord.User`, but only if a mention or userID is provided.
- Unlike the default `UserConverter`, it does allow conversion from name, or name#descrim.
-
+ Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim.
This is useful in cases where that lookup strategy would lead to ambiguity.
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
"""Convert the `arg` to a `discord.User`."""
- match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
+ match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)
if match is not None:
return await super().convert(ctx, argument)
@@ -507,5 +536,19 @@ class FetchedUser(UserConverter):
raise BadArgument(f"User `{arg}` does not exist")
+def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int:
+ """
+ Extract the snowflake from `arg` using a regex `pattern` and return it as an int.
+
+ The snowflake is expected to be within the first capture group in `pattern`.
+ """
+ match = pattern.match(arg)
+ if not match:
+ raise BadArgument(f"Mention {str!r} is invalid.")
+
+ return int(match.group(1))
+
+
Expiry = t.Union[Duration, ISODateTime]
FetchedMember = t.Union[discord.Member, FetchedUser]
+UserMention = partial(_snowflake_from_regex, RE_USER_MENTION)