diff options
Diffstat (limited to 'bot/converters.py')
| -rw-r--r-- | bot/converters.py | 67 | 
1 files changed, 55 insertions, 12 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 1358cbf1e..2e118d476 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -2,6 +2,7 @@ import logging  import re  import typing as t  from datetime import datetime +from functools import partial  from ssl import CertificateError  import dateutil.parser @@ -10,6 +11,7 @@ import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from discord.utils import DISCORD_EPOCH, snowflake_time  from bot.api import ResponseCodeError  from bot.constants import URLs @@ -17,6 +19,9 @@ from bot.utils.regex import INVITE_RE  log = logging.getLogger(__name__) +DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000) +RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") +  def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:      """ @@ -172,17 +177,42 @@ class ValidURL(Converter):          return url -class InfractionSearchQuery(Converter): -    """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" +class Snowflake(IDConverter): +    """ +    Converts to an int if the argument is a valid Discord snowflake. + +    A snowflake is valid if: + +    * It consists of 15-21 digits (0-9) +    * Its parsed datetime is after the Discord epoch +    * Its parsed datetime is less than 1 day after the current time +    """ + +    async def convert(self, ctx: Context, arg: str) -> int: +        """ +        Ensure `arg` matches the ID pattern and its timestamp is in range. + +        Return `arg` as an int if it's a valid snowflake. +        """ +        error = f"Invalid snowflake {arg!r}" + +        if not self._get_id_match(arg): +            raise BadArgument(error) + +        snowflake = int(arg) -    @staticmethod -    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: -        """Check if the argument is a Discord user, and if not, falls back to a string."""          try: -            maybe_snowflake = arg.strip("<@!>") -            return await ctx.bot.fetch_user(maybe_snowflake) -        except (discord.NotFound, discord.HTTPException): -            return arg +            time = snowflake_time(snowflake) +        except (OverflowError, OSError) as e: +            # Not sure if this can ever even happen, but let's be safe. +            raise BadArgument(f"{error}: {e}") + +        if time < DISCORD_EPOCH_DT: +            raise BadArgument(f"{error}: timestamp is before the Discord epoch.") +        elif (datetime.utcnow() - time).days < -1: +            raise BadArgument(f"{error}: timestamp is too far into the future.") + +        return snowflake  class Subreddit(Converter): @@ -447,14 +477,13 @@ class UserMentionOrID(UserConverter):      """      Converts to a `discord.User`, but only if a mention or userID is provided. -    Unlike the default `UserConverter`, it does allow conversion from name, or name#descrim. - +    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim.      This is useful in cases where that lookup strategy would lead to ambiguity.      """      async def convert(self, ctx: Context, argument: str) -> discord.User:          """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) +        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)          if match is not None:              return await super().convert(ctx, argument) @@ -507,5 +536,19 @@ class FetchedUser(UserConverter):              raise BadArgument(f"User `{arg}` does not exist") +def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: +    """ +    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. + +    The snowflake is expected to be within the first capture group in `pattern`. +    """ +    match = pattern.match(arg) +    if not match: +        raise BadArgument(f"Mention {str!r} is invalid.") + +    return int(match.group(1)) + +  Expiry = t.Union[Duration, ISODateTime]  FetchedMember = t.Union[discord.Member, FetchedUser] +UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) | 
