diff options
| author | 2019-12-11 20:26:26 -0800 | |
|---|---|---|
| committer | 2019-12-11 20:29:15 -0800 | |
| commit | 9d551cc69c1935165389f26f52753895604dd3f5 (patch) | |
| tree | ed9e9e8ce474c7bd1656c9146469cf21b93f1976 | |
| parent | Merge pull request #682 from manusaurio/master (diff) | |
Add a generic converter for only allowing certain string values
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/moderation/management.py | 13 | ||||
| -rw-r--r-- | bot/converters.py | 23 | 
2 files changed, 23 insertions, 13 deletions
| diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..50bce3981 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands  from discord.ext.commands import Context  from bot import constants -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, string  from bot.pagination import LinePaginator  from bot.utils import time  from bot.utils.checks import in_channel_check, with_role_check @@ -22,15 +22,6 @@ log = logging.getLogger(__name__)  UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: -    """Only allow an expiration to be 'permanent' if it is a string.""" -    expires_at = expires_at.lower() -    if expires_at != "permanent": -        raise commands.BadArgument -    else: -        return expires_at - -  class ModManagement(commands.Cog):      """Management of infractions.""" @@ -61,7 +52,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction_id: int, -        duration: t.Union[utils.Expiry, permanent_duration, None], +        duration: t.Union[utils.Expiry, string("permanent"), None],          *,          reason: str = None      ) -> None: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..2cfc42903 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@  import logging  import re +import typing as t  from datetime import datetime  from ssl import CertificateError -from typing import Union  import dateutil.parser  import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter  log = logging.getLogger(__name__) +def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +    """ +    Return a converter which only allows arguments equal to one of the given values. + +    Unless preserve_case is True, the argument is converter to lowercase. All values are then +    expected to have already been given in lowercase too. +    """ +    def converter(arg: str) -> str: +        if not preserve_case: +            arg = arg.lower() + +        if arg not in values: +            raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") +        else: +            return arg + +    return converter + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):      """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""      @staticmethod -    async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: +    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:          """Check if the argument is a Discord user, and if not, falls back to a string."""          try:              maybe_snowflake = arg.strip("<@!>") | 
