aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Numerlor <[email protected]>2022-06-14 00:37:07 +0200
committerGravatar Numerlor <[email protected]>2022-06-15 00:02:08 +0200
commita3b0ffb72a1e72fc3be0e96c7407ddff2ade67c9 (patch)
tree5b19f75a36102181aaeafe9743f2abfea129483b
parentadd typing-extensions (diff)
Add decorator to block duplicate command invocations in a channel
-rw-r--r--botcore/utils/__init__.py3
-rw-r--r--botcore/utils/cooldown.py184
-rw-r--r--tests/botcore/utils/test_cooldown.py48
3 files changed, 234 insertions, 1 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py
index 6055d144..cfc5e99d 100644
--- a/botcore/utils/__init__.py
+++ b/botcore/utils/__init__.py
@@ -1,6 +1,6 @@
"""Useful utilities and tools for Discord bot development."""
-from botcore.utils import _monkey_patches, caching, channel, function, logging, members, regex, scheduling
+from botcore.utils import _monkey_patches, caching, channel, cooldown, function, logging, members, regex, scheduling
from botcore.utils._extensions import unqualify
@@ -24,6 +24,7 @@ __all__ = [
apply_monkey_patches,
caching,
channel,
+ cooldown,
function,
logging,
members,
diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py
new file mode 100644
index 00000000..a06dce46
--- /dev/null
+++ b/botcore/utils/cooldown.py
@@ -0,0 +1,184 @@
+"""Helpers for setting a cooldown on commands."""
+
+from __future__ import annotations
+
+import asyncio
+import math
+import random
+import time
+import typing
+from collections.abc import Awaitable, Hashable
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable
+
+import discord
+from discord.ext.commands import CommandError, Context
+
+from botcore.utils import scheduling
+from botcore.utils.function import command_wraps
+
+__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"]
+
+_ArgsTuple = tuple[object]
+
+if typing.TYPE_CHECKING:
+ from botcore import BotBase
+ import typing_extensions
+ P = typing_extensions.ParamSpec("P")
+ P.__constraints__ = ()
+else:
+ P = typing.TypeVar("P")
+ """The command's signature."""
+
+R = typing.TypeVar("R")
+"""The command's return value."""
+
+
+class CommandOnCooldown(CommandError, typing.Generic[P, R]):
+ """Raised when a command is invoked while on cooldown."""
+
+ def __init__(
+ self,
+ message: str | None,
+ function: Callable[P, Awaitable[R]],
+ *args: object,
+ **kwargs: object,
+ ):
+ super().__init__(message, function, args, kwargs)
+ self._function = function
+ self._args = args
+ self._kwargs = kwargs
+
+ async def call_without_cooldown(self) -> R:
+ """
+ Run the command this cooldown blocked.
+
+ Returns:
+ The command's return value.
+ """
+ return await self._function(*self._args, **self._kwargs)
+
+
+@dataclass
+class _CooldownItem:
+ call_arguments: _ArgsTuple
+ timeout_timestamp: float
+
+
+class _CommandCooldownManager:
+ """
+ Manage invocation cooldowns for a command through the arguments the command is called with.
+
+ A cooldown is set through `set_cooldown` for a channel with the given `call_arguments`,
+ if `is_on_cooldown` is checked within `cooldown_duration` seconds
+ of the call to `set_cooldown` with the same arguments, True is returned.
+ """
+
+ def __init__(self, *, cooldown_duration: float):
+ self._cooldowns = dict[tuple[Hashable, _ArgsTuple], float]()
+ self._cooldowns_non_hashable = dict[Hashable, list[_CooldownItem]]()
+ self._cooldown_duration = cooldown_duration
+ self.cleanup_task = scheduling.create_task(
+ self._periodical_cleanup(random.uniform(0, 10)),
+ name="CooldownManager cleanup",
+ )
+
+ def set_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> None:
+ """Set `call_arguments` arguments on cooldown in `channel`."""
+ timeout_timestamp = time.monotonic() + self._cooldown_duration
+
+ try:
+ self._cooldowns[(channel, call_arguments)] = timeout_timestamp
+ except TypeError:
+ cooldowns_list = self._cooldowns_non_hashable.setdefault(channel, [])
+ for item in cooldowns_list:
+ if item.call_arguments == call_arguments:
+ item.timeout_timestamp = timeout_timestamp
+ else:
+ cooldowns_list.append(_CooldownItem(call_arguments, timeout_timestamp))
+
+ def is_on_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> bool:
+ """Check whether ``call_arguments`` is on cooldown in ``channel``."""
+ current_time = time.monotonic()
+ try:
+ return self._cooldowns.get((channel, call_arguments), -math.inf) > current_time
+ except TypeError:
+ cooldowns_list = self._cooldowns_non_hashable.get(channel, None)
+ if cooldowns_list is None:
+ return False
+
+ for item in cooldowns_list:
+ if item.call_arguments == call_arguments:
+ return item.timeout_timestamp > current_time
+ return False
+
+ async def _periodical_cleanup(self, initial_delay: float) -> None:
+ """
+ Wait for `initial_delay`, after that delete stale items every hour.
+
+ The `initial_delay` ensures we're not running cleanups for every command at the same time.
+ """
+ await asyncio.sleep(initial_delay)
+ while True:
+ await asyncio.sleep(60 * 60)
+ self._delete_stale_items()
+
+ def _delete_stale_items(self) -> None:
+ """Remove expired items from internal collections."""
+ current_time = time.monotonic()
+
+ for key, timeout_timestamp in self._cooldowns.copy().items():
+ if timeout_timestamp < current_time:
+ del self._cooldowns[key]
+
+ for key, cooldowns_list in self._cooldowns_non_hashable.copy().items():
+ filtered_cooldowns = [
+ cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time
+ ]
+
+ if not filtered_cooldowns:
+ del self._cooldowns_non_hashable[key]
+ else:
+ self._cooldowns_non_hashable[key] = filtered_cooldowns
+
+
+def block_duplicate_invocations(
+ *, cooldown_duration: float = 5, send_notice: bool = False
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+ """
+ Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds.
+
+ Args:
+ cooldown_duration: Length of the cooldown in seconds.
+ send_notice: If True, the user is notified of the cooldown with a reply.
+
+ Returns:
+ A decorator that adds a wrapper which applies the cooldowns.
+
+ Warning:
+ The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown.
+ """
+
+ def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
+ mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration)
+
+ @command_wraps(func)
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ arg_tuple = (*args[2:], *kwargs.items())
+ ctx = typing.cast("Context[BotBase]", args[1])
+ channel = ctx.channel
+
+ if not isinstance(channel, discord.DMChannel):
+ if mgr.is_on_cooldown(ctx.channel, arg_tuple):
+ if send_notice:
+ with suppress(discord.NotFound):
+ await ctx.reply("The command is on cooldown with the given arguments.")
+ raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs)
+ mgr.set_cooldown(ctx.channel, arg_tuple)
+
+ return await func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
diff --git a/tests/botcore/utils/test_cooldown.py b/tests/botcore/utils/test_cooldown.py
new file mode 100644
index 00000000..e7fe0f59
--- /dev/null
+++ b/tests/botcore/utils/test_cooldown.py
@@ -0,0 +1,48 @@
+import unittest
+from unittest.mock import patch
+
+from botcore.utils.cooldown import _ArgsTuple, _CommandCooldownManager
+
+
+def create_argument_tuple(*args, **kwargs) -> _ArgsTuple:
+ return (*args, *kwargs.items())
+
+
+class CommandCooldownManagerTests(unittest.IsolatedAsyncioTestCase):
+ test_call_args = (
+ create_argument_tuple(0),
+ create_argument_tuple(a=0),
+ create_argument_tuple([]),
+ create_argument_tuple(a=[]),
+ create_argument_tuple(1, 2, 3, a=4, b=5, c=6),
+ create_argument_tuple([1], [2], [3], a=[4], b=[5], c=[6]),
+ create_argument_tuple([1], 2, [3], a=4, b=[5], c=6),
+ )
+
+ async def asyncSetUp(self):
+ self.cooldown_manager = _CommandCooldownManager(cooldown_duration=5)
+
+ def test_no_cooldown_on_unset(self):
+ for call_args in self.test_call_args:
+ with self.subTest(arguments_tuple=call_args, channel=0):
+ self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args))
+
+ for call_args in self.test_call_args:
+ with self.subTest(arguments_tuple=call_args, channel=1):
+ self.assertFalse(self.cooldown_manager.is_on_cooldown(1, call_args))
+
+ @patch("time.monotonic")
+ def test_cooldown_is_set(self, monotonic):
+ monotonic.side_effect = lambda: 0
+ for call_args in self.test_call_args:
+ with self.subTest(arguments_tuple=call_args):
+ self.cooldown_manager.set_cooldown(0, call_args)
+ self.assertTrue(self.cooldown_manager.is_on_cooldown(0, call_args))
+
+ @patch("time.monotonic")
+ def test_cooldown_expires(self, monotonic):
+ for call_args in self.test_call_args:
+ monotonic.side_effect = (0, 1000)
+ with self.subTest(arguments_tuple=call_args):
+ self.cooldown_manager.set_cooldown(0, call_args)
+ self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args))