diff options
| -rw-r--r-- | bot/converters.py | 97 | ||||
| -rw-r--r-- | bot/exts/info/source.py | 35 | ||||
| -rw-r--r-- | bot/exts/moderation/modpings.py | 1 | ||||
| -rw-r--r-- | bot/exts/moderation/silence.py | 2 | ||||
| -rw-r--r-- | bot/exts/utils/extensions.py | 41 | 
5 files changed, 100 insertions, 76 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 37eb91c7f..1c0fd673d 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import logging  import re  import typing as t @@ -10,13 +12,17 @@ import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts  from bot.api import ResponseCodeError  from bot.constants import URLs  from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: +    from bot.exts.info.source import SourceType  log = logging.getLogger(__name__) @@ -127,6 +133,44 @@ class ValidFilterListType(Converter):          return list_type +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class PackageName(Converter):      """      A converter that checks whether the given string is a valid package name. @@ -289,6 +333,38 @@ class TagContentConverter(Converter):          return tag_content +class SourceConverter(Converter): +    """Convert an argument into a help command, tag, command, or cog.""" + +    @staticmethod +    async def convert(ctx: Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        if argument.lower() == "help": +            return ctx.bot.help_command + +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog + +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd + +        tags_cog = ctx.bot.get_cog("Tags") +        show_tag = True + +        if not tags_cog: +            show_tag = False +        elif argument.lower() in tags_cog._cache: +            return argument.lower() + +        escaped_arg = escape_markdown(argument) + +        raise BadArgument( +            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." +        ) + +  class DurationDelta(Converter):      """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" @@ -485,5 +561,24 @@ class Infraction(Converter):              return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +if t.TYPE_CHECKING: +    ValidDiscordServerInvite = dict  # noqa: F811 +    ValidFilterListType = str  # noqa: F811 +    Extension = str  # noqa: F811 +    PackageName = str  # noqa: F811 +    ValidURL = str  # noqa: F811 +    Inventory = t.Tuple[str, _inventory_parser.InventoryDict]  # noqa: F811 +    Snowflake = int  # noqa: F811 +    TagNameConverter = str  # noqa: F811 +    TagContentConverter = str  # noqa: F811 +    SourceConverter = SourceType  # noqa: F811 +    DurationDelta = relativedelta  # noqa: F811 +    Duration = datetime  # noqa: F811 +    OffTopicName = str  # noqa: F811 +    ISODateTime = datetime  # noqa: F811 +    HushDurationConverter = int  # noqa: F811 +    UserMentionOrID = discord.User  # noqa: F811 +    Infraction = t.Optional[dict]  # noqa: F811 +  Expiry = t.Union[Duration, ISODateTime]  MemberOrUser = t.Union[discord.Member, discord.User] diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index ef07c77a1..8ce25b4e8 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,47 +2,16 @@ import inspect  from pathlib import Path  from typing import Optional, Tuple, Union -from discord import Embed, utils +from discord import Embed  from discord.ext import commands  from bot.bot import Bot  from bot.constants import URLs +from bot.converters import SourceConverter  SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] -class SourceConverter(commands.Converter): -    """Convert an argument into a help command, tag, command, or cog.""" - -    @staticmethod -    async def convert(ctx: commands.Context, argument: str) -> SourceType: -        """Convert argument into source object.""" -        if argument.lower() == "help": -            return ctx.bot.help_command - -        cog = ctx.bot.get_cog(argument) -        if cog: -            return cog - -        cmd = ctx.bot.get_command(argument) -        if cmd: -            return cmd - -        tags_cog = ctx.bot.get_cog("Tags") -        show_tag = True - -        if not tags_cog: -            show_tag = False -        elif argument.lower() in tags_cog._cache: -            return argument.lower() - -        escaped_arg = utils.escape_markdown(argument) - -        raise commands.BadArgument( -            f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." -        ) - -  class BotSource(commands.Cog):      """Displays information about the bot's source code.""" diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index 29a5c1c8e..80c9f0c38 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -87,7 +87,6 @@ class ModPings(Cog):          The duration cannot be longer than 30 days.          """ -        duration: datetime.datetime          delta = duration - datetime.datetime.utcnow()          if delta > datetime.timedelta(days=30):              await ctx.send(":x: Cannot remove the role for longer than 30 days.") diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 8025f3df6..95e2792c3 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -202,8 +202,6 @@ class Silence(commands.Cog):          duration: HushDurationConverter      ) -> typing.Tuple[TextOrVoiceChannel, Optional[int]]:          """Helper method to parse the arguments of the silence command.""" -        duration: Optional[int] -          if duration_or_channel:              if isinstance(duration_or_channel, (TextChannel, VoiceChannel)):                  channel = duration_or_channel diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 8a1ed98f4..f78664527 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -10,8 +10,9 @@ from discord.ext.commands import Context, group  from bot import exts  from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.converters import Extension  from bot.pagination import LinePaginator -from bot.utils.extensions import EXTENSIONS, unqualify +from bot.utils.extensions import EXTENSIONS  log = logging.getLogger(__name__) @@ -29,44 +30,6 @@ class Action(Enum):      RELOAD = functools.partial(Bot.reload_extension) -class Extension(commands.Converter): -    """ -    Fully qualify the name of an extension and ensure it exists. - -    The * and ** values bypass this when used with the reload command. -    """ - -    async def convert(self, ctx: Context, argument: str) -> str: -        """Fully qualify the name of an extension and ensure it exists.""" -        # Special values to reload all extensions -        if argument == "*" or argument == "**": -            return argument - -        argument = argument.lower() - -        if argument in EXTENSIONS: -            return argument -        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: -            return qualified_arg - -        matches = [] -        for ext in EXTENSIONS: -            if argument == unqualify(ext): -                matches.append(ext) - -        if len(matches) > 1: -            matches.sort() -            names = "\n".join(matches) -            raise commands.BadArgument( -                f":x: `{argument}` is an ambiguous extension name. " -                f"Please use one of the following fully-qualified names.```\n{names}```" -            ) -        elif matches: -            return matches[0] -        else: -            raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - -  class Extensions(commands.Cog):      """Extension management commands.""" | 
