diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/decorators.py | 30 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 6 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 9 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 7 | 
4 files changed, 40 insertions, 12 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index f4331264f..dd001d3ca 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -3,7 +3,9 @@ import functools  import types  import typing as t  from contextlib import suppress +from datetime import datetime +import arrow  from discord import Member, NotFound  from discord.ext import commands  from discord.ext.commands import Cog, Context @@ -236,3 +238,31 @@ def mock_in_debug(return_value: t.Any) -> t.Callable:              return await func(*args, **kwargs)          return wrapped      return decorator + + +def ensure_duration_in_future(duration_arg: function.Argument) -> t.Callable: +    """ +    Ensure the duration argument is in the future. + +    If the condition fails, a warning is sent to the invoking context. + +    `duration_arg` is the keyword name or position index of the parameter of the decorated command +    whose value is the target duration. + +    This decorator must go before (below) the `command` decorator. +    """ +    def decorator(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func) +        async def wrapper(*args, **kwargs) -> t.Any: +            bound_args = function.get_bound_args(func, args, kwargs) +            target = function.get_arg_value(duration_arg, bound_args) + +            ctx = function.get_arg_value(1, bound_args) + +            if isinstance(target, datetime) and target < arrow.utcnow(): +                await ctx.send(":x: Expiration is in the past.") +                return + +            return await func(*args, **kwargs) +        return wrapper +    return decorator diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index e607bf752..47b639421 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -137,14 +137,8 @@ class InfractionScheduler:          icon = _utils.INFRACTION_ICONS[infr_type][0]          reason = infraction["reason"]          expiry = time.format_with_duration(infraction["expires_at"]) -        expiry_datetime = arrow.get(infraction["expires_at"])          id_ = infraction['id'] -        now_datetime = arrow.utcnow() -        if expiry_datetime < now_datetime: -            await ctx.send(":x: Expiration is in the past.") -            return False -          if user_reason is None:              user_reason = reason diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index af42ab1b8..09610cb1a 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,7 +10,7 @@ from bot import constants  from bot.bot import Bot  from bot.constants import Event  from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser -from bot.decorators import respect_role_hierarchy +from bot.decorators import ensure_duration_in_future, respect_role_hierarchy  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.log import get_logger @@ -81,6 +81,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_kick(ctx, user, reason)      @command() +    @ensure_duration_in_future(duration_arg=3)      async def ban(          self,          ctx: Context, @@ -97,6 +98,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_ban(ctx, user, reason, expires_at=duration)      @command(aliases=("cban", "purgeban", "pban")) +    @ensure_duration_in_future(duration_arg=3)      async def cleanban(          self,          ctx: Context, @@ -161,6 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `voicemute`?")      @command(aliases=("vmute",)) +    @ensure_duration_in_future(duration_arg=3)      async def voicemute(          self,          ctx: Context, @@ -180,6 +183,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Temporary infractions      @command(aliases=["mute"]) +    @ensure_duration_in_future(duration_arg=3)      async def tempmute(          self, ctx: Context,          user: UnambiguousMemberOrUser, @@ -213,6 +217,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_mute(ctx, user, reason, expires_at=duration)      @command(aliases=("tban",)) +    @ensure_duration_in_future(duration_arg=3)      async def tempban(          self,          ctx: Context, @@ -248,6 +253,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `tempvoicemute`?")      @command(aliases=("tempvmute", "tvmute")) +    @ensure_duration_in_future(duration_arg=3)      async def tempvoicemute(          self,          ctx: Context, @@ -294,6 +300,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Temporary shadow infractions      @command(hidden=True, aliases=["shadowtempban", "stempban", "stban"]) +    @ensure_duration_in_future(duration_arg=3)      async def shadow_tempban(          self,          ctx: Context, diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index a9dc231c1..4ec1039af 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -1,7 +1,6 @@  import textwrap  import typing as t -import arrow  import discord  from discord.ext import commands  from discord.ext.commands import Context @@ -10,6 +9,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot  from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser, allowed_strings +from bot.decorators import ensure_duration_in_future  from bot.errors import InvalidInfraction  from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog @@ -100,6 +100,7 @@ class ModManagement(commands.Cog):          await self.infraction_edit(ctx, infraction, duration, reason=reason)      @infraction_group.command(name='edit', aliases=('e',)) +    @ensure_duration_in_future(duration_arg=3)      async def infraction_edit(          self,          ctx: Context, @@ -147,10 +148,6 @@ class ModManagement(commands.Cog):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: -            now_datetime = arrow.utcnow() -            if duration < now_datetime: -                await ctx.send(":x: Expiration is in the past.") -                return              request_data['expires_at'] = duration.isoformat()              expiry = time.format_with_duration(duration)              confirm_messages.append(f"set to expire on {expiry}") | 
