aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Mark <[email protected]>2019-12-12 01:21:58 -0800
committerGravatar GitHub <[email protected]>2019-12-12 01:21:58 -0800
commit5bfef1d683ddcf52924c7323ccb8950c6877312f (patch)
tree8313452e2a9b9ae34fee09f33742d99ded354226
parentToken and bad code (#500) (diff)
parentMerge branch 'master' into enh/mod/624/edit-recent-infraction (diff)
Merge pull request #693 from python-discord/enh/mod/624/edit-recent-infraction
Allow "recent" as infraction ID for infraction edit command
-rw-r--r--bot/cogs/moderation/management.py45
-rw-r--r--bot/converters.py23
2 files changed, 49 insertions, 19 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index feae00b7c..180d5219f 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -10,7 +10,7 @@ from discord.ext.commands import Context
from bot import constants
from bot.bot import Bot
-from bot.converters import InfractionSearchQuery
+from bot.converters import InfractionSearchQuery, allowed_strings
from bot.pagination import LinePaginator
from bot.utils import time
from bot.utils.checks import in_channel_check, with_role_check
@@ -23,15 +23,6 @@ log = logging.getLogger(__name__)
UserConverter = t.Union[discord.User, utils.proxy_user]
-def permanent_duration(expires_at: str) -> str:
- """Only allow an expiration to be 'permanent' if it is a string."""
- expires_at = expires_at.lower()
- if expires_at != "permanent":
- raise commands.BadArgument
- else:
- return expires_at
-
-
class ModManagement(commands.Cog):
"""Management of infractions."""
@@ -61,8 +52,8 @@ class ModManagement(commands.Cog):
async def infraction_edit(
self,
ctx: Context,
- infraction_id: int,
- duration: t.Union[utils.Expiry, permanent_duration, None],
+ infraction_id: t.Union[int, allowed_strings("l", "last", "recent")],
+ duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],
*,
reason: str = None
) -> None:
@@ -79,21 +70,40 @@ class ModManagement(commands.Cog):
\u2003`M` - minutesāˆ—
\u2003`s` - seconds
- Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp
- can be provided for the duration.
+ Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction
+ authored by the command invoker should be edited.
+
+ Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601
+ timestamp can be provided for the duration.
"""
if duration is None and reason is None:
# Unlike UserInputError, the error handler will show a specified message for BadArgument
raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")
# Retrieve the previous infraction for its information.
- old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}')
+ if isinstance(infraction_id, str):
+ params = {
+ "actor__id": ctx.author.id,
+ "ordering": "-inserted_at"
+ }
+ infractions = await self.bot.api_client.get(f"bot/infractions", params=params)
+
+ if infractions:
+ old_infraction = infractions[0]
+ infraction_id = old_infraction["id"]
+ else:
+ await ctx.send(
+ f":x: Couldn't find most recent infraction; you have never given an infraction."
+ )
+ return
+ else:
+ old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")
request_data = {}
confirm_messages = []
log_text = ""
- if duration == "permanent":
+ if isinstance(duration, str):
request_data['expires_at'] = None
confirm_messages.append("marked as permanent")
elif duration is not None:
@@ -130,7 +140,8 @@ class ModManagement(commands.Cog):
New expiry: {new_infraction['expires_at'] or "Permanent"}
""".rstrip()
- await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}")
+ changes = ' & '.join(confirm_messages)
+ await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")
# Get information about the infraction's user
user_id = new_infraction['user']
diff --git a/bot/converters.py b/bot/converters.py
index cf0496541..8d2ab7eb8 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,8 +1,8 @@
import logging
import re
+import typing as t
from datetime import datetime
from ssl import CertificateError
-from typing import Union
import dateutil.parser
import dateutil.tz
@@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter
log = logging.getLogger(__name__)
+def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
+ """
+ Return a converter which only allows arguments equal to one of the given values.
+
+ Unless preserve_case is True, the argument is converted to lowercase. All values are then
+ expected to have already been given in lowercase too.
+ """
+ def converter(arg: str) -> str:
+ if not preserve_case:
+ arg = arg.lower()
+
+ if arg not in values:
+ raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```")
+ else:
+ return arg
+
+ return converter
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):
"""A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
@staticmethod
- async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]:
+ async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
"""Check if the argument is a Discord user, and if not, falls back to a string."""
try:
maybe_snowflake = arg.strip("<@!>")