diff options
Diffstat (limited to 'bot/converters.py')
| -rw-r--r-- | bot/converters.py | 78 | 
1 files changed, 77 insertions, 1 deletions
diff --git a/bot/converters.py b/bot/converters.py index 595809517..23aa9eab8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import logging  import re  import typing as t @@ -11,13 +13,17 @@ import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs  from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: +    from bot.exts.info.source import SourceType  log = logging.getLogger(__name__) @@ -128,6 +134,44 @@ class ValidFilterListType(Converter):          return list_type +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class PackageName(Converter):      """      A converter that checks whether the given string is a valid package name. @@ -290,6 +334,38 @@ class TagContentConverter(Converter):          return tag_content +class SourceConverter(Converter): +    """Convert an argument into a help command, tag, command, or cog.""" + +    @staticmethod +    async def convert(ctx: Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        if argument.lower() == "help": +            return ctx.bot.help_command + +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog + +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd + +        tags_cog = ctx.bot.get_cog("Tags") +        show_tag = True + +        if not tags_cog: +            show_tag = False +        elif argument.lower() in tags_cog._cache: +            return argument.lower() + +        escaped_arg = escape_markdown(argument) + +        raise BadArgument( +            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." +        ) + +  class DurationDelta(Converter):      """Convert duration strings into dateutil.relativedelta.relativedelta objects."""  |