aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/converters.py101
-rw-r--r--bot/errors.py6
-rw-r--r--bot/exts/info/source.py35
-rw-r--r--bot/exts/moderation/modpings.py1
-rw-r--r--bot/exts/moderation/silence.py2
-rw-r--r--bot/exts/utils/extensions.py41
-rw-r--r--tests/bot/test_converters.py38
7 files changed, 96 insertions, 128 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 37eb91c7f..0118cc48a 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import re
import typing as t
@@ -10,13 +12,17 @@ import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
-from discord.utils import DISCORD_EPOCH, snowflake_time
+from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time
+from bot import exts
from bot.api import ResponseCodeError
from bot.constants import URLs
from bot.exts.info.doc import _inventory_parser
+from bot.utils.extensions import EXTENSIONS, unqualify
from bot.utils.regex import INVITE_RE
from bot.utils.time import parse_duration_string
+if t.TYPE_CHECKING:
+ from bot.exts.info.source import SourceType
log = logging.getLogger(__name__)
@@ -127,6 +133,44 @@ class ValidFilterListType(Converter):
return list_type
+class Extension(Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if argument == "*" or argument == "**":
+ return argument
+
+ argument = argument.lower()
+
+ if argument in EXTENSIONS:
+ return argument
+ elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS:
+ return qualified_arg
+
+ matches = []
+ for ext in EXTENSIONS:
+ if argument == unqualify(ext):
+ matches.append(ext)
+
+ if len(matches) > 1:
+ matches.sort()
+ names = "\n".join(matches)
+ raise BadArgument(
+ f":x: `{argument}` is an ambiguous extension name. "
+ f"Please use one of the following fully-qualified names.```\n{names}```"
+ )
+ elif matches:
+ return matches[0]
+ else:
+ raise BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
class PackageName(Converter):
"""
A converter that checks whether the given string is a valid package name.
@@ -270,23 +314,36 @@ class TagNameConverter(Converter):
return tag_name
-class TagContentConverter(Converter):
- """Ensure proposed tag content is not empty and contains at least one non-whitespace character."""
+class SourceConverter(Converter):
+ """Convert an argument into a help command, tag, command, or cog."""
@staticmethod
- async def convert(ctx: Context, tag_content: str) -> str:
- """
- Ensure tag_content is non-empty and contains at least one non-whitespace character.
+ async def convert(ctx: Context, argument: str) -> SourceType:
+ """Convert argument into source object."""
+ if argument.lower() == "help":
+ return ctx.bot.help_command
- If tag_content is valid, return the stripped version.
- """
- tag_content = tag_content.strip()
+ cog = ctx.bot.get_cog(argument)
+ if cog:
+ return cog
- # The tag contents should not be empty, or filled with whitespace.
- if not tag_content:
- raise BadArgument("Tag contents should not be empty, or filled with whitespace.")
+ cmd = ctx.bot.get_command(argument)
+ if cmd:
+ return cmd
- return tag_content
+ tags_cog = ctx.bot.get_cog("Tags")
+ show_tag = True
+
+ if not tags_cog:
+ show_tag = False
+ elif argument.lower() in tags_cog._cache:
+ return argument.lower()
+
+ escaped_arg = escape_markdown(argument)
+
+ raise BadArgument(
+ f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
+ )
class DurationDelta(Converter):
@@ -485,5 +542,23 @@ class Infraction(Converter):
return await ctx.bot.api_client.get(f"bot/infractions/{arg}")
+if t.TYPE_CHECKING:
+ ValidDiscordServerInvite = dict # noqa: F811
+ ValidFilterListType = str # noqa: F811
+ Extension = str # noqa: F811
+ PackageName = str # noqa: F811
+ ValidURL = str # noqa: F811
+ Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811
+ Snowflake = int # noqa: F811
+ TagNameConverter = str # noqa: F811
+ SourceConverter = SourceType # noqa: F811
+ DurationDelta = relativedelta # noqa: F811
+ Duration = datetime # noqa: F811
+ OffTopicName = str # noqa: F811
+ ISODateTime = datetime # noqa: F811
+ HushDurationConverter = int # noqa: F811
+ UserMentionOrID = discord.User # noqa: F811
+ Infraction = t.Optional[dict] # noqa: F811
+
Expiry = t.Union[Duration, ISODateTime]
MemberOrUser = t.Union[discord.Member, discord.User]
diff --git a/bot/errors.py b/bot/errors.py
index 08396ec3e..2633390a8 100644
--- a/bot/errors.py
+++ b/bot/errors.py
@@ -1,6 +1,8 @@
-from typing import Hashable
+from __future__ import annotations
-from bot.converters import MemberOrUser
+from typing import Hashable, TYPE_CHECKING
+if TYPE_CHECKING:
+ from bot.converters import MemberOrUser
class LockedResourceError(RuntimeError):
diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py
index ef07c77a1..8ce25b4e8 100644
--- a/bot/exts/info/source.py
+++ b/bot/exts/info/source.py
@@ -2,47 +2,16 @@ import inspect
from pathlib import Path
from typing import Optional, Tuple, Union
-from discord import Embed, utils
+from discord import Embed
from discord.ext import commands
from bot.bot import Bot
from bot.constants import URLs
+from bot.converters import SourceConverter
SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded]
-class SourceConverter(commands.Converter):
- """Convert an argument into a help command, tag, command, or cog."""
-
- @staticmethod
- async def convert(ctx: commands.Context, argument: str) -> SourceType:
- """Convert argument into source object."""
- if argument.lower() == "help":
- return ctx.bot.help_command
-
- cog = ctx.bot.get_cog(argument)
- if cog:
- return cog
-
- cmd = ctx.bot.get_command(argument)
- if cmd:
- return cmd
-
- tags_cog = ctx.bot.get_cog("Tags")
- show_tag = True
-
- if not tags_cog:
- show_tag = False
- elif argument.lower() in tags_cog._cache:
- return argument.lower()
-
- escaped_arg = utils.escape_markdown(argument)
-
- raise commands.BadArgument(
- f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
- )
-
-
class BotSource(commands.Cog):
"""Displays information about the bot's source code."""
diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py
index 29a5c1c8e..80c9f0c38 100644
--- a/bot/exts/moderation/modpings.py
+++ b/bot/exts/moderation/modpings.py
@@ -87,7 +87,6 @@ class ModPings(Cog):
The duration cannot be longer than 30 days.
"""
- duration: datetime.datetime
delta = duration - datetime.datetime.utcnow()
if delta > datetime.timedelta(days=30):
await ctx.send(":x: Cannot remove the role for longer than 30 days.")
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index 8025f3df6..95e2792c3 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -202,8 +202,6 @@ class Silence(commands.Cog):
duration: HushDurationConverter
) -> typing.Tuple[TextOrVoiceChannel, Optional[int]]:
"""Helper method to parse the arguments of the silence command."""
- duration: Optional[int]
-
if duration_or_channel:
if isinstance(duration_or_channel, (TextChannel, VoiceChannel)):
channel = duration_or_channel
diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py
index 8a1ed98f4..f78664527 100644
--- a/bot/exts/utils/extensions.py
+++ b/bot/exts/utils/extensions.py
@@ -10,8 +10,9 @@ from discord.ext.commands import Context, group
from bot import exts
from bot.bot import Bot
from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
+from bot.converters import Extension
from bot.pagination import LinePaginator
-from bot.utils.extensions import EXTENSIONS, unqualify
+from bot.utils.extensions import EXTENSIONS
log = logging.getLogger(__name__)
@@ -29,44 +30,6 @@ class Action(Enum):
RELOAD = functools.partial(Bot.reload_extension)
-class Extension(commands.Converter):
- """
- Fully qualify the name of an extension and ensure it exists.
-
- The * and ** values bypass this when used with the reload command.
- """
-
- async def convert(self, ctx: Context, argument: str) -> str:
- """Fully qualify the name of an extension and ensure it exists."""
- # Special values to reload all extensions
- if argument == "*" or argument == "**":
- return argument
-
- argument = argument.lower()
-
- if argument in EXTENSIONS:
- return argument
- elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS:
- return qualified_arg
-
- matches = []
- for ext in EXTENSIONS:
- if argument == unqualify(ext):
- matches.append(ext)
-
- if len(matches) > 1:
- matches.sort()
- names = "\n".join(matches)
- raise commands.BadArgument(
- f":x: `{argument}` is an ambiguous extension name. "
- f"Please use one of the following fully-qualified names.```\n{names}```"
- )
- elif matches:
- return matches[0]
- else:
- raise commands.BadArgument(f":x: Could not find the extension `{argument}`.")
-
-
class Extensions(commands.Cog):
"""Extension management commands."""
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index 2a1c4e543..6e3a6b898 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -11,7 +11,6 @@ from bot.converters import (
HushDurationConverter,
ISODateTime,
PackageName,
- TagContentConverter,
TagNameConverter,
)
@@ -26,43 +25,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
- async def test_tag_content_converter_for_valid(self):
- """TagContentConverter should return correct values for valid input."""
- test_values = (
- ('hello', 'hello'),
- (' h ello ', 'h ello'),
- )
-
- for content, expected_conversion in test_values:
- with self.subTest(content=content, expected_conversion=expected_conversion):
- conversion = await TagContentConverter.convert(self.context, content)
- self.assertEqual(conversion, expected_conversion)
-
- async def test_tag_content_converter_for_invalid(self):
- """TagContentConverter should raise the proper exception for invalid input."""
- test_values = (
- ('', "Tag contents should not be empty, or filled with whitespace."),
- (' ', "Tag contents should not be empty, or filled with whitespace."),
- )
-
- for value, exception_message in test_values:
- with self.subTest(tag_content=value, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
- await TagContentConverter.convert(self.context, value)
-
- async def test_tag_name_converter_for_valid(self):
- """TagNameConverter should return the correct values for valid tag names."""
- test_values = (
- ('tracebacks', 'tracebacks'),
- ('Tracebacks', 'tracebacks'),
- (' Tracebacks ', 'tracebacks'),
- )
-
- for name, expected_conversion in test_values:
- with self.subTest(name=name, expected_conversion=expected_conversion):
- conversion = await TagNameConverter.convert(self.context, name)
- self.assertEqual(conversion, expected_conversion)
-
async def test_tag_name_converter_for_invalid(self):
"""TagNameConverter should raise the correct exception for invalid tag names."""
test_values = (