diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 156 |
1 files changed, 151 insertions, 5 deletions
diff --git a/bot/converters.py b/bot/converters.py index 4deb59f87..1358cbf1e 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -9,8 +9,11 @@ import dateutil.tz import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Context, Converter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from bot.api import ResponseCodeError +from bot.constants import URLs +from bot.utils.regex import INVITE_RE log = logging.getLogger(__name__) @@ -34,6 +37,90 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s return converter +class ValidDiscordServerInvite(Converter): + """ + A converter that validates whether a given string is a valid Discord server invite. + + Raises 'BadArgument' if: + - The string is not a valid Discord server invite. + - The string is valid, but is an invite for a group DM. + - The string is valid, but is expired. + + Returns a (partial) guild object if: + - The string is a valid vanity + - The string is a full invite URI + - The string contains the invite code (the stuff after discord.gg/) + + See the Discord API docs for documentation on the guild object: + https://discord.com/developers/docs/resources/guild#guild-object + """ + + async def convert(self, ctx: Context, server_invite: str) -> dict: + """Check whether the string is a valid Discord server invite.""" + invite_code = INVITE_RE.search(server_invite) + if invite_code: + response = await ctx.bot.http_session.get( + f"{URLs.discord_invite_api}/{invite_code[1]}" + ) + if response.status != 404: + invite_data = await response.json() + return invite_data.get("guild") + + id_converter = IDConverter() + if id_converter._get_id_match(server_invite): + raise BadArgument("Guild IDs are not supported, only invites.") + + raise BadArgument("This does not appear to be a valid Discord server invite.") + + +class ValidFilterListType(Converter): + """ + A converter that checks whether the given string is a valid FilterList type. + + Raises `BadArgument` if the argument is not a valid FilterList type, and simply + passes through the given argument otherwise. + """ + + @staticmethod + async def get_valid_types(bot: Bot) -> list: + """ + Try to get a list of valid filter list types. + + Raise a BadArgument if the API can't respond. + """ + try: + valid_types = await bot.api_client.get('bot/filter-lists/get-types') + except ResponseCodeError: + raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") + + return [enum for enum, classname in valid_types] + + async def convert(self, ctx: Context, list_type: str) -> str: + """Checks whether the given string is a valid FilterList type.""" + valid_types = await self.get_valid_types(ctx.bot) + list_type = list_type.upper() + + if list_type not in valid_types: + + # Maybe the user is using the plural form of this type, + # e.g. "guild_invites" instead of "guild_invite". + # + # This code will support the simple plural form (a single 's' at the end), + # which works for all current list types, but if a list type is added in the future + # which has an irregular plural form (like 'ies'), this code will need to be + # refactored to support this. + if list_type.endswith("S") and list_type[:-1] in valid_types: + list_type = list_type[:-1] + + else: + valid_types_list = '\n'.join([f"โข {type_.lower()}" for type_ in valid_types]) + raise BadArgument( + f"You have provided an invalid list type!\n\n" + f"Please provide one of the following: \n{valid_types_list}" + ) + return list_type + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -181,8 +268,8 @@ class TagContentConverter(Converter): return tag_content -class Duration(Converter): - """Convert duration strings into UTC datetime.datetime objects.""" +class DurationDelta(Converter): + """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" duration_parser = re.compile( r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" @@ -194,9 +281,9 @@ class Duration(Converter): r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" ) - async def convert(self, ctx: Context, duration: str) -> datetime: + async def convert(self, ctx: Context, duration: str) -> relativedelta: """ - Converts a `duration` string to a datetime object that's `duration` in the future. + Converts a `duration` string to a relativedelta object. The converter supports the following symbols for each unit of time: - years: `Y`, `y`, `year`, `years` @@ -215,6 +302,20 @@ class Duration(Converter): duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} delta = relativedelta(**duration_dict) + + return delta + + +class Duration(DurationDelta): + """Convert duration strings into UTC datetime.datetime objects.""" + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports the same symbols for each unit of time as its parent class. + """ + delta = await super().convert(ctx, duration) now = datetime.utcnow() try: @@ -223,6 +324,32 @@ class Duration(Converter): raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class OffTopicName(Converter): + """A converter that ensures an added off-topic name is valid.""" + + async def convert(self, ctx: Context, argument: str) -> str: + """Attempt to replace any invalid characters with their approximate Unicode equivalent.""" + allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + + # Chain multiple words to a single one + argument = "-".join(argument.split()) + + if not (2 <= len(argument) <= 96): + raise BadArgument("Channel name must be between 2 and 96 chars long") + + elif not all(c.isalnum() or c in allowed_characters for c in argument): + raise BadArgument( + "Channel name must only consist of " + "alphanumeric characters, minus signs or apostrophes." + ) + + # Replace invalid characters with unicode alternatives. + table = str.maketrans( + allowed_characters, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-' + ) + return argument.translate(table) + + class ISODateTime(Converter): """Converts an ISO-8601 datetime string into a datetime.datetime.""" @@ -316,6 +443,25 @@ def proxy_user(user_id: str) -> discord.Object: return user +class UserMentionOrID(UserConverter): + """ + Converts to a `discord.User`, but only if a mention or userID is provided. + + Unlike the default `UserConverter`, it does allow conversion from name, or name#descrim. + + This is useful in cases where that lookup strategy would lead to ambiguity. + """ + + async def convert(self, ctx: Context, argument: str) -> discord.User: + """Convert the `arg` to a `discord.User`.""" + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) + + if match is not None: + return await super().convert(ctx, argument) + else: + raise BadArgument(f"`{argument}` is not a User mention or a User ID.") + + class FetchedUser(UserConverter): """ Converts to a `discord.User` or, if it fails, a `discord.Object`. |