diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 101 |
1 files changed, 88 insertions, 13 deletions
diff --git a/bot/converters.py b/bot/converters.py index 37eb91c7f..0118cc48a 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import re import typing as t @@ -10,13 +12,17 @@ import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify from bot.utils.regex import INVITE_RE from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: + from bot.exts.info.source import SourceType log = logging.getLogger(__name__) @@ -127,6 +133,44 @@ class ValidFilterListType(Converter): return list_type +class Extension(Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + if argument == unqualify(ext): + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise BadArgument(f":x: Could not find the extension `{argument}`.") + + class PackageName(Converter): """ A converter that checks whether the given string is a valid package name. @@ -270,23 +314,36 @@ class TagNameConverter(Converter): return tag_name -class TagContentConverter(Converter): - """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +class SourceConverter(Converter): + """Convert an argument into a help command, tag, command, or cog.""" @staticmethod - async def convert(ctx: Context, tag_content: str) -> str: - """ - Ensure tag_content is non-empty and contains at least one non-whitespace character. + async def convert(ctx: Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower() == "help": + return ctx.bot.help_command - If tag_content is valid, return the stripped version. - """ - tag_content = tag_content.strip() + cog = ctx.bot.get_cog(argument) + if cog: + return cog - # The tag contents should not be empty, or filled with whitespace. - if not tag_content: - raise BadArgument("Tag contents should not be empty, or filled with whitespace.") + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd - return tag_content + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + escaped_arg = escape_markdown(argument) + + raise BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) class DurationDelta(Converter): @@ -485,5 +542,23 @@ class Infraction(Converter): return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +if t.TYPE_CHECKING: + ValidDiscordServerInvite = dict # noqa: F811 + ValidFilterListType = str # noqa: F811 + Extension = str # noqa: F811 + PackageName = str # noqa: F811 + ValidURL = str # noqa: F811 + Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 + Snowflake = int # noqa: F811 + TagNameConverter = str # noqa: F811 + SourceConverter = SourceType # noqa: F811 + DurationDelta = relativedelta # noqa: F811 + Duration = datetime # noqa: F811 + OffTopicName = str # noqa: F811 + ISODateTime = datetime # noqa: F811 + HushDurationConverter = int # noqa: F811 + UserMentionOrID = discord.User # noqa: F811 + Infraction = t.Optional[dict] # noqa: F811 + Expiry = t.Union[Duration, ISODateTime] MemberOrUser = t.Union[discord.Member, discord.User] |