diff options
-rw-r--r-- | bot/exts/moderation/slowmode.py | 66 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_slowmode.py | 45 |
2 files changed, 55 insertions, 56 deletions
diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index 296ae6742..9caf776b1 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime, timedelta +from datetime import datetime from typing import Literal from async_rediscache import RedisCache @@ -9,7 +9,7 @@ from pydis_core.utils.scheduling import Scheduler from bot.bot import Bot from bot.constants import Channels, Emojis, MODERATION_ROLES -from bot.converters import DurationDelta +from bot.converters import Duration, DurationDelta from bot.log import get_logger from bot.utils import time from bot.utils.time import TimestampFormats, discord_timestamp @@ -30,11 +30,10 @@ MessageHolder = TextChannel | Thread | None class Slowmode(Cog): """Commands for getting and setting slowmode delays of text channels.""" - # Stores the expiration timestamp in POSIX format for active slowmodes, keyed by channel ID. - slowmode_expiration_cache = RedisCache() - - # Stores the original slowmode interval by channel ID, allowing its restoration after temporary slowmode expires. - original_slowmode_cache = RedisCache() + # RedisCache[discord.channel.id : f"{delay}, {expiry}"] + # `delay` is the slowmode delay assigned to the text channel. + # `expiry` is a naïve ISO 8601 string which describes when the slowmode should be removed. + slowmode_cache = RedisCache() def __init__(self, bot: Bot) -> None: self.bot = bot @@ -53,8 +52,8 @@ class Slowmode(Cog): channel = ctx.channel humanized_delay = time.humanize_delta(seconds=channel.slowmode_delay) - if await self.slowmode_expiration_cache.contains(channel.id): - expiration_time = await self.slowmode_expiration_cache.get(channel.id) + if await self.slowmode_cache.contains(channel.id): + expiration_time = await self.slowmode_cache.get(channel.id).split(", ")[1] expiration_timestamp = discord_timestamp(expiration_time, TimestampFormats.RELATIVE) await ctx.send( f"The slowmode delay for {channel.mention} is {humanized_delay} and expires in {expiration_timestamp}." @@ -68,12 +67,12 @@ class Slowmode(Cog): ctx: Context, channel: MessageHolder, delay: DurationDelta | Literal["0s", "0seconds"], - duration: DurationDelta | None = None + expiry: Duration | None = None ) -> None: """ Set the slowmode delay for a text channel. - Supports temporary slowmodes with the `duration` argument that automatically + Supports temporary slowmodes with the `expiry` argument that automatically revert to the original delay after expiration. """ # Use the channel this command was invoked in if one was not given @@ -100,22 +99,22 @@ class Slowmode(Cog): ) return - if duration is not None: - slowmode_duration = time.relativedelta_to_timedelta(duration).total_seconds() - humanized_duration = time.humanize_delta(duration) - - expiration_time = datetime.now(tz=UTC) + timedelta(seconds=slowmode_duration) - expiration_timestamp = discord_timestamp(expiration_time, TimestampFormats.RELATIVE) + if expiry is not None: + humanized_expiry = time.humanize_delta(expiry) + expiration_timestamp = discord_timestamp(expiry, TimestampFormats.RELATIVE) - # Only update original_slowmode_cache if the last slowmode was not temporary. - if not await self.slowmode_expiration_cache.contains(channel.id): - await self.original_slowmode_cache.set(channel.id, channel.slowmode_delay) - await self.slowmode_expiration_cache.set(channel.id, expiration_time.timestamp()) + # Only cache the original slowmode delay if there is not already an ongoing temporary slowmode. + if not await self.slowmode_cache.contains(channel.id): + await self.slowmode_cache.set(channel.id, f"{channel.slowmode_delay}, {expiry}") + else: + cached_delay = await self.slowmode_cache.get(channel.id) + await self.slowmode_cache.set(channel.id, f"{cached_delay}, {expiry}") + self.scheduler.cancel(channel.id) - self.scheduler.schedule_at(expiration_time, channel.id, self._revert_slowmode(channel.id)) + self.scheduler.schedule_at(expiry, channel.id, self._revert_slowmode(channel.id)) log.info( f"{ctx.author} set the slowmode delay for #{channel} to" - f"{humanized_delay} which expires in {humanized_duration}." + f"{humanized_delay} which expires in {humanized_expiry}." ) await channel.edit(slowmode_delay=slowmode_delay) await ctx.send( @@ -123,9 +122,8 @@ class Slowmode(Cog): f" is now {humanized_delay} and expires in {expiration_timestamp}." ) else: - if await self.slowmode_expiration_cache.contains(channel.id): - await self.slowmode_expiration_cache.delete(channel.id) - await self.original_slowmode_cache.delete(channel.id) + if await self.slowmode_cache.contains(channel.id): + await self.slowmode_cache.delete(channel.id) self.scheduler.cancel(channel.id) log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.") @@ -139,14 +137,16 @@ class Slowmode(Cog): async def _reschedule(self) -> None: log.trace("Rescheduling the expiration of temporary slowmodes from cache.") - for channel_id, expiration in await self.slowmode_expiration_cache.items(): - expiration_datetime = datetime.fromtimestamp(expiration, tz=UTC) + for channel_id, cached_data in await self.slowmode_cache.items(): + expiration = cached_data.split(", ")[1] + expiration_datetime = datetime.fromisoformat(expiration) channel = self.bot.get_channel(channel_id) log.info(f"Rescheduling slowmode expiration for #{channel} ({channel_id}).") self.scheduler.schedule_at(expiration_datetime, channel_id, self._revert_slowmode(channel_id)) async def _revert_slowmode(self, channel_id: int) -> None: - original_slowmode = await self.original_slowmode_cache.get(channel_id) + cached_data = await self.slowmode_cache.get(channel_id) + original_slowmode = int(cached_data.split(", ")[0]) slowmode_delay = time.humanize_delta(seconds=original_slowmode) channel = self.bot.get_channel(channel_id) log.info(f"Slowmode in #{channel} ({channel.id}) has expired and has reverted to {slowmode_delay}.") @@ -154,8 +154,7 @@ class Slowmode(Cog): await channel.send( f"{Emojis.check_mark} A previously applied slowmode has expired and has been reverted to {slowmode_delay}." ) - await self.slowmode_expiration_cache.delete(channel.id) - await self.original_slowmode_cache.delete(channel.id) + await self.slowmode_cache.delete(channel.id) @slowmode_group.command(name="reset", aliases=["r"]) async def reset_slowmode(self, ctx: Context, channel: MessageHolder) -> None: @@ -163,9 +162,8 @@ class Slowmode(Cog): await self.set_slowmode(ctx, channel, relativedelta(seconds=0)) if channel is None: channel = ctx.channel - if await self.slowmode_expiration_cache.contains(channel.id): - await self.slowmode_expiration_cache.delete(channel.id) - await self.original_slowmode_cache.delete(channel.id) + if await self.slowmode_cache.contains(channel.id): + await self.slowmode_cache.delete(channel.id) self.scheduler.cancel(channel.id) async def cog_check(self, ctx: Context) -> bool: diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py index 3d816e144..d88ffd784 100644 --- a/tests/bot/exts/moderation/test_slowmode.py +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -98,9 +98,11 @@ class SlowmodeTests(RedisTestCase): ) @mock.patch("bot.exts.moderation.slowmode.datetime") - async def test_set_slowmode_with_duration(self, mock_datetime) -> None: - """Set slowmode with a duration""" - mock_datetime.now.return_value = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) + async def test_set_slowmode_with_expiry(self, mock_datetime) -> None: + """Set slowmode with an expiry""" + fixed_datetime = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) + mock_datetime.now.return_value = fixed_datetime + test_cases = ( ("python-general", 6, 6000, f"{Emojis.check_mark} The slowmode delay for #python-general is now 6 seconds" " and expires in <t:1748871600:R>."), @@ -109,11 +111,11 @@ class SlowmodeTests(RedisTestCase): ("changelog", 12, 7200, f"{Emojis.check_mark} The slowmode delay for #changelog is now 12 seconds and" " expires in <t:1748872800:R>.") ) - for channel_name, seconds, duration, result_msg in test_cases: + for channel_name, seconds, expiry, result_msg in test_cases: with self.subTest( channel_mention=channel_name, seconds=seconds, - duration=duration, + expiry=expiry, result_msg=result_msg ): text_channel = MockTextChannel(name=channel_name, slowmode_delay=0) @@ -122,28 +124,27 @@ class SlowmodeTests(RedisTestCase): self.ctx, text_channel, relativedelta(seconds=seconds), - duration=relativedelta(seconds=duration) + fixed_datetime + relativedelta(seconds=expiry) ) text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) self.ctx.send.assert_called_once_with(result_msg) self.ctx.reset_mock() - @mock.patch("bot.exts.moderation.slowmode.datetime", wraps=datetime.datetime) - async def test_callback_scheduled(self, mock_datetime, ): + async def test_callback_scheduled(self): """Schedule slowmode to be reverted""" - mock_now = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) - mock_datetime.now.return_value = mock_now self.cog.scheduler=mock.MagicMock(wraps=self.cog.scheduler) text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123) + expiry = datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10) await self.cog.set_slowmode( self.cog, self.ctx, text_channel, relativedelta(seconds=4), - relativedelta(seconds=10)) + expiry + ) - args = (mock_now+relativedelta(seconds=10), text_channel.id, mock.ANY) + args = (expiry, text_channel.id, mock.ANY) self.cog.scheduler.schedule_at.assert_called_once_with(*args) async def test_revert_slowmode_callback(self) -> None: @@ -151,7 +152,11 @@ class SlowmodeTests(RedisTestCase): text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123) self.bot.get_channel = mock.MagicMock(return_value=text_channel) await self.cog.set_slowmode( - self.cog, self.ctx, text_channel, relativedelta(seconds=4), relativedelta(seconds=10) + self.cog, + self.ctx, + text_channel, + relativedelta(seconds=4), + datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10) ) await self.cog._revert_slowmode(text_channel.id) text_channel.edit.assert_awaited_with(slowmode_delay=2) @@ -177,23 +182,19 @@ class SlowmodeTests(RedisTestCase): self.cog._reschedule.assert_called() - @mock.patch("bot.exts.moderation.slowmode.datetime", wraps=datetime.datetime) - async def test_reschedules_slowmodes(self, mock_datetime) -> None: + async def test_reschedules_slowmodes(self) -> None: """Slowmodes are loaded from cache at cog reload and scheduled to be reverted.""" - mock_datetime.now.return_value = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) - mock_now = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC) + now = datetime.datetime.now(tz=datetime.UTC) channels = {} slowmodes = ( - (123, (mock_now - datetime.timedelta(10)).timestamp(), 2), # expiration in the past - (456, (mock_now + datetime.timedelta(20)).timestamp(), 4), # expiration in the future + (123, (now - datetime.timedelta(minutes=10)), 2), # expiration in the past + (456, (now + datetime.timedelta(minutes=20)), 4), # expiration in the future ) - for channel_id, expiration_datetime, delay in slowmodes: channel = MockTextChannel(slowmode_delay=delay, id=channel_id) channels[channel_id] = channel - await self.cog.slowmode_expiration_cache.set(channel_id, expiration_datetime) - await self.cog.original_slowmode_cache.set(channel_id, delay) + await self.cog.slowmode_cache.set(channel_id, f"{delay}, {expiration_datetime}") self.bot.get_channel = mock.MagicMock(side_effect=lambda channel_id: channels.get(channel_id)) await self.cog.cog_unload() |