diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 48 |
1 files changed, 14 insertions, 34 deletions
diff --git a/bot/converters.py b/bot/converters.py index 559e759e1..8a140e0c2 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -6,23 +6,22 @@ from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser -import dateutil.tz import discord from aiohttp import ClientConnectorError +from botcore.site_api import ResponseCodeError +from botcore.utils import unqualify +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter from discord.utils import escape_markdown, snowflake_time -from bot import exts -from bot.api import ResponseCodeError +from bot import exts, instance as bot_instance from bot.constants import URLs from bot.errors import InvalidInfraction from bot.exts.info.doc import _inventory_parser from bot.exts.info.tags import TagIdentifier from bot.log import get_logger -from bot.utils.extensions import EXTENSIONS, unqualify -from bot.utils.regex import INVITE_RE -from bot.utils.time import parse_duration_string +from bot.utils import time if t.TYPE_CHECKING: from bot.exts.info.source import SourceType @@ -33,25 +32,6 @@ DISCORD_EPOCH_DT = snowflake_time(0) RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$") -def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: - """ - Return a converter which only allows arguments equal to one of the given values. - - Unless preserve_case is True, the argument is converted to lowercase. All values are then - expected to have already been given in lowercase too. - """ - def converter(arg: str) -> str: - if not preserve_case: - arg = arg.lower() - - if arg not in values: - raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") - else: - return arg - - return converter - - class ValidDiscordServerInvite(Converter): """ A converter that validates whether a given string is a valid Discord server invite. @@ -72,7 +52,7 @@ class ValidDiscordServerInvite(Converter): async def convert(self, ctx: Context, server_invite: str) -> dict: """Check whether the string is a valid Discord server invite.""" - invite_code = INVITE_RE.match(server_invite) + invite_code = DISCORD_INVITE.match(server_invite) if invite_code: response = await ctx.bot.http_session.get( f"{URLs.discord_invite_api}/{invite_code.group('invite')}" @@ -151,13 +131,13 @@ class Extension(Converter): argument = argument.lower() - if argument in EXTENSIONS: + if argument in bot_instance.all_extensions: return argument - elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + elif (qualified_arg := f"{exts.__name__}.{argument}") in bot_instance.all_extensions: return qualified_arg matches = [] - for ext in EXTENSIONS: + for ext in bot_instance.all_extensions: if argument == unqualify(ext): matches.append(ext) @@ -338,7 +318,7 @@ class DurationDelta(Converter): The units need to be provided in descending order of magnitude. """ - if not (delta := parse_duration_string(duration)): + if not (delta := time.parse_duration_string(duration)): raise BadArgument(f"`{duration}` is not a valid duration string.") return delta @@ -383,8 +363,8 @@ class Age(DurationDelta): class OffTopicName(Converter): """A converter that ensures an added off-topic name is valid.""" - ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>" - TRANSLATED_CHARACTERS = "๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-๏ผ๏ผ" + ALLOWED_CHARACTERS = r"ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>\/" + TRANSLATED_CHARACTERS = "๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-๏ผ๏ผโงนโงธ" @classmethod def translate_name(cls, name: str, *, from_unicode: bool = True) -> str: @@ -454,9 +434,9 @@ class ISODateTime(Converter): raise BadArgument(f"`{datetime_string}` is not a valid ISO-8601 datetime string") if dt.tzinfo: - dt = dt.astimezone(dateutil.tz.UTC) + dt = dt.astimezone(timezone.utc) else: # Without a timezone, assume it represents UTC. - dt = dt.replace(tzinfo=dateutil.tz.UTC) + dt = dt.replace(tzinfo=timezone.utc) return dt |