aboutsummaryrefslogtreecommitdiffstats
path: root/bot/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/converters.py')
-rw-r--r--bot/converters.py156
1 files changed, 151 insertions, 5 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 4deb59f87..1358cbf1e 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -9,8 +9,11 @@ import dateutil.tz
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument, Context, Converter, UserConverter
+from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
+from bot.api import ResponseCodeError
+from bot.constants import URLs
+from bot.utils.regex import INVITE_RE
log = logging.getLogger(__name__)
@@ -34,6 +37,90 @@ def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], s
return converter
+class ValidDiscordServerInvite(Converter):
+ """
+ A converter that validates whether a given string is a valid Discord server invite.
+
+ Raises 'BadArgument' if:
+ - The string is not a valid Discord server invite.
+ - The string is valid, but is an invite for a group DM.
+ - The string is valid, but is expired.
+
+ Returns a (partial) guild object if:
+ - The string is a valid vanity
+ - The string is a full invite URI
+ - The string contains the invite code (the stuff after discord.gg/)
+
+ See the Discord API docs for documentation on the guild object:
+ https://discord.com/developers/docs/resources/guild#guild-object
+ """
+
+ async def convert(self, ctx: Context, server_invite: str) -> dict:
+ """Check whether the string is a valid Discord server invite."""
+ invite_code = INVITE_RE.search(server_invite)
+ if invite_code:
+ response = await ctx.bot.http_session.get(
+ f"{URLs.discord_invite_api}/{invite_code[1]}"
+ )
+ if response.status != 404:
+ invite_data = await response.json()
+ return invite_data.get("guild")
+
+ id_converter = IDConverter()
+ if id_converter._get_id_match(server_invite):
+ raise BadArgument("Guild IDs are not supported, only invites.")
+
+ raise BadArgument("This does not appear to be a valid Discord server invite.")
+
+
+class ValidFilterListType(Converter):
+ """
+ A converter that checks whether the given string is a valid FilterList type.
+
+ Raises `BadArgument` if the argument is not a valid FilterList type, and simply
+ passes through the given argument otherwise.
+ """
+
+ @staticmethod
+ async def get_valid_types(bot: Bot) -> list:
+ """
+ Try to get a list of valid filter list types.
+
+ Raise a BadArgument if the API can't respond.
+ """
+ try:
+ valid_types = await bot.api_client.get('bot/filter-lists/get-types')
+ except ResponseCodeError:
+ raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.")
+
+ return [enum for enum, classname in valid_types]
+
+ async def convert(self, ctx: Context, list_type: str) -> str:
+ """Checks whether the given string is a valid FilterList type."""
+ valid_types = await self.get_valid_types(ctx.bot)
+ list_type = list_type.upper()
+
+ if list_type not in valid_types:
+
+ # Maybe the user is using the plural form of this type,
+ # e.g. "guild_invites" instead of "guild_invite".
+ #
+ # This code will support the simple plural form (a single 's' at the end),
+ # which works for all current list types, but if a list type is added in the future
+ # which has an irregular plural form (like 'ies'), this code will need to be
+ # refactored to support this.
+ if list_type.endswith("S") and list_type[:-1] in valid_types:
+ list_type = list_type[:-1]
+
+ else:
+ valid_types_list = '\n'.join([f"โ€ข {type_.lower()}" for type_ in valid_types])
+ raise BadArgument(
+ f"You have provided an invalid list type!\n\n"
+ f"Please provide one of the following: \n{valid_types_list}"
+ )
+ return list_type
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -181,8 +268,8 @@ class TagContentConverter(Converter):
return tag_content
-class Duration(Converter):
- """Convert duration strings into UTC datetime.datetime objects."""
+class DurationDelta(Converter):
+ """Convert duration strings into dateutil.relativedelta.relativedelta objects."""
duration_parser = re.compile(
r"((?P<years>\d+?) ?(years|year|Y|y) ?)?"
@@ -194,9 +281,9 @@ class Duration(Converter):
r"((?P<seconds>\d+?) ?(seconds|second|S|s))?"
)
- async def convert(self, ctx: Context, duration: str) -> datetime:
+ async def convert(self, ctx: Context, duration: str) -> relativedelta:
"""
- Converts a `duration` string to a datetime object that's `duration` in the future.
+ Converts a `duration` string to a relativedelta object.
The converter supports the following symbols for each unit of time:
- years: `Y`, `y`, `year`, `years`
@@ -215,6 +302,20 @@ class Duration(Converter):
duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()}
delta = relativedelta(**duration_dict)
+
+ return delta
+
+
+class Duration(DurationDelta):
+ """Convert duration strings into UTC datetime.datetime objects."""
+
+ async def convert(self, ctx: Context, duration: str) -> datetime:
+ """
+ Converts a `duration` string to a datetime object that's `duration` in the future.
+
+ The converter supports the same symbols for each unit of time as its parent class.
+ """
+ delta = await super().convert(ctx, duration)
now = datetime.utcnow()
try:
@@ -223,6 +324,32 @@ class Duration(Converter):
raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
+class OffTopicName(Converter):
+ """A converter that ensures an added off-topic name is valid."""
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Attempt to replace any invalid characters with their approximate Unicode equivalent."""
+ allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+
+ # Chain multiple words to a single one
+ argument = "-".join(argument.split())
+
+ if not (2 <= len(argument) <= 96):
+ raise BadArgument("Channel name must be between 2 and 96 chars long")
+
+ elif not all(c.isalnum() or c in allowed_characters for c in argument):
+ raise BadArgument(
+ "Channel name must only consist of "
+ "alphanumeric characters, minus signs or apostrophes."
+ )
+
+ # Replace invalid characters with unicode alternatives.
+ table = str.maketrans(
+ allowed_characters, '๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-'
+ )
+ return argument.translate(table)
+
+
class ISODateTime(Converter):
"""Converts an ISO-8601 datetime string into a datetime.datetime."""
@@ -316,6 +443,25 @@ def proxy_user(user_id: str) -> discord.Object:
return user
+class UserMentionOrID(UserConverter):
+ """
+ Converts to a `discord.User`, but only if a mention or userID is provided.
+
+ Unlike the default `UserConverter`, it does allow conversion from name, or name#descrim.
+
+ This is useful in cases where that lookup strategy would lead to ambiguity.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> discord.User:
+ """Convert the `arg` to a `discord.User`."""
+ match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
+
+ if match is not None:
+ return await super().convert(ctx, argument)
+ else:
+ raise BadArgument(f"`{argument}` is not a User mention or a User ID.")
+
+
class FetchedUser(UserConverter):
"""
Converts to a `discord.User` or, if it fails, a `discord.Object`.