diff options
| -rw-r--r-- | bot/cogs/moderation/management.py | 45 | ||||
| -rw-r--r-- | bot/converters.py | 23 | 
2 files changed, 49 insertions, 19 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index feae00b7c..180d5219f 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context  from bot import constants  from bot.bot import Bot -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, allowed_strings  from bot.pagination import LinePaginator  from bot.utils import time  from bot.utils.checks import in_channel_check, with_role_check @@ -23,15 +23,6 @@ log = logging.getLogger(__name__)  UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: -    """Only allow an expiration to be 'permanent' if it is a string.""" -    expires_at = expires_at.lower() -    if expires_at != "permanent": -        raise commands.BadArgument -    else: -        return expires_at - -  class ModManagement(commands.Cog):      """Management of infractions.""" @@ -61,8 +52,8 @@ class ModManagement(commands.Cog):      async def infraction_edit(          self,          ctx: Context, -        infraction_id: int, -        duration: t.Union[utils.Expiry, permanent_duration, None], +        infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], +        duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],          *,          reason: str = None      ) -> None: @@ -79,21 +70,40 @@ class ModManagement(commands.Cog):          \u2003`M` - minutesā          \u2003`s` - seconds -        Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp -        can be provided for the duration. +        Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction +        authored by the command invoker should be edited. + +        Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 +        timestamp can be provided for the duration.          """          if duration is None and reason is None:              # Unlike UserInputError, the error handler will show a specified message for BadArgument              raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")          # Retrieve the previous infraction for its information. -        old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') +        if isinstance(infraction_id, str): +            params = { +                "actor__id": ctx.author.id, +                "ordering": "-inserted_at" +            } +            infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + +            if infractions: +                old_infraction = infractions[0] +                infraction_id = old_infraction["id"] +            else: +                await ctx.send( +                    f":x: Couldn't find most recent infraction; you have never given an infraction." +                ) +                return +        else: +            old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")          request_data = {}          confirm_messages = []          log_text = "" -        if duration == "permanent": +        if isinstance(duration, str):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: @@ -130,7 +140,8 @@ class ModManagement(commands.Cog):                  New expiry: {new_infraction['expires_at'] or "Permanent"}              """.rstrip() -        await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") +        changes = ' & '.join(confirm_messages) +        await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")          # Get information about the infraction's user          user_id = new_infraction['user'] diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@  import logging  import re +import typing as t  from datetime import datetime  from ssl import CertificateError -from typing import Union  import dateutil.parser  import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter  log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +    """ +    Return a converter which only allows arguments equal to one of the given values. + +    Unless preserve_case is True, the argument is converted to lowercase. All values are then +    expected to have already been given in lowercase too. +    """ +    def converter(arg: str) -> str: +        if not preserve_case: +            arg = arg.lower() + +        if arg not in values: +            raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") +        else: +            return arg + +    return converter + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):      """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""      @staticmethod -    async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: +    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:          """Check if the argument is a Discord user, and if not, falls back to a string."""          try:              maybe_snowflake = arg.strip("<@!>")  |