diff options
author | 2022-06-14 00:37:07 +0200 | |
---|---|---|
committer | 2022-06-15 00:02:08 +0200 | |
commit | a3b0ffb72a1e72fc3be0e96c7407ddff2ade67c9 (patch) | |
tree | 5b19f75a36102181aaeafe9743f2abfea129483b /botcore | |
parent | add typing-extensions (diff) |
Add decorator to block duplicate command invocations in a channel
Diffstat (limited to 'botcore')
-rw-r--r-- | botcore/utils/__init__.py | 3 | ||||
-rw-r--r-- | botcore/utils/cooldown.py | 184 |
2 files changed, 186 insertions, 1 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index 6055d144..cfc5e99d 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,6 @@ """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, function, logging, members, regex, scheduling +from botcore.utils import _monkey_patches, caching, channel, cooldown, function, logging, members, regex, scheduling from botcore.utils._extensions import unqualify @@ -24,6 +24,7 @@ __all__ = [ apply_monkey_patches, caching, channel, + cooldown, function, logging, members, diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py new file mode 100644 index 00000000..a06dce46 --- /dev/null +++ b/botcore/utils/cooldown.py @@ -0,0 +1,184 @@ +"""Helpers for setting a cooldown on commands.""" + +from __future__ import annotations + +import asyncio +import math +import random +import time +import typing +from collections.abc import Awaitable, Hashable +from contextlib import suppress +from dataclasses import dataclass +from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable + +import discord +from discord.ext.commands import CommandError, Context + +from botcore.utils import scheduling +from botcore.utils.function import command_wraps + +__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] + +_ArgsTuple = tuple[object] + +if typing.TYPE_CHECKING: + from botcore import BotBase + import typing_extensions + P = typing_extensions.ParamSpec("P") + P.__constraints__ = () +else: + P = typing.TypeVar("P") + """The command's signature.""" + +R = typing.TypeVar("R") +"""The command's return value.""" + + +class CommandOnCooldown(CommandError, typing.Generic[P, R]): + """Raised when a command is invoked while on cooldown.""" + + def __init__( + self, + message: str | None, + function: Callable[P, Awaitable[R]], + *args: object, + **kwargs: object, + ): + super().__init__(message, function, args, kwargs) + self._function = function + self._args = args + self._kwargs = kwargs + + async def call_without_cooldown(self) -> R: + """ + Run the command this cooldown blocked. + + Returns: + The command's return value. + """ + return await self._function(*self._args, **self._kwargs) + + +@dataclass +class _CooldownItem: + call_arguments: _ArgsTuple + timeout_timestamp: float + + +class _CommandCooldownManager: + """ + Manage invocation cooldowns for a command through the arguments the command is called with. + + A cooldown is set through `set_cooldown` for a channel with the given `call_arguments`, + if `is_on_cooldown` is checked within `cooldown_duration` seconds + of the call to `set_cooldown` with the same arguments, True is returned. + """ + + def __init__(self, *, cooldown_duration: float): + self._cooldowns = dict[tuple[Hashable, _ArgsTuple], float]() + self._cooldowns_non_hashable = dict[Hashable, list[_CooldownItem]]() + self._cooldown_duration = cooldown_duration + self.cleanup_task = scheduling.create_task( + self._periodical_cleanup(random.uniform(0, 10)), + name="CooldownManager cleanup", + ) + + def set_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> None: + """Set `call_arguments` arguments on cooldown in `channel`.""" + timeout_timestamp = time.monotonic() + self._cooldown_duration + + try: + self._cooldowns[(channel, call_arguments)] = timeout_timestamp + except TypeError: + cooldowns_list = self._cooldowns_non_hashable.setdefault(channel, []) + for item in cooldowns_list: + if item.call_arguments == call_arguments: + item.timeout_timestamp = timeout_timestamp + else: + cooldowns_list.append(_CooldownItem(call_arguments, timeout_timestamp)) + + def is_on_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> bool: + """Check whether ``call_arguments`` is on cooldown in ``channel``.""" + current_time = time.monotonic() + try: + return self._cooldowns.get((channel, call_arguments), -math.inf) > current_time + except TypeError: + cooldowns_list = self._cooldowns_non_hashable.get(channel, None) + if cooldowns_list is None: + return False + + for item in cooldowns_list: + if item.call_arguments == call_arguments: + return item.timeout_timestamp > current_time + return False + + async def _periodical_cleanup(self, initial_delay: float) -> None: + """ + Wait for `initial_delay`, after that delete stale items every hour. + + The `initial_delay` ensures we're not running cleanups for every command at the same time. + """ + await asyncio.sleep(initial_delay) + while True: + await asyncio.sleep(60 * 60) + self._delete_stale_items() + + def _delete_stale_items(self) -> None: + """Remove expired items from internal collections.""" + current_time = time.monotonic() + + for key, timeout_timestamp in self._cooldowns.copy().items(): + if timeout_timestamp < current_time: + del self._cooldowns[key] + + for key, cooldowns_list in self._cooldowns_non_hashable.copy().items(): + filtered_cooldowns = [ + cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time + ] + + if not filtered_cooldowns: + del self._cooldowns_non_hashable[key] + else: + self._cooldowns_non_hashable[key] = filtered_cooldowns + + +def block_duplicate_invocations( + *, cooldown_duration: float = 5, send_notice: bool = False +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """ + Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds. + + Args: + cooldown_duration: Length of the cooldown in seconds. + send_notice: If True, the user is notified of the cooldown with a reply. + + Returns: + A decorator that adds a wrapper which applies the cooldowns. + + Warning: + The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown. + """ + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration) + + @command_wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + arg_tuple = (*args[2:], *kwargs.items()) + ctx = typing.cast("Context[BotBase]", args[1]) + channel = ctx.channel + + if not isinstance(channel, discord.DMChannel): + if mgr.is_on_cooldown(ctx.channel, arg_tuple): + if send_notice: + with suppress(discord.NotFound): + await ctx.reply("The command is on cooldown with the given arguments.") + raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs) + mgr.set_cooldown(ctx.channel, arg_tuple) + + return await func(*args, **kwargs) + + return wrapper + + return decorator |