aboutsummaryrefslogtreecommitdiffstats
path: root/bot/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/converters.py')
-rw-r--r--bot/converters.py222
1 files changed, 207 insertions, 15 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 72c46fdf0..2e118d476 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -2,6 +2,7 @@ import logging
import re
import typing as t
from datetime import datetime
+from functools import partial
from ssl import CertificateError
import dateutil.parser
@@ -9,11 +10,18 @@ import dateutil.tz
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument, Context, Converter, UserConverter
+from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
+from discord.utils import DISCORD_EPOCH, snowflake_time
+from bot.api import ResponseCodeError
+from bot.constants import URLs
+from bot.utils.regex import INVITE_RE
log = logging.getLogger(__name__)
+DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)
+RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
+
def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
"""
@@ -34,6 +42,90 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s
return converter
+class ValidDiscordServerInvite(Converter):
+ """
+ A converter that validates whether a given string is a valid Discord server invite.
+
+ Raises 'BadArgument' if:
+ - The string is not a valid Discord server invite.
+ - The string is valid, but is an invite for a group DM.
+ - The string is valid, but is expired.
+
+ Returns a (partial) guild object if:
+ - The string is a valid vanity
+ - The string is a full invite URI
+ - The string contains the invite code (the stuff after discord.gg/)
+
+ See the Discord API docs for documentation on the guild object:
+ https://discord.com/developers/docs/resources/guild#guild-object
+ """
+
+ async def convert(self, ctx: Context, server_invite: str) -> dict:
+ """Check whether the string is a valid Discord server invite."""
+ invite_code = INVITE_RE.search(server_invite)
+ if invite_code:
+ response = await ctx.bot.http_session.get(
+ f"{URLs.discord_invite_api}/{invite_code[1]}"
+ )
+ if response.status != 404:
+ invite_data = await response.json()
+ return invite_data.get("guild")
+
+ id_converter = IDConverter()
+ if id_converter._get_id_match(server_invite):
+ raise BadArgument("Guild IDs are not supported, only invites.")
+
+ raise BadArgument("This does not appear to be a valid Discord server invite.")
+
+
+class ValidFilterListType(Converter):
+ """
+ A converter that checks whether the given string is a valid FilterList type.
+
+ Raises `BadArgument` if the argument is not a valid FilterList type, and simply
+ passes through the given argument otherwise.
+ """
+
+ @staticmethod
+ async def get_valid_types(bot: Bot) -> list:
+ """
+ Try to get a list of valid filter list types.
+
+ Raise a BadArgument if the API can't respond.
+ """
+ try:
+ valid_types = await bot.api_client.get('bot/filter-lists/get-types')
+ except ResponseCodeError:
+ raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.")
+
+ return [enum for enum, classname in valid_types]
+
+ async def convert(self, ctx: Context, list_type: str) -> str:
+ """Checks whether the given string is a valid FilterList type."""
+ valid_types = await self.get_valid_types(ctx.bot)
+ list_type = list_type.upper()
+
+ if list_type not in valid_types:
+
+ # Maybe the user is using the plural form of this type,
+ # e.g. "guild_invites" instead of "guild_invite".
+ #
+ # This code will support the simple plural form (a single 's' at the end),
+ # which works for all current list types, but if a list type is added in the future
+ # which has an irregular plural form (like 'ies'), this code will need to be
+ # refactored to support this.
+ if list_type.endswith("S") and list_type[:-1] in valid_types:
+ list_type = list_type[:-1]
+
+ else:
+ valid_types_list = '\n'.join([f"โ€ข {type_.lower()}" for type_ in valid_types])
+ raise BadArgument(
+ f"You have provided an invalid list type!\n\n"
+ f"Please provide one of the following: \n{valid_types_list}"
+ )
+ return list_type
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -85,17 +177,42 @@ class ValidURL(Converter):
return url
-class InfractionSearchQuery(Converter):
- """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
+class Snowflake(IDConverter):
+ """
+ Converts to an int if the argument is a valid Discord snowflake.
+
+ A snowflake is valid if:
+
+ * It consists of 15-21 digits (0-9)
+ * Its parsed datetime is after the Discord epoch
+ * Its parsed datetime is less than 1 day after the current time
+ """
+
+ async def convert(self, ctx: Context, arg: str) -> int:
+ """
+ Ensure `arg` matches the ID pattern and its timestamp is in range.
+
+ Return `arg` as an int if it's a valid snowflake.
+ """
+ error = f"Invalid snowflake {arg!r}"
+
+ if not self._get_id_match(arg):
+ raise BadArgument(error)
+
+ snowflake = int(arg)
- @staticmethod
- async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
- """Check if the argument is a Discord user, and if not, falls back to a string."""
try:
- maybe_snowflake = arg.strip("<@!>")
- return await ctx.bot.fetch_user(maybe_snowflake)
- except (discord.NotFound, discord.HTTPException):
- return arg
+ time = snowflake_time(snowflake)
+ except (OverflowError, OSError) as e:
+ # Not sure if this can ever even happen, but let's be safe.
+ raise BadArgument(f"{error}: {e}")
+
+ if time < DISCORD_EPOCH_DT:
+ raise BadArgument(f"{error}: timestamp is before the Discord epoch.")
+ elif (datetime.utcnow() - time).days < -1:
+ raise BadArgument(f"{error}: timestamp is too far into the future.")
+
+ return snowflake
class Subreddit(Converter):
@@ -181,8 +298,8 @@ class TagContentConverter(Converter):
return tag_content
-class Duration(Converter):
- """Convert duration strings into UTC datetime.datetime objects."""
+class DurationDelta(Converter):
+ """Convert duration strings into dateutil.relativedelta.relativedelta objects."""
duration_parser = re.compile(
r"((?P<years>\d+?) ?(years|year|Y|y) ?)?"
@@ -194,9 +311,9 @@ class Duration(Converter):
r"((?P<seconds>\d+?) ?(seconds|second|S|s))?"
)
- async def convert(self, ctx: Context, duration: str) -> datetime:
+ async def convert(self, ctx: Context, duration: str) -> relativedelta:
"""
- Converts a `duration` string to a datetime object that's `duration` in the future.
+ Converts a `duration` string to a relativedelta object.
The converter supports the following symbols for each unit of time:
- years: `Y`, `y`, `year`, `years`
@@ -215,9 +332,52 @@ class Duration(Converter):
duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()}
delta = relativedelta(**duration_dict)
+
+ return delta
+
+
+class Duration(DurationDelta):
+ """Convert duration strings into UTC datetime.datetime objects."""
+
+ async def convert(self, ctx: Context, duration: str) -> datetime:
+ """
+ Converts a `duration` string to a datetime object that's `duration` in the future.
+
+ The converter supports the same symbols for each unit of time as its parent class.
+ """
+ delta = await super().convert(ctx, duration)
now = datetime.utcnow()
- return now + delta
+ try:
+ return now + delta
+ except ValueError:
+ raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
+
+
+class OffTopicName(Converter):
+ """A converter that ensures an added off-topic name is valid."""
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Attempt to replace any invalid characters with their approximate Unicode equivalent."""
+ allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+
+ # Chain multiple words to a single one
+ argument = "-".join(argument.split())
+
+ if not (2 <= len(argument) <= 96):
+ raise BadArgument("Channel name must be between 2 and 96 chars long")
+
+ elif not all(c.isalnum() or c in allowed_characters for c in argument):
+ raise BadArgument(
+ "Channel name must only consist of "
+ "alphanumeric characters, minus signs or apostrophes."
+ )
+
+ # Replace invalid characters with unicode alternatives.
+ table = str.maketrans(
+ allowed_characters, '๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-'
+ )
+ return argument.translate(table)
class ISODateTime(Converter):
@@ -313,6 +473,24 @@ def proxy_user(user_id: str) -> discord.Object:
return user
+class UserMentionOrID(UserConverter):
+ """
+ Converts to a `discord.User`, but only if a mention or userID is provided.
+
+ Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim.
+ This is useful in cases where that lookup strategy would lead to ambiguity.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> discord.User:
+ """Convert the `arg` to a `discord.User`."""
+ match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)
+
+ if match is not None:
+ return await super().convert(ctx, argument)
+ else:
+ raise BadArgument(f"`{argument}` is not a User mention or a User ID.")
+
+
class FetchedUser(UserConverter):
"""
Converts to a `discord.User` or, if it fails, a `discord.Object`.
@@ -358,5 +536,19 @@ class FetchedUser(UserConverter):
raise BadArgument(f"User `{arg}` does not exist")
+def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int:
+ """
+ Extract the snowflake from `arg` using a regex `pattern` and return it as an int.
+
+ The snowflake is expected to be within the first capture group in `pattern`.
+ """
+ match = pattern.match(arg)
+ if not match:
+ raise BadArgument(f"Mention {str!r} is invalid.")
+
+ return int(match.group(1))
+
+
Expiry = t.Union[Duration, ISODateTime]
FetchedMember = t.Union[discord.Member, FetchedUser]
+UserMention = partial(_snowflake_from_regex, RE_USER_MENTION)