diff options
| author | 2025-08-10 00:12:00 +0300 | |
|---|---|---|
| committer | 2025-08-10 00:12:00 +0300 | |
| commit | 876adc2ff060127e6496a16c5005dc3a0632eecd (patch) | |
| tree | c19461757f2900c8b63e15dc0ce4d52b7c8a02b8 | |
| parent | Update extensions functools.partial to enum.member (diff) | |
| parent | Merge branch 'main' into main (diff) | |
Merge pull request #3331 from b0nes1/main
Implemented optional duration parameter in slowmode command
| -rw-r--r-- | bot/exts/moderation/slowmode.py | 130 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_slowmode.py | 119 | 
2 files changed, 230 insertions, 19 deletions
diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index 6efb710bb..a83692fa4 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -1,12 +1,16 @@ +from datetime import datetime  from typing import Literal +from async_rediscache import RedisCache  from dateutil.relativedelta import relativedelta  from discord import TextChannel, Thread  from discord.ext.commands import Cog, Context, group, has_any_role +from pydis_core.utils.channel import get_or_fetch_channel +from pydis_core.utils.scheduling import Scheduler  from bot.bot import Bot  from bot.constants import Channels, Emojis, MODERATION_ROLES -from bot.converters import DurationDelta +from bot.converters import Duration, DurationDelta  from bot.log import get_logger  from bot.utils import time @@ -26,8 +30,14 @@ MessageHolder = TextChannel | Thread | None  class Slowmode(Cog):      """Commands for getting and setting slowmode delays of text channels.""" +    # RedisCache[discord.channel.id : f"{delay}, {expiry}"] +    # `delay` is the slowmode delay assigned to the text channel. +    # `expiry` is a naïve ISO 8601 string which describes when the slowmode should be removed. +    slowmode_cache = RedisCache() +      def __init__(self, bot: Bot) -> None:          self.bot = bot +        self.scheduler = Scheduler(self.__class__.__name__)      @group(name="slowmode", aliases=["sm"], invoke_without_command=True)      async def slowmode_group(self, ctx: Context) -> None: @@ -42,8 +52,14 @@ class Slowmode(Cog):              channel = ctx.channel          humanized_delay = time.humanize_delta(seconds=channel.slowmode_delay) - -        await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.") +        original_delay, humanized_original_delay, expiration_timestamp = await self._fetch_sm_cache(channel.id) +        if original_delay is not None: +            await ctx.send( +                f"The slowmode delay for {channel.mention} is {humanized_delay}" +                f" and will revert to {humanized_original_delay} {expiration_timestamp}." +            ) +        else: +            await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")      @slowmode_group.command(name="set", aliases=["s"])      async def set_slowmode( @@ -51,8 +67,14 @@ class Slowmode(Cog):          ctx: Context,          channel: MessageHolder,          delay: DurationDelta | Literal["0s", "0seconds"], +        expiry: Duration | None = None      ) -> None: -        """Set the slowmode delay for a text channel.""" +        """ +        Set the slowmode delay for a text channel. + +        Supports temporary slowmodes with the `expiry` argument that automatically +        revert to the original delay after expiration. +        """          # Use the channel this command was invoked in if one was not given          if channel is None:              channel = ctx.channel @@ -62,31 +84,96 @@ class Slowmode(Cog):          if isinstance(delay, str):              delay = relativedelta(seconds=0) -        slowmode_delay = time.relativedelta_to_timedelta(delay).total_seconds() +        slowmode_delay = int(time.relativedelta_to_timedelta(delay).total_seconds())          humanized_delay = time.humanize_delta(delay)          # Ensure the delay is within discord's limits -        if slowmode_delay <= SLOWMODE_MAX_DELAY: -            log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.") - -            await channel.edit(slowmode_delay=slowmode_delay) -            if channel.id in COMMONLY_SLOWMODED_CHANNELS: -                log.info(f"Recording slowmode change in stats for {channel.name}.") -                self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay) +        if slowmode_delay > SLOWMODE_MAX_DELAY: +            log.info( +                f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, " +                "which is not between 0 and 6 hours." +            )              await ctx.send( -                f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}." +                f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."              ) +            return -        else: +        if expiry is not None: +            expiration_timestamp = time.format_relative(expiry) + +            original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel.id) +            # Cache the channel's current delay if it has no expiry, otherwise use the cached original delay. +            if original_delay is None: +                original_delay = channel.slowmode_delay +                humanized_original_delay = time.humanize_delta(seconds=original_delay) +            else: +                self.scheduler.cancel(channel.id) +            await self.slowmode_cache.set(channel.id, f"{original_delay}, {expiry}") + +            self.scheduler.schedule_at(expiry, channel.id, self._revert_slowmode(channel.id))              log.info( -                f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, " -                "which is not between 0 and 6 hours." +                f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}" +                f" which will revert to {humanized_original_delay} in {time.humanize_delta(expiry)}." +            ) +            await channel.edit(slowmode_delay=slowmode_delay) +            await ctx.send( +                f"{Emojis.check_mark} The slowmode delay for {channel.mention}" +                f" is now {humanized_delay} and will revert to {humanized_original_delay} {expiration_timestamp}."              ) +        else: +            if await self.slowmode_cache.contains(channel.id): +                await self.slowmode_cache.delete(channel.id) +                self.scheduler.cancel(channel.id) +            log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.") +            await channel.edit(slowmode_delay=slowmode_delay)              await ctx.send( -                f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours." +                f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."              ) +        if channel.id in COMMONLY_SLOWMODED_CHANNELS: +            log.info(f"Recording slowmode change in stats for {channel.name}.") +            self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay) + +    async def _reschedule(self) -> None: +        log.trace("Rescheduling the expiration of temporary slowmodes from cache.") +        for channel_id, cached_data in await self.slowmode_cache.items(): +            expiration = cached_data.split(", ")[1] +            expiration_datetime = datetime.fromisoformat(expiration) +            channel = self.bot.get_channel(channel_id) +            log.info(f"Rescheduling slowmode expiration for #{channel} ({channel_id}).") +            self.scheduler.schedule_at(expiration_datetime, channel_id, self._revert_slowmode(channel_id)) + +    async def _fetch_sm_cache(self, channel_id: int) -> tuple[int | None, str, str]: +        """ +        Fetch the channel's info from the cache and decode it. + +        If no cache for the channel, the returned slowmode is None. +        """ +        cached_data = await self.slowmode_cache.get(channel_id, None) +        if not cached_data: +            return None, "", "" + +        original_delay, expiration_time = cached_data.split(", ") +        original_delay = int(original_delay) +        humanized_original_delay = time.humanize_delta(seconds=original_delay) +        expiration_timestamp = time.format_relative(expiration_time) + +        return original_delay, humanized_original_delay, expiration_timestamp + +    async def _revert_slowmode(self, channel_id: int) -> None: +        original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel_id) +        channel = await get_or_fetch_channel(self.bot, channel_id) +        mod_channel = await get_or_fetch_channel(self.bot, Channels.mods) +        log.info( +            f"Slowmode in #{channel.name} ({channel.id}) has expired and has reverted to {humanized_original_delay}." +        ) +        await channel.edit(slowmode_delay=original_delay) +        await mod_channel.send( +            f"{Emojis.check_mark} A previously applied slowmode in {channel.jump_url} ({channel.id})" +            f" has expired and has been reverted to {humanized_original_delay}." +        ) +        await self.slowmode_cache.delete(channel.id)      @slowmode_group.command(name="reset", aliases=["r"])      async def reset_slowmode(self, ctx: Context, channel: MessageHolder) -> None: @@ -97,6 +184,15 @@ class Slowmode(Cog):          """Only allow moderators to invoke the commands in this cog."""          return await has_any_role(*MODERATION_ROLES).predicate(ctx) +    async def cog_load(self) -> None: +        """Wait for guild to become available and reschedule slowmodes which should expire.""" +        await self.bot.wait_until_guild_available() +        await self._reschedule() + +    async def cog_unload(self) -> None: +        """Cancel all scheduled tasks.""" +        self.scheduler.cancel_all() +  async def setup(bot: Bot) -> None:      """Load the Slowmode cog.""" diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py index cf5101e16..d75fcd2f1 100644 --- a/tests/bot/exts/moderation/test_slowmode.py +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -1,14 +1,16 @@ -import unittest +import asyncio +import datetime  from unittest import mock  from dateutil.relativedelta import relativedelta  from bot.constants import Emojis  from bot.exts.moderation.slowmode import Slowmode +from tests.base import RedisTestCase  from tests.helpers import MockBot, MockContext, MockTextChannel -class SlowmodeTests(unittest.IsolatedAsyncioTestCase): +class SlowmodeTests(RedisTestCase):      def setUp(self) -> None:          self.bot = MockBot() @@ -95,6 +97,119 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase):              self.ctx, text_channel, relativedelta(seconds=0)          ) +    @mock.patch("bot.exts.moderation.slowmode.datetime") +    async def test_set_slowmode_with_expiry(self, mock_datetime) -> None: +        """Set slowmode with an expiry""" +        fixed_datetime = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) +        mock_datetime.now.return_value = fixed_datetime + +        test_cases = ( +            ("python-general", 6, 6000, f"{Emojis.check_mark} The slowmode delay for #python-general is now 6 seconds " +             "and will revert to 0 seconds <t:1748871600:R>."), +            ("mod-spam", 5, 600, f"{Emojis.check_mark} The slowmode delay for #mod-spam is now 5 seconds and will " +             "revert to 0 seconds <t:1748866200:R>."), +            ("changelog", 12, 7200, f"{Emojis.check_mark} The slowmode delay for #changelog is now 12 seconds and will " +             "revert to 0 seconds <t:1748872800:R>.") +        ) +        for channel_name, seconds, expiry, result_msg in test_cases: +            with self.subTest( +                channel_mention=channel_name, +                seconds=seconds, +                expiry=expiry, +                result_msg=result_msg +            ): +                text_channel = MockTextChannel(name=channel_name, slowmode_delay=0) +                await self.cog.set_slowmode( +                    self.cog, +                    self.ctx, +                    text_channel, +                    relativedelta(seconds=seconds), +                    fixed_datetime + relativedelta(seconds=expiry) +                ) +                text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) +                self.ctx.send.assert_called_once_with(result_msg) +            self.ctx.reset_mock() + +    async def test_callback_scheduled(self): +        """Schedule slowmode to be reverted""" +        self.cog.scheduler=mock.MagicMock(wraps=self.cog.scheduler) + +        text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123) +        expiry = datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10) +        await self.cog.set_slowmode( +            self.cog, +            self.ctx, +            text_channel, +            relativedelta(seconds=4), +            expiry +            ) + +        args = (expiry, text_channel.id, mock.ANY) +        self.cog.scheduler.schedule_at.assert_called_once_with(*args) + +    @mock.patch("bot.exts.moderation.slowmode.get_or_fetch_channel") +    async def test_revert_slowmode_callback(self, mock_get_or_fetch_channel) -> None: +        """Check that the slowmode is reverted""" +        text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123, jump_url="#python-general") +        mod_channel = MockTextChannel(name="mods", id=999, ) +        mock_get_or_fetch_channel.side_effect = [text_channel, mod_channel] + +        await self.cog.set_slowmode( +            self.cog, +            self.ctx, +            text_channel, +            relativedelta(seconds=4), +            datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10) +            ) +        await self.cog._revert_slowmode(text_channel.id) +        text_channel.edit.assert_awaited_with(slowmode_delay=2) +        mod_channel.send.assert_called_once_with( +            f"{Emojis.check_mark} A previously applied slowmode in {text_channel.jump_url} ({text_channel.id}) " +            "has expired and has been reverted to 2 seconds." +            ) + +    async def test_reschedule_slowmodes(self) -> None: +        """Does not reschedule if cache is empty""" +        self.cog.scheduler.schedule_at = mock.MagicMock() +        self.cog._reschedule = mock.AsyncMock() +        await self.cog.cog_unload() +        await self.cog.cog_load() + +        self.cog._reschedule.assert_called() +        self.cog.scheduler.schedule_at.assert_not_called() + +    async def test_reschedule_upon_reload(self) -> None: +        """ Check that method `_reschedule` is called upon cog reload""" +        self.cog._reschedule = mock.AsyncMock(wraps=self.cog._reschedule) +        await self.cog.cog_unload() +        await self.cog.cog_load() + +        self.cog._reschedule.assert_called() + +    async def test_reschedules_slowmodes(self) -> None: +        """Slowmodes are loaded from cache at cog reload and scheduled to be reverted.""" + +        now = datetime.datetime.now(tz=datetime.UTC) +        channels = {} +        slowmodes = ( +            (123, (now - datetime.timedelta(minutes=10)), 2), # expiration in the past +            (456, (now + datetime.timedelta(minutes=20)), 4), # expiration in the future +        ) +        for channel_id, expiration_datetime, delay in slowmodes: +            channel = MockTextChannel(slowmode_delay=delay, id=channel_id) +            channels[channel_id] = channel +            await self.cog.slowmode_cache.set(channel_id, f"{delay}, {expiration_datetime}") + +        self.bot.get_channel = mock.MagicMock(side_effect=lambda channel_id: channels.get(channel_id)) +        await self.cog.cog_unload() +        await self.cog.cog_load() +        for channel_id in channels: +            self.assertIn(channel_id, self.cog.scheduler) + +        await asyncio.sleep(1) # give scheduled task time to execute +        channels[123].edit.assert_awaited_once_with(slowmode_delay=channels[123].slowmode_delay) +        channels[456].edit.assert_not_called() +      @mock.patch("bot.exts.moderation.slowmode.has_any_role")      @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3))      async def test_cog_check(self, role_check):  |