aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2019-12-11 20:26:26 -0800
committerGravatar MarkKoz <[email protected]>2019-12-11 20:29:15 -0800
commit9d551cc69c1935165389f26f52753895604dd3f5 (patch)
treeed9e9e8ce474c7bd1656c9146469cf21b93f1976
parentMerge pull request #682 from manusaurio/master (diff)
Add a generic converter for only allowing certain string values
-rw-r--r--bot/cogs/moderation/management.py13
-rw-r--r--bot/converters.py23
2 files changed, 23 insertions, 13 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index abfe5c2b3..50bce3981 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -9,7 +9,7 @@ from discord.ext import commands
from discord.ext.commands import Context
from bot import constants
-from bot.converters import InfractionSearchQuery
+from bot.converters import InfractionSearchQuery, string
from bot.pagination import LinePaginator
from bot.utils import time
from bot.utils.checks import in_channel_check, with_role_check
@@ -22,15 +22,6 @@ log = logging.getLogger(__name__)
UserConverter = t.Union[discord.User, utils.proxy_user]
-def permanent_duration(expires_at: str) -> str:
- """Only allow an expiration to be 'permanent' if it is a string."""
- expires_at = expires_at.lower()
- if expires_at != "permanent":
- raise commands.BadArgument
- else:
- return expires_at
-
-
class ModManagement(commands.Cog):
"""Management of infractions."""
@@ -61,7 +52,7 @@ class ModManagement(commands.Cog):
self,
ctx: Context,
infraction_id: int,
- duration: t.Union[utils.Expiry, permanent_duration, None],
+ duration: t.Union[utils.Expiry, string("permanent"), None],
*,
reason: str = None
) -> None:
diff --git a/bot/converters.py b/bot/converters.py
index cf0496541..2cfc42903 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,8 +1,8 @@
import logging
import re
+import typing as t
from datetime import datetime
from ssl import CertificateError
-from typing import Union
import dateutil.parser
import dateutil.tz
@@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter
log = logging.getLogger(__name__)
+def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
+ """
+ Return a converter which only allows arguments equal to one of the given values.
+
+ Unless preserve_case is True, the argument is converter to lowercase. All values are then
+ expected to have already been given in lowercase too.
+ """
+ def converter(arg: str) -> str:
+ if not preserve_case:
+ arg = arg.lower()
+
+ if arg not in values:
+ raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```")
+ else:
+ return arg
+
+ return converter
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):
"""A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
@staticmethod
- async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]:
+ async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
"""Check if the argument is a Discord user, and if not, falls back to a string."""
try:
maybe_snowflake = arg.strip("<@!>")