diff options
| author | 2021-08-09 20:46:41 +0200 | |
|---|---|---|
| committer | 2021-08-09 21:37:54 +0200 | |
| commit | da85add68b993136dbe1c3eb9da33a3a8ab1862b (patch) | |
| tree | 9f2cc023ef6f6b3919f3493e51be45c64fadca69 | |
| parent | Merge pull request #1727 from onerandomusername/patch-1 (diff) | |
Move all converters to converters.py
| -rw-r--r-- | bot/converters.py | 78 | ||||
| -rw-r--r-- | bot/exts/info/source.py | 35 | ||||
| -rw-r--r-- | bot/exts/utils/extensions.py | 41 | 
3 files changed, 81 insertions, 73 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 595809517..23aa9eab8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import logging  import re  import typing as t @@ -11,13 +13,17 @@ import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs  from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: +    from bot.exts.info.source import SourceType  log = logging.getLogger(__name__) @@ -128,6 +134,44 @@ class ValidFilterListType(Converter):          return list_type +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class PackageName(Converter):      """      A converter that checks whether the given string is a valid package name. @@ -290,6 +334,38 @@ class TagContentConverter(Converter):          return tag_content +class SourceConverter(Converter): +    """Convert an argument into a help command, tag, command, or cog.""" + +    @staticmethod +    async def convert(ctx: Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        if argument.lower() == "help": +            return ctx.bot.help_command + +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog + +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd + +        tags_cog = ctx.bot.get_cog("Tags") +        show_tag = True + +        if not tags_cog: +            show_tag = False +        elif argument.lower() in tags_cog._cache: +            return argument.lower() + +        escaped_arg = escape_markdown(argument) + +        raise BadArgument( +            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." +        ) + +  class DurationDelta(Converter):      """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index ef07c77a1..8ce25b4e8 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,47 +2,16 @@ import inspect  from pathlib import Path  from typing import Optional, Tuple, Union -from discord import Embed, utils +from discord import Embed  from discord.ext import commands  from bot.bot import Bot  from bot.constants import URLs +from bot.converters import SourceConverter  SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] -class SourceConverter(commands.Converter): -    """Convert an argument into a help command, tag, command, or cog.""" - -    @staticmethod -    async def convert(ctx: commands.Context, argument: str) -> SourceType: -        """Convert argument into source object.""" -        if argument.lower() == "help": -            return ctx.bot.help_command - -        cog = ctx.bot.get_cog(argument) -        if cog: -            return cog - -        cmd = ctx.bot.get_command(argument) -        if cmd: -            return cmd - -        tags_cog = ctx.bot.get_cog("Tags") -        show_tag = True - -        if not tags_cog: -            show_tag = False -        elif argument.lower() in tags_cog._cache: -            return argument.lower() - -        escaped_arg = utils.escape_markdown(argument) - -        raise commands.BadArgument( -            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." -        ) - -  class BotSource(commands.Cog):      """Displays information about the bot's source code.""" diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 8a1ed98f4..f78664527 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -10,8 +10,9 @@ from discord.ext.commands import Context, group  from bot import exts  from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.converters import Extension  from bot.pagination import LinePaginator -from bot.utils.extensions import EXTENSIONS, unqualify +from bot.utils.extensions import EXTENSIONS  log = logging.getLogger(__name__) @@ -29,44 +30,6 @@ class Action(Enum):      RELOAD = functools.partial(Bot.reload_extension) -class Extension(commands.Converter): -    """ -    Fully qualify the name of an extension and ensure it exists. - -    The * and ** values bypass this when used with the reload command. -    """ - -    async def convert(self, ctx: Context, argument: str) -> str: -        """Fully qualify the name of an extension and ensure it exists.""" -        # Special values to reload all extensions -        if argument == "*" or argument == "**": -            return argument - -        argument = argument.lower() - -        if argument in EXTENSIONS: -            return argument -        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: -            return qualified_arg - -        matches = [] -        for ext in EXTENSIONS: -            if argument == unqualify(ext): -                matches.append(ext) - -        if len(matches) > 1: -            matches.sort() -            names = "\n".join(matches) -            raise commands.BadArgument( -                f":x: `{argument}` is an ambiguous extension name. " -                f"Please use one of the following fully-qualified names.```\n{names}```" -            ) -        elif matches: -            return matches[0] -        else: -            raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - -  class Extensions(commands.Cog):      """Extension management commands.""" | 
