aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils/cooldown.py
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_core/utils/cooldown.py')
-rw-r--r--pydis_core/utils/cooldown.py220
1 files changed, 220 insertions, 0 deletions
diff --git a/pydis_core/utils/cooldown.py b/pydis_core/utils/cooldown.py
new file mode 100644
index 00000000..5129befd
--- /dev/null
+++ b/pydis_core/utils/cooldown.py
@@ -0,0 +1,220 @@
+"""Helpers for setting a cooldown on commands."""
+
+from __future__ import annotations
+
+import asyncio
+import random
+import time
+import typing
+import weakref
+from collections.abc import Awaitable, Callable, Hashable, Iterable
+from contextlib import suppress
+from dataclasses import dataclass
+
+import discord
+from discord.ext.commands import CommandError, Context
+
+from pydis_core.utils import scheduling
+from pydis_core.utils.function import command_wraps
+
+__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"]
+
+_KEYWORD_SEP_SENTINEL = object()
+
+_ArgsList = list[object]
+_HashableArgsTuple = tuple[Hashable, ...]
+
+if typing.TYPE_CHECKING:
+ import typing_extensions
+ from pydis_core import BotBase
+
+P = typing.ParamSpec("P")
+"""The command's signature."""
+R = typing.TypeVar("R")
+"""The command's return value."""
+
+
+class CommandOnCooldown(CommandError, typing.Generic[P, R]):
+ """Raised when a command is invoked while on cooldown."""
+
+ def __init__(
+ self,
+ message: str | None,
+ function: Callable[P, Awaitable[R]],
+ /,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ):
+ super().__init__(message, function, args, kwargs)
+ self._function = function
+ self._args = args
+ self._kwargs = kwargs
+
+ async def call_without_cooldown(self) -> R:
+ """
+ Run the command this cooldown blocked.
+
+ Returns:
+ The command's return value.
+ """
+ return await self._function(*self._args, **self._kwargs)
+
+
+@dataclass
+class _CooldownItem:
+ non_hashable_arguments: _ArgsList
+ timeout_timestamp: float
+
+
+@dataclass
+class _SeparatedArguments:
+ """Arguments separated into their hashable and non-hashable parts."""
+
+ hashable: _HashableArgsTuple
+ non_hashable: _ArgsList
+
+ @classmethod
+ def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self:
+ """Create a new instance from full call arguments."""
+ hashable = list[Hashable]()
+ non_hashable = list[object]()
+
+ for item in call_arguments:
+ try:
+ hash(item)
+ except TypeError:
+ non_hashable.append(item)
+ else:
+ hashable.append(item)
+
+ return cls(tuple(hashable), non_hashable)
+
+
+class _CommandCooldownManager:
+ """
+ Manage invocation cooldowns for a command through the arguments the command is called with.
+
+ Use `set_cooldown` to set a cooldown,
+ and `is_on_cooldown` to check for a cooldown for a channel with the given arguments.
+ A cooldown lasts for `cooldown_duration` seconds.
+ """
+
+ def __init__(self, *, cooldown_duration: float):
+ self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]()
+ self._cooldown_duration = cooldown_duration
+ self.cleanup_task = scheduling.create_task(
+ self._periodical_cleanup(random.uniform(0, 10)),
+ name="CooldownManager cleanup",
+ )
+ weakref.finalize(self, self.cleanup_task.cancel)
+
+ def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None:
+ """Set `call_arguments` arguments on cooldown in `channel`."""
+ timeout_timestamp = time.monotonic() + self._cooldown_duration
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.setdefault(
+ (channel, separated_arguments.hashable),
+ [],
+ )
+
+ for item in cooldowns_list:
+ if item.non_hashable_arguments == separated_arguments.non_hashable:
+ item.timeout_timestamp = timeout_timestamp
+ return
+
+ cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp))
+
+ def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool:
+ """Check whether `call_arguments` is on cooldown in `channel`."""
+ current_time = time.monotonic()
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.get(
+ (channel, separated_arguments.hashable),
+ [],
+ )
+
+ for item in cooldowns_list:
+ if item.non_hashable_arguments == separated_arguments.non_hashable:
+ return item.timeout_timestamp > current_time
+ return False
+
+ async def _periodical_cleanup(self, initial_delay: float) -> None:
+ """
+ Delete stale items every hour after waiting for `initial_delay`.
+
+ The `initial_delay` ensures cleanups are not running for every command at the same time.
+ A strong reference to self is only kept while cleanup is running.
+ """
+ weak_self = weakref.ref(self)
+ del self
+
+ await asyncio.sleep(initial_delay)
+ while True:
+ await asyncio.sleep(60 * 60)
+ weak_self()._delete_stale_items()
+
+ def _delete_stale_items(self) -> None:
+ """Remove expired items from internal collections."""
+ current_time = time.monotonic()
+
+ for key, cooldowns_list in self._cooldowns.copy().items():
+ filtered_cooldowns = [
+ cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time
+ ]
+
+ if not filtered_cooldowns:
+ del self._cooldowns[key]
+ else:
+ self._cooldowns[key] = filtered_cooldowns
+
+
+def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]:
+ return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items())
+
+
+def block_duplicate_invocations(
+ *,
+ cooldown_duration: float = 5,
+ send_notice: bool = False,
+ args_preprocessor: Callable[P, Iterable[object]] | None = None,
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+ """
+ Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds.
+
+ Args:
+ cooldown_duration: Length of the cooldown in seconds.
+ send_notice: If :obj:`True`, notify the user about the cooldown with a reply.
+ args_preprocessor: If specified, this function is called with the args and kwargs the function is called with,
+ its return value is then used to check for the cooldown instead of the raw arguments.
+
+ Returns:
+ A decorator that adds a wrapper which applies the cooldowns.
+
+ Warning:
+ The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown.
+ """
+
+ def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
+ mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration)
+
+ @command_wraps(func)
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ if args_preprocessor is not None:
+ all_args = args_preprocessor(*args, **kwargs)
+ else:
+ all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command
+ ctx = typing.cast("Context[BotBase]", args[1])
+
+ if not isinstance(ctx.channel, discord.DMChannel):
+ if mgr.is_on_cooldown(ctx.channel, all_args):
+ if send_notice:
+ with suppress(discord.NotFound):
+ await ctx.reply("The command is on cooldown with the given arguments.")
+ raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs)
+ mgr.set_cooldown(ctx.channel, all_args)
+
+ return await func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator