aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2025-08-10 00:12:00 +0300
committerGravatar GitHub <[email protected]>2025-08-10 00:12:00 +0300
commit876adc2ff060127e6496a16c5005dc3a0632eecd (patch)
treec19461757f2900c8b63e15dc0ce4d52b7c8a02b8
parentUpdate extensions functools.partial to enum.member (diff)
parentMerge branch 'main' into main (diff)
Merge pull request #3331 from b0nes1/main
Implemented optional duration parameter in slowmode command
Diffstat (limited to '')
-rw-r--r--bot/exts/moderation/slowmode.py130
-rw-r--r--tests/bot/exts/moderation/test_slowmode.py119
2 files changed, 230 insertions, 19 deletions
diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py
index 6efb710bb..a83692fa4 100644
--- a/bot/exts/moderation/slowmode.py
+++ b/bot/exts/moderation/slowmode.py
@@ -1,12 +1,16 @@
+from datetime import datetime
from typing import Literal
+from async_rediscache import RedisCache
from dateutil.relativedelta import relativedelta
from discord import TextChannel, Thread
from discord.ext.commands import Cog, Context, group, has_any_role
+from pydis_core.utils.channel import get_or_fetch_channel
+from pydis_core.utils.scheduling import Scheduler
from bot.bot import Bot
from bot.constants import Channels, Emojis, MODERATION_ROLES
-from bot.converters import DurationDelta
+from bot.converters import Duration, DurationDelta
from bot.log import get_logger
from bot.utils import time
@@ -26,8 +30,14 @@ MessageHolder = TextChannel | Thread | None
class Slowmode(Cog):
"""Commands for getting and setting slowmode delays of text channels."""
+ # RedisCache[discord.channel.id : f"{delay}, {expiry}"]
+ # `delay` is the slowmode delay assigned to the text channel.
+ # `expiry` is a naïve ISO 8601 string which describes when the slowmode should be removed.
+ slowmode_cache = RedisCache()
+
def __init__(self, bot: Bot) -> None:
self.bot = bot
+ self.scheduler = Scheduler(self.__class__.__name__)
@group(name="slowmode", aliases=["sm"], invoke_without_command=True)
async def slowmode_group(self, ctx: Context) -> None:
@@ -42,8 +52,14 @@ class Slowmode(Cog):
channel = ctx.channel
humanized_delay = time.humanize_delta(seconds=channel.slowmode_delay)
-
- await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")
+ original_delay, humanized_original_delay, expiration_timestamp = await self._fetch_sm_cache(channel.id)
+ if original_delay is not None:
+ await ctx.send(
+ f"The slowmode delay for {channel.mention} is {humanized_delay}"
+ f" and will revert to {humanized_original_delay} {expiration_timestamp}."
+ )
+ else:
+ await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")
@slowmode_group.command(name="set", aliases=["s"])
async def set_slowmode(
@@ -51,8 +67,14 @@ class Slowmode(Cog):
ctx: Context,
channel: MessageHolder,
delay: DurationDelta | Literal["0s", "0seconds"],
+ expiry: Duration | None = None
) -> None:
- """Set the slowmode delay for a text channel."""
+ """
+ Set the slowmode delay for a text channel.
+
+ Supports temporary slowmodes with the `expiry` argument that automatically
+ revert to the original delay after expiration.
+ """
# Use the channel this command was invoked in if one was not given
if channel is None:
channel = ctx.channel
@@ -62,31 +84,96 @@ class Slowmode(Cog):
if isinstance(delay, str):
delay = relativedelta(seconds=0)
- slowmode_delay = time.relativedelta_to_timedelta(delay).total_seconds()
+ slowmode_delay = int(time.relativedelta_to_timedelta(delay).total_seconds())
humanized_delay = time.humanize_delta(delay)
# Ensure the delay is within discord's limits
- if slowmode_delay <= SLOWMODE_MAX_DELAY:
- log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")
-
- await channel.edit(slowmode_delay=slowmode_delay)
- if channel.id in COMMONLY_SLOWMODED_CHANNELS:
- log.info(f"Recording slowmode change in stats for {channel.name}.")
- self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)
+ if slowmode_delay > SLOWMODE_MAX_DELAY:
+ log.info(
+ f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
+ "which is not between 0 and 6 hours."
+ )
await ctx.send(
- f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
+ f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
)
+ return
- else:
+ if expiry is not None:
+ expiration_timestamp = time.format_relative(expiry)
+
+ original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel.id)
+ # Cache the channel's current delay if it has no expiry, otherwise use the cached original delay.
+ if original_delay is None:
+ original_delay = channel.slowmode_delay
+ humanized_original_delay = time.humanize_delta(seconds=original_delay)
+ else:
+ self.scheduler.cancel(channel.id)
+ await self.slowmode_cache.set(channel.id, f"{original_delay}, {expiry}")
+
+ self.scheduler.schedule_at(expiry, channel.id, self._revert_slowmode(channel.id))
log.info(
- f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
- "which is not between 0 and 6 hours."
+ f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}"
+ f" which will revert to {humanized_original_delay} in {time.humanize_delta(expiry)}."
+ )
+ await channel.edit(slowmode_delay=slowmode_delay)
+ await ctx.send(
+ f"{Emojis.check_mark} The slowmode delay for {channel.mention}"
+ f" is now {humanized_delay} and will revert to {humanized_original_delay} {expiration_timestamp}."
)
+ else:
+ if await self.slowmode_cache.contains(channel.id):
+ await self.slowmode_cache.delete(channel.id)
+ self.scheduler.cancel(channel.id)
+ log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")
+ await channel.edit(slowmode_delay=slowmode_delay)
await ctx.send(
- f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
+ f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
)
+ if channel.id in COMMONLY_SLOWMODED_CHANNELS:
+ log.info(f"Recording slowmode change in stats for {channel.name}.")
+ self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)
+
+ async def _reschedule(self) -> None:
+ log.trace("Rescheduling the expiration of temporary slowmodes from cache.")
+ for channel_id, cached_data in await self.slowmode_cache.items():
+ expiration = cached_data.split(", ")[1]
+ expiration_datetime = datetime.fromisoformat(expiration)
+ channel = self.bot.get_channel(channel_id)
+ log.info(f"Rescheduling slowmode expiration for #{channel} ({channel_id}).")
+ self.scheduler.schedule_at(expiration_datetime, channel_id, self._revert_slowmode(channel_id))
+
+ async def _fetch_sm_cache(self, channel_id: int) -> tuple[int | None, str, str]:
+ """
+ Fetch the channel's info from the cache and decode it.
+
+ If no cache for the channel, the returned slowmode is None.
+ """
+ cached_data = await self.slowmode_cache.get(channel_id, None)
+ if not cached_data:
+ return None, "", ""
+
+ original_delay, expiration_time = cached_data.split(", ")
+ original_delay = int(original_delay)
+ humanized_original_delay = time.humanize_delta(seconds=original_delay)
+ expiration_timestamp = time.format_relative(expiration_time)
+
+ return original_delay, humanized_original_delay, expiration_timestamp
+
+ async def _revert_slowmode(self, channel_id: int) -> None:
+ original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel_id)
+ channel = await get_or_fetch_channel(self.bot, channel_id)
+ mod_channel = await get_or_fetch_channel(self.bot, Channels.mods)
+ log.info(
+ f"Slowmode in #{channel.name} ({channel.id}) has expired and has reverted to {humanized_original_delay}."
+ )
+ await channel.edit(slowmode_delay=original_delay)
+ await mod_channel.send(
+ f"{Emojis.check_mark} A previously applied slowmode in {channel.jump_url} ({channel.id})"
+ f" has expired and has been reverted to {humanized_original_delay}."
+ )
+ await self.slowmode_cache.delete(channel.id)
@slowmode_group.command(name="reset", aliases=["r"])
async def reset_slowmode(self, ctx: Context, channel: MessageHolder) -> None:
@@ -97,6 +184,15 @@ class Slowmode(Cog):
"""Only allow moderators to invoke the commands in this cog."""
return await has_any_role(*MODERATION_ROLES).predicate(ctx)
+ async def cog_load(self) -> None:
+ """Wait for guild to become available and reschedule slowmodes which should expire."""
+ await self.bot.wait_until_guild_available()
+ await self._reschedule()
+
+ async def cog_unload(self) -> None:
+ """Cancel all scheduled tasks."""
+ self.scheduler.cancel_all()
+
async def setup(bot: Bot) -> None:
"""Load the Slowmode cog."""
diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py
index cf5101e16..d75fcd2f1 100644
--- a/tests/bot/exts/moderation/test_slowmode.py
+++ b/tests/bot/exts/moderation/test_slowmode.py
@@ -1,14 +1,16 @@
-import unittest
+import asyncio
+import datetime
from unittest import mock
from dateutil.relativedelta import relativedelta
from bot.constants import Emojis
from bot.exts.moderation.slowmode import Slowmode
+from tests.base import RedisTestCase
from tests.helpers import MockBot, MockContext, MockTextChannel
-class SlowmodeTests(unittest.IsolatedAsyncioTestCase):
+class SlowmodeTests(RedisTestCase):
def setUp(self) -> None:
self.bot = MockBot()
@@ -95,6 +97,119 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase):
self.ctx, text_channel, relativedelta(seconds=0)
)
+ @mock.patch("bot.exts.moderation.slowmode.datetime")
+ async def test_set_slowmode_with_expiry(self, mock_datetime) -> None:
+ """Set slowmode with an expiry"""
+ fixed_datetime = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)
+ mock_datetime.now.return_value = fixed_datetime
+
+ test_cases = (
+ ("python-general", 6, 6000, f"{Emojis.check_mark} The slowmode delay for #python-general is now 6 seconds "
+ "and will revert to 0 seconds <t:1748871600:R>."),
+ ("mod-spam", 5, 600, f"{Emojis.check_mark} The slowmode delay for #mod-spam is now 5 seconds and will "
+ "revert to 0 seconds <t:1748866200:R>."),
+ ("changelog", 12, 7200, f"{Emojis.check_mark} The slowmode delay for #changelog is now 12 seconds and will "
+ "revert to 0 seconds <t:1748872800:R>.")
+ )
+ for channel_name, seconds, expiry, result_msg in test_cases:
+ with self.subTest(
+ channel_mention=channel_name,
+ seconds=seconds,
+ expiry=expiry,
+ result_msg=result_msg
+ ):
+ text_channel = MockTextChannel(name=channel_name, slowmode_delay=0)
+ await self.cog.set_slowmode(
+ self.cog,
+ self.ctx,
+ text_channel,
+ relativedelta(seconds=seconds),
+ fixed_datetime + relativedelta(seconds=expiry)
+ )
+ text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds))
+ self.ctx.send.assert_called_once_with(result_msg)
+ self.ctx.reset_mock()
+
+ async def test_callback_scheduled(self):
+ """Schedule slowmode to be reverted"""
+ self.cog.scheduler=mock.MagicMock(wraps=self.cog.scheduler)
+
+ text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123)
+ expiry = datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10)
+ await self.cog.set_slowmode(
+ self.cog,
+ self.ctx,
+ text_channel,
+ relativedelta(seconds=4),
+ expiry
+ )
+
+ args = (expiry, text_channel.id, mock.ANY)
+ self.cog.scheduler.schedule_at.assert_called_once_with(*args)
+
+ @mock.patch("bot.exts.moderation.slowmode.get_or_fetch_channel")
+ async def test_revert_slowmode_callback(self, mock_get_or_fetch_channel) -> None:
+ """Check that the slowmode is reverted"""
+ text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123, jump_url="#python-general")
+ mod_channel = MockTextChannel(name="mods", id=999, )
+ mock_get_or_fetch_channel.side_effect = [text_channel, mod_channel]
+
+ await self.cog.set_slowmode(
+ self.cog,
+ self.ctx,
+ text_channel,
+ relativedelta(seconds=4),
+ datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10)
+ )
+ await self.cog._revert_slowmode(text_channel.id)
+ text_channel.edit.assert_awaited_with(slowmode_delay=2)
+ mod_channel.send.assert_called_once_with(
+ f"{Emojis.check_mark} A previously applied slowmode in {text_channel.jump_url} ({text_channel.id}) "
+ "has expired and has been reverted to 2 seconds."
+ )
+
+ async def test_reschedule_slowmodes(self) -> None:
+ """Does not reschedule if cache is empty"""
+ self.cog.scheduler.schedule_at = mock.MagicMock()
+ self.cog._reschedule = mock.AsyncMock()
+ await self.cog.cog_unload()
+ await self.cog.cog_load()
+
+ self.cog._reschedule.assert_called()
+ self.cog.scheduler.schedule_at.assert_not_called()
+
+ async def test_reschedule_upon_reload(self) -> None:
+ """ Check that method `_reschedule` is called upon cog reload"""
+ self.cog._reschedule = mock.AsyncMock(wraps=self.cog._reschedule)
+ await self.cog.cog_unload()
+ await self.cog.cog_load()
+
+ self.cog._reschedule.assert_called()
+
+ async def test_reschedules_slowmodes(self) -> None:
+ """Slowmodes are loaded from cache at cog reload and scheduled to be reverted."""
+
+ now = datetime.datetime.now(tz=datetime.UTC)
+ channels = {}
+ slowmodes = (
+ (123, (now - datetime.timedelta(minutes=10)), 2), # expiration in the past
+ (456, (now + datetime.timedelta(minutes=20)), 4), # expiration in the future
+ )
+ for channel_id, expiration_datetime, delay in slowmodes:
+ channel = MockTextChannel(slowmode_delay=delay, id=channel_id)
+ channels[channel_id] = channel
+ await self.cog.slowmode_cache.set(channel_id, f"{delay}, {expiration_datetime}")
+
+ self.bot.get_channel = mock.MagicMock(side_effect=lambda channel_id: channels.get(channel_id))
+ await self.cog.cog_unload()
+ await self.cog.cog_load()
+ for channel_id in channels:
+ self.assertIn(channel_id, self.cog.scheduler)
+
+ await asyncio.sleep(1) # give scheduled task time to execute
+ channels[123].edit.assert_awaited_once_with(slowmode_delay=channels[123].slowmode_delay)
+ channels[456].edit.assert_not_called()
+
@mock.patch("bot.exts.moderation.slowmode.has_any_role")
@mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3))
async def test_cog_check(self, role_check):