diff options
author | 2019-12-11 20:26:26 -0800 | |
---|---|---|
committer | 2019-12-11 20:29:15 -0800 | |
commit | 9d551cc69c1935165389f26f52753895604dd3f5 (patch) | |
tree | ed9e9e8ce474c7bd1656c9146469cf21b93f1976 | |
parent | Merge pull request #682 from manusaurio/master (diff) |
Add a generic converter for only allowing certain string values
-rw-r--r-- | bot/cogs/moderation/management.py | 13 | ||||
-rw-r--r-- | bot/converters.py | 23 |
2 files changed, 23 insertions, 13 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..50bce3981 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, string from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -22,15 +22,6 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" @@ -61,7 +52,7 @@ class ModManagement(commands.Cog): self, ctx: Context, infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + duration: t.Union[utils.Expiry, string("permanent"), None], *, reason: str = None ) -> None: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..2cfc42903 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converter to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") |