aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Shakya Majumdar <[email protected]>2022-03-01 18:47:08 +0530
committerGravatar Shakya Majumdar <[email protected]>2022-03-01 18:47:08 +0530
commitcc270dfa60ffed5e6251c9cfec82c69249c2931c (patch)
tree638740fe742ab6541add2f767bea33789a0680b5
parentremove unused imports (diff)
move duration check to a decorator
-rw-r--r--bot/decorators.py30
-rw-r--r--bot/exts/moderation/infraction/_scheduler.py6
-rw-r--r--bot/exts/moderation/infraction/infractions.py9
-rw-r--r--bot/exts/moderation/infraction/management.py7
4 files changed, 40 insertions, 12 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index f4331264f..dd001d3ca 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -3,7 +3,9 @@ import functools
import types
import typing as t
from contextlib import suppress
+from datetime import datetime
+import arrow
from discord import Member, NotFound
from discord.ext import commands
from discord.ext.commands import Cog, Context
@@ -236,3 +238,31 @@ def mock_in_debug(return_value: t.Any) -> t.Callable:
return await func(*args, **kwargs)
return wrapped
return decorator
+
+
+def ensure_duration_in_future(duration_arg: function.Argument) -> t.Callable:
+ """
+ Ensure the duration argument is in the future.
+
+ If the condition fails, a warning is sent to the invoking context.
+
+ `duration_arg` is the keyword name or position index of the parameter of the decorated command
+ whose value is the target duration.
+
+ This decorator must go before (below) the `command` decorator.
+ """
+ def decorator(func: types.FunctionType) -> types.FunctionType:
+ @command_wraps(func)
+ async def wrapper(*args, **kwargs) -> t.Any:
+ bound_args = function.get_bound_args(func, args, kwargs)
+ target = function.get_arg_value(duration_arg, bound_args)
+
+ ctx = function.get_arg_value(1, bound_args)
+
+ if isinstance(target, datetime) and target < arrow.utcnow():
+ await ctx.send(":x: Expiration is in the past.")
+ return
+
+ return await func(*args, **kwargs)
+ return wrapper
+ return decorator
diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py
index e607bf752..47b639421 100644
--- a/bot/exts/moderation/infraction/_scheduler.py
+++ b/bot/exts/moderation/infraction/_scheduler.py
@@ -137,14 +137,8 @@ class InfractionScheduler:
icon = _utils.INFRACTION_ICONS[infr_type][0]
reason = infraction["reason"]
expiry = time.format_with_duration(infraction["expires_at"])
- expiry_datetime = arrow.get(infraction["expires_at"])
id_ = infraction['id']
- now_datetime = arrow.utcnow()
- if expiry_datetime < now_datetime:
- await ctx.send(":x: Expiration is in the past.")
- return False
-
if user_reason is None:
user_reason = reason
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index af42ab1b8..09610cb1a 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -10,7 +10,7 @@ from bot import constants
from bot.bot import Bot
from bot.constants import Event
from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser
-from bot.decorators import respect_role_hierarchy
+from bot.decorators import ensure_duration_in_future, respect_role_hierarchy
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._scheduler import InfractionScheduler
from bot.log import get_logger
@@ -81,6 +81,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_kick(ctx, user, reason)
@command()
+ @ensure_duration_in_future(duration_arg=3)
async def ban(
self,
ctx: Context,
@@ -97,6 +98,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_ban(ctx, user, reason, expires_at=duration)
@command(aliases=("cban", "purgeban", "pban"))
+ @ensure_duration_in_future(duration_arg=3)
async def cleanban(
self,
ctx: Context,
@@ -161,6 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `voicemute`?")
@command(aliases=("vmute",))
+ @ensure_duration_in_future(duration_arg=3)
async def voicemute(
self,
ctx: Context,
@@ -180,6 +183,7 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Temporary infractions
@command(aliases=["mute"])
+ @ensure_duration_in_future(duration_arg=3)
async def tempmute(
self, ctx: Context,
user: UnambiguousMemberOrUser,
@@ -213,6 +217,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_mute(ctx, user, reason, expires_at=duration)
@command(aliases=("tban",))
+ @ensure_duration_in_future(duration_arg=3)
async def tempban(
self,
ctx: Context,
@@ -248,6 +253,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `tempvoicemute`?")
@command(aliases=("tempvmute", "tvmute"))
+ @ensure_duration_in_future(duration_arg=3)
async def tempvoicemute(
self,
ctx: Context,
@@ -294,6 +300,7 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Temporary shadow infractions
@command(hidden=True, aliases=["shadowtempban", "stempban", "stban"])
+ @ensure_duration_in_future(duration_arg=3)
async def shadow_tempban(
self,
ctx: Context,
diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py
index a9dc231c1..4ec1039af 100644
--- a/bot/exts/moderation/infraction/management.py
+++ b/bot/exts/moderation/infraction/management.py
@@ -1,7 +1,6 @@
import textwrap
import typing as t
-import arrow
import discord
from discord.ext import commands
from discord.ext.commands import Context
@@ -10,6 +9,7 @@ from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser, allowed_strings
+from bot.decorators import ensure_duration_in_future
from bot.errors import InvalidInfraction
from bot.exts.moderation.infraction.infractions import Infractions
from bot.exts.moderation.modlog import ModLog
@@ -100,6 +100,7 @@ class ModManagement(commands.Cog):
await self.infraction_edit(ctx, infraction, duration, reason=reason)
@infraction_group.command(name='edit', aliases=('e',))
+ @ensure_duration_in_future(duration_arg=3)
async def infraction_edit(
self,
ctx: Context,
@@ -147,10 +148,6 @@ class ModManagement(commands.Cog):
request_data['expires_at'] = None
confirm_messages.append("marked as permanent")
elif duration is not None:
- now_datetime = arrow.utcnow()
- if duration < now_datetime:
- await ctx.send(":x: Expiration is in the past.")
- return
request_data['expires_at'] = duration.isoformat()
expiry = time.format_with_duration(duration)
confirm_messages.append(f"set to expire on {expiry}")