diff options
author | 2021-08-24 18:29:38 +0100 | |
---|---|---|
committer | 2021-08-24 18:29:38 +0100 | |
commit | 1782cfa6f9cfcb0d395521d361375f53dd55c091 (patch) | |
tree | 079a5b8a28aca2273350e4071db06d829bd3888a | |
parent | Merge pull request #1776 from python-discord/community-partners-access (diff) | |
parent | Merge branch 'main' into converter-typehints (diff) |
Merge pull request #1731 from Numerlor/converter-typehints
Add converter typehints
-rw-r--r-- | bot/converters.py | 101 | ||||
-rw-r--r-- | bot/errors.py | 6 | ||||
-rw-r--r-- | bot/exts/info/source.py | 35 | ||||
-rw-r--r-- | bot/exts/moderation/modpings.py | 1 | ||||
-rw-r--r-- | bot/exts/moderation/silence.py | 2 | ||||
-rw-r--r-- | bot/exts/utils/extensions.py | 41 | ||||
-rw-r--r-- | tests/bot/test_converters.py | 38 |
7 files changed, 96 insertions, 128 deletions
diff --git a/bot/converters.py b/bot/converters.py index 37eb91c7f..0118cc48a 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import re import typing as t @@ -10,13 +12,17 @@ import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter -from discord.utils import DISCORD_EPOCH, snowflake_time +from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time +from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs from bot.exts.info.doc import _inventory_parser +from bot.utils.extensions import EXTENSIONS, unqualify from bot.utils.regex import INVITE_RE from bot.utils.time import parse_duration_string +if t.TYPE_CHECKING: + from bot.exts.info.source import SourceType log = logging.getLogger(__name__) @@ -127,6 +133,44 @@ class ValidFilterListType(Converter): return list_type +class Extension(Converter): + """ + Fully qualify the name of an extension and ensure it exists. + + The * and ** values bypass this when used with the reload command. + """ + + async def convert(self, ctx: Context, argument: str) -> str: + """Fully qualify the name of an extension and ensure it exists.""" + # Special values to reload all extensions + if argument == "*" or argument == "**": + return argument + + argument = argument.lower() + + if argument in EXTENSIONS: + return argument + elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + return qualified_arg + + matches = [] + for ext in EXTENSIONS: + if argument == unqualify(ext): + matches.append(ext) + + if len(matches) > 1: + matches.sort() + names = "\n".join(matches) + raise BadArgument( + f":x: `{argument}` is an ambiguous extension name. " + f"Please use one of the following fully-qualified names.```\n{names}```" + ) + elif matches: + return matches[0] + else: + raise BadArgument(f":x: Could not find the extension `{argument}`.") + + class PackageName(Converter): """ A converter that checks whether the given string is a valid package name. @@ -270,23 +314,36 @@ class TagNameConverter(Converter): return tag_name -class TagContentConverter(Converter): - """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +class SourceConverter(Converter): + """Convert an argument into a help command, tag, command, or cog.""" @staticmethod - async def convert(ctx: Context, tag_content: str) -> str: - """ - Ensure tag_content is non-empty and contains at least one non-whitespace character. + async def convert(ctx: Context, argument: str) -> SourceType: + """Convert argument into source object.""" + if argument.lower() == "help": + return ctx.bot.help_command - If tag_content is valid, return the stripped version. - """ - tag_content = tag_content.strip() + cog = ctx.bot.get_cog(argument) + if cog: + return cog - # The tag contents should not be empty, or filled with whitespace. - if not tag_content: - raise BadArgument("Tag contents should not be empty, or filled with whitespace.") + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd - return tag_content + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + elif argument.lower() in tags_cog._cache: + return argument.lower() + + escaped_arg = escape_markdown(argument) + + raise BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) class DurationDelta(Converter): @@ -485,5 +542,23 @@ class Infraction(Converter): return await ctx.bot.api_client.get(f"bot/infractions/{arg}") +if t.TYPE_CHECKING: + ValidDiscordServerInvite = dict # noqa: F811 + ValidFilterListType = str # noqa: F811 + Extension = str # noqa: F811 + PackageName = str # noqa: F811 + ValidURL = str # noqa: F811 + Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 + Snowflake = int # noqa: F811 + TagNameConverter = str # noqa: F811 + SourceConverter = SourceType # noqa: F811 + DurationDelta = relativedelta # noqa: F811 + Duration = datetime # noqa: F811 + OffTopicName = str # noqa: F811 + ISODateTime = datetime # noqa: F811 + HushDurationConverter = int # noqa: F811 + UserMentionOrID = discord.User # noqa: F811 + Infraction = t.Optional[dict] # noqa: F811 + Expiry = t.Union[Duration, ISODateTime] MemberOrUser = t.Union[discord.Member, discord.User] diff --git a/bot/errors.py b/bot/errors.py index 08396ec3e..2633390a8 100644 --- a/bot/errors.py +++ b/bot/errors.py @@ -1,6 +1,8 @@ -from typing import Hashable +from __future__ import annotations -from bot.converters import MemberOrUser +from typing import Hashable, TYPE_CHECKING +if TYPE_CHECKING: + from bot.converters import MemberOrUser class LockedResourceError(RuntimeError): diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index ef07c77a1..8ce25b4e8 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -2,47 +2,16 @@ import inspect from pathlib import Path from typing import Optional, Tuple, Union -from discord import Embed, utils +from discord import Embed from discord.ext import commands from bot.bot import Bot from bot.constants import URLs +from bot.converters import SourceConverter SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] -class SourceConverter(commands.Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - @staticmethod - async def convert(ctx: commands.Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower() == "help": - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - - escaped_arg = utils.escape_markdown(argument) - - raise commands.BadArgument( - f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." - ) - - class BotSource(commands.Cog): """Displays information about the bot's source code.""" diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py index 29a5c1c8e..80c9f0c38 100644 --- a/bot/exts/moderation/modpings.py +++ b/bot/exts/moderation/modpings.py @@ -87,7 +87,6 @@ class ModPings(Cog): The duration cannot be longer than 30 days. """ - duration: datetime.datetime delta = duration - datetime.datetime.utcnow() if delta > datetime.timedelta(days=30): await ctx.send(":x: Cannot remove the role for longer than 30 days.") diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 8025f3df6..95e2792c3 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -202,8 +202,6 @@ class Silence(commands.Cog): duration: HushDurationConverter ) -> typing.Tuple[TextOrVoiceChannel, Optional[int]]: """Helper method to parse the arguments of the silence command.""" - duration: Optional[int] - if duration_or_channel: if isinstance(duration_or_channel, (TextChannel, VoiceChannel)): channel = duration_or_channel diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 8a1ed98f4..f78664527 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -10,8 +10,9 @@ from discord.ext.commands import Context, group from bot import exts from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.converters import Extension from bot.pagination import LinePaginator -from bot.utils.extensions import EXTENSIONS, unqualify +from bot.utils.extensions import EXTENSIONS log = logging.getLogger(__name__) @@ -29,44 +30,6 @@ class Action(Enum): RELOAD = functools.partial(Bot.reload_extension) -class Extension(commands.Converter): - """ - Fully qualify the name of an extension and ensure it exists. - - The * and ** values bypass this when used with the reload command. - """ - - async def convert(self, ctx: Context, argument: str) -> str: - """Fully qualify the name of an extension and ensure it exists.""" - # Special values to reload all extensions - if argument == "*" or argument == "**": - return argument - - argument = argument.lower() - - if argument in EXTENSIONS: - return argument - elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: - return qualified_arg - - matches = [] - for ext in EXTENSIONS: - if argument == unqualify(ext): - matches.append(ext) - - if len(matches) > 1: - matches.sort() - names = "\n".join(matches) - raise commands.BadArgument( - f":x: `{argument}` is an ambiguous extension name. " - f"Please use one of the following fully-qualified names.```\n{names}```" - ) - elif matches: - return matches[0] - else: - raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") - - class Extensions(commands.Cog): """Extension management commands.""" diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 2a1c4e543..6e3a6b898 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -11,7 +11,6 @@ from bot.converters import ( HushDurationConverter, ISODateTime, PackageName, - TagContentConverter, TagNameConverter, ) @@ -26,43 +25,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') - async def test_tag_content_converter_for_valid(self): - """TagContentConverter should return correct values for valid input.""" - test_values = ( - ('hello', 'hello'), - (' h ello ', 'h ello'), - ) - - for content, expected_conversion in test_values: - with self.subTest(content=content, expected_conversion=expected_conversion): - conversion = await TagContentConverter.convert(self.context, content) - self.assertEqual(conversion, expected_conversion) - - async def test_tag_content_converter_for_invalid(self): - """TagContentConverter should raise the proper exception for invalid input.""" - test_values = ( - ('', "Tag contents should not be empty, or filled with whitespace."), - (' ', "Tag contents should not be empty, or filled with whitespace."), - ) - - for value, exception_message in test_values: - with self.subTest(tag_content=value, exception_message=exception_message): - with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): - await TagContentConverter.convert(self.context, value) - - async def test_tag_name_converter_for_valid(self): - """TagNameConverter should return the correct values for valid tag names.""" - test_values = ( - ('tracebacks', 'tracebacks'), - ('Tracebacks', 'tracebacks'), - (' Tracebacks ', 'tracebacks'), - ) - - for name, expected_conversion in test_values: - with self.subTest(name=name, expected_conversion=expected_conversion): - conversion = await TagNameConverter.convert(self.context, name) - self.assertEqual(conversion, expected_conversion) - async def test_tag_name_converter_for_invalid(self): """TagNameConverter should raise the correct exception for invalid tag names.""" test_values = ( |