diff options
author | 2022-03-01 18:47:08 +0530 | |
---|---|---|
committer | 2022-03-01 18:47:08 +0530 | |
commit | cc270dfa60ffed5e6251c9cfec82c69249c2931c (patch) | |
tree | 638740fe742ab6541add2f767bea33789a0680b5 | |
parent | remove unused imports (diff) |
move duration check to a decorator
-rw-r--r-- | bot/decorators.py | 30 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 6 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 9 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/management.py | 7 |
4 files changed, 40 insertions, 12 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index f4331264f..dd001d3ca 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -3,7 +3,9 @@ import functools import types import typing as t from contextlib import suppress +from datetime import datetime +import arrow from discord import Member, NotFound from discord.ext import commands from discord.ext.commands import Cog, Context @@ -236,3 +238,31 @@ def mock_in_debug(return_value: t.Any) -> t.Callable: return await func(*args, **kwargs) return wrapped return decorator + + +def ensure_duration_in_future(duration_arg: function.Argument) -> t.Callable: + """ + Ensure the duration argument is in the future. + + If the condition fails, a warning is sent to the invoking context. + + `duration_arg` is the keyword name or position index of the parameter of the decorated command + whose value is the target duration. + + This decorator must go before (below) the `command` decorator. + """ + def decorator(func: types.FunctionType) -> types.FunctionType: + @command_wraps(func) + async def wrapper(*args, **kwargs) -> t.Any: + bound_args = function.get_bound_args(func, args, kwargs) + target = function.get_arg_value(duration_arg, bound_args) + + ctx = function.get_arg_value(1, bound_args) + + if isinstance(target, datetime) and target < arrow.utcnow(): + await ctx.send(":x: Expiration is in the past.") + return + + return await func(*args, **kwargs) + return wrapper + return decorator diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index e607bf752..47b639421 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -137,14 +137,8 @@ class InfractionScheduler: icon = _utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] expiry = time.format_with_duration(infraction["expires_at"]) - expiry_datetime = arrow.get(infraction["expires_at"]) id_ = infraction['id'] - now_datetime = arrow.utcnow() - if expiry_datetime < now_datetime: - await ctx.send(":x: Expiration is in the past.") - return False - if user_reason is None: user_reason = reason diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index af42ab1b8..09610cb1a 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,7 +10,7 @@ from bot import constants from bot.bot import Bot from bot.constants import Event from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser -from bot.decorators import respect_role_hierarchy +from bot.decorators import ensure_duration_in_future, respect_role_hierarchy from bot.exts.moderation.infraction import _utils from bot.exts.moderation.infraction._scheduler import InfractionScheduler from bot.log import get_logger @@ -81,6 +81,7 @@ class Infractions(InfractionScheduler, commands.Cog): await self.apply_kick(ctx, user, reason) @command() + @ensure_duration_in_future(duration_arg=3) async def ban( self, ctx: Context, @@ -97,6 +98,7 @@ class Infractions(InfractionScheduler, commands.Cog): await self.apply_ban(ctx, user, reason, expires_at=duration) @command(aliases=("cban", "purgeban", "pban")) + @ensure_duration_in_future(duration_arg=3) async def cleanban( self, ctx: Context, @@ -161,6 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog): await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `voicemute`?") @command(aliases=("vmute",)) + @ensure_duration_in_future(duration_arg=3) async def voicemute( self, ctx: Context, @@ -180,6 +183,7 @@ class Infractions(InfractionScheduler, commands.Cog): # region: Temporary infractions @command(aliases=["mute"]) + @ensure_duration_in_future(duration_arg=3) async def tempmute( self, ctx: Context, user: UnambiguousMemberOrUser, @@ -213,6 +217,7 @@ class Infractions(InfractionScheduler, commands.Cog): await self.apply_mute(ctx, user, reason, expires_at=duration) @command(aliases=("tban",)) + @ensure_duration_in_future(duration_arg=3) async def tempban( self, ctx: Context, @@ -248,6 +253,7 @@ class Infractions(InfractionScheduler, commands.Cog): await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `tempvoicemute`?") @command(aliases=("tempvmute", "tvmute")) + @ensure_duration_in_future(duration_arg=3) async def tempvoicemute( self, ctx: Context, @@ -294,6 +300,7 @@ class Infractions(InfractionScheduler, commands.Cog): # region: Temporary shadow infractions @command(hidden=True, aliases=["shadowtempban", "stempban", "stban"]) + @ensure_duration_in_future(duration_arg=3) async def shadow_tempban( self, ctx: Context, diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index a9dc231c1..4ec1039af 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,7 +1,6 @@ import textwrap import typing as t -import arrow import discord from discord.ext import commands from discord.ext.commands import Context @@ -10,6 +9,7 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser, allowed_strings +from bot.decorators import ensure_duration_in_future from bot.errors import InvalidInfraction from bot.exts.moderation.infraction.infractions import Infractions from bot.exts.moderation.modlog import ModLog @@ -100,6 +100,7 @@ class ModManagement(commands.Cog): await self.infraction_edit(ctx, infraction, duration, reason=reason) @infraction_group.command(name='edit', aliases=('e',)) + @ensure_duration_in_future(duration_arg=3) async def infraction_edit( self, ctx: Context, @@ -147,10 +148,6 @@ class ModManagement(commands.Cog): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: - now_datetime = arrow.utcnow() - if duration < now_datetime: - await ctx.send(":x: Expiration is in the past.") - return request_data['expires_at'] = duration.isoformat() expiry = time.format_with_duration(duration) confirm_messages.append(f"set to expire on {expiry}") |