diff options
author | 2019-12-12 01:21:58 -0800 | |
---|---|---|
committer | 2019-12-12 01:21:58 -0800 | |
commit | 5bfef1d683ddcf52924c7323ccb8950c6877312f (patch) | |
tree | 8313452e2a9b9ae34fee09f33742d99ded354226 | |
parent | Token and bad code (#500) (diff) | |
parent | Merge branch 'master' into enh/mod/624/edit-recent-infraction (diff) |
Merge pull request #693 from python-discord/enh/mod/624/edit-recent-infraction
Allow "recent" as infraction ID for infraction edit command
-rw-r--r-- | bot/cogs/moderation/management.py | 45 | ||||
-rw-r--r-- | bot/converters.py | 23 |
2 files changed, 49 insertions, 19 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index feae00b7c..180d5219f 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context from bot import constants from bot.bot import Bot -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -23,15 +23,6 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" @@ -61,8 +52,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: @@ -79,21 +70,40 @@ class ModManagement(commands.Cog): \u2003`M` - minutesā \u2003`s` - seconds - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: @@ -130,7 +140,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converted to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") |