diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 81 |
1 files changed, 23 insertions, 58 deletions
diff --git a/bot/converters.py b/bot/converters.py index 0984fa0a3..6f35d2fe4 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -6,22 +6,22 @@ from datetime import datetime, timezone from ssl import CertificateError import dateutil.parser -import dateutil.tz -import discord +import disnake from aiohttp import ClientConnectorError +from botcore.utils.regex import DISCORD_INVITE from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter -from discord.utils import escape_markdown, snowflake_time +from disnake.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from disnake.utils import escape_markdown, snowflake_time from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs from bot.errors import InvalidInfraction from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier from bot.log import get_logger +from bot.utils import time from bot.utils.extensions import EXTENSIONS, unqualify -from bot.utils.regex import INVITE_RE -from bot.utils.time import parse_duration_string if t.TYPE_CHECKING: from bot.exts.info.source import SourceType @@ -71,7 +71,7 @@ class ValidDiscordServerInvite(Converter): async def convert(self, ctx: Context, server_invite: str) -> dict: """Check whether the string is a valid Discord server invite.""" - invite_code = INVITE_RE.match(server_invite) + invite_code = DISCORD_INVITE.match(server_invite) if invite_code: response = await ctx.bot.http_session.get( f"{URLs.discord_invite_api}/{invite_code.group('invite')}" @@ -286,41 +286,6 @@ class Snowflake(IDConverter): return snowflake -class TagNameConverter(Converter): - """ - Ensure that a proposed tag name is valid. - - Valid tag names meet the following conditions: - * All ASCII characters - * Has at least one non-whitespace character - * Not solely numeric - * Shorter than 127 characters - """ - - @staticmethod - async def convert(ctx: Context, tag_name: str) -> str: - """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" - tag_name = tag_name.lower().strip() - - # The tag name has at least one invalid character. - if ascii(tag_name)[1:-1] != tag_name: - raise BadArgument("Don't be ridiculous, you can't use that character!") - - # The tag name is either empty, or consists of nothing but whitespace. - elif not tag_name: - raise BadArgument("Tag names should not be empty, or filled with whitespace.") - - # The tag name is longer than 127 characters. - elif len(tag_name) > 127: - raise BadArgument("Are you insane? That's way too long!") - - # The tag name is ascii but does not contain any letters. - elif not any(character.isalpha() for character in tag_name): - raise BadArgument("Tag names must contain at least one letter.") - - return tag_name - - class SourceConverter(Converter): """Convert an argument into a help command, tag, command, or cog.""" @@ -343,9 +308,10 @@ class SourceConverter(Converter): if not tags_cog: show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier escaped_arg = escape_markdown(argument) raise BadArgument( @@ -371,7 +337,7 @@ class DurationDelta(Converter): The units need to be provided in descending order of magnitude. """ - if not (delta := parse_duration_string(duration)): + if not (delta := time.parse_duration_string(duration)): raise BadArgument(f"`{duration}` is not a valid duration string.") return delta @@ -487,9 +453,9 @@ class ISODateTime(Converter): raise BadArgument(f"`{datetime_string}` is not a valid ISO-8601 datetime string") if dt.tzinfo: - dt = dt.astimezone(dateutil.tz.UTC) + dt = dt.astimezone(timezone.utc) else: # Without a timezone, assume it represents UTC. - dt = dt.replace(tzinfo=dateutil.tz.UTC) + dt = dt.replace(tzinfo=timezone.utc) return dt @@ -539,14 +505,14 @@ AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Us class UnambiguousUser(UserConverter): """ - Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `disnake.User`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `UserConverter`, it doesn't allow conversion from a name. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> discord.User: - """Convert the `argument` to a `discord.User`.""" + async def convert(self, ctx: Context, argument: str) -> disnake.User: + """Convert the `argument` to a `disnake.User`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -555,14 +521,14 @@ class UnambiguousUser(UserConverter): class UnambiguousMember(MemberConverter): """ - Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. + Converts to a `disnake.Member`, but only if a mention, userID or a username (name#discrim) is provided. Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. This is useful in cases where that lookup strategy would lead to too much ambiguity. """ - async def convert(self, ctx: Context, argument: str) -> discord.Member: - """Convert the `argument` to a `discord.Member`.""" + async def convert(self, ctx: Context, argument: str) -> disnake.Member: + """Convert the `argument` to a `disnake.Member`.""" if _is_an_unambiguous_user_argument(argument): return await super().convert(ctx, argument) else: @@ -615,7 +581,6 @@ if t.TYPE_CHECKING: ValidURL = str # noqa: F811 Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 Snowflake = int # noqa: F811 - TagNameConverter = str # noqa: F811 SourceConverter = SourceType # noqa: F811 DurationDelta = relativedelta # noqa: F811 Duration = datetime # noqa: F811 @@ -623,10 +588,10 @@ if t.TYPE_CHECKING: OffTopicName = str # noqa: F811 ISODateTime = datetime # noqa: F811 HushDurationConverter = int # noqa: F811 - UnambiguousUser = discord.User # noqa: F811 - UnambiguousMember = discord.Member # noqa: F811 + UnambiguousUser = disnake.User # noqa: F811 + UnambiguousMember = disnake.Member # noqa: F811 Infraction = t.Optional[dict] # noqa: F811 Expiry = t.Union[Duration, ISODateTime] -MemberOrUser = t.Union[discord.Member, discord.User] +MemberOrUser = t.Union[disnake.Member, disnake.User] UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] |