diff options
Diffstat (limited to 'botcore')
-rw-r--r-- | botcore/utils/__init__.py | 16 | ||||
-rw-r--r-- | botcore/utils/cooldown.py | 218 | ||||
-rw-r--r-- | botcore/utils/function.py | 112 |
3 files changed, 345 insertions, 1 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index 95e89d20..09aaa45f 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,18 @@ """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, commands, interactions, logging, members, regex, scheduling +from botcore.utils import ( + _monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, +) from botcore.utils._extensions import unqualify @@ -25,6 +37,8 @@ __all__ = [ caching, channel, commands, + cooldown, + function, interactions, logging, members, diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py new file mode 100644 index 00000000..ee65033d --- /dev/null +++ b/botcore/utils/cooldown.py @@ -0,0 +1,218 @@ +"""Helpers for setting a cooldown on commands.""" + +from __future__ import annotations + +import asyncio +import random +import time +import typing +import weakref +from collections.abc import Awaitable, Callable, Hashable, Iterable +from contextlib import suppress +from dataclasses import dataclass + +import discord +from discord.ext.commands import CommandError, Context + +from botcore.utils import scheduling +from botcore.utils.function import command_wraps + +__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] + +_KEYWORD_SEP_SENTINEL = object() + +_ArgsList = list[object] +_HashableArgsTuple = tuple[Hashable, ...] + +if typing.TYPE_CHECKING: + from botcore import BotBase + import typing_extensions + P = typing_extensions.ParamSpec("P") + P.__constraints__ = () +else: + P = typing.TypeVar("P") + """The command's signature.""" + +R = typing.TypeVar("R") +"""The command's return value.""" + + +class CommandOnCooldown(CommandError, typing.Generic[P, R]): + """Raised when a command is invoked while on cooldown.""" + + def __init__( + self, + message: str | None, + function: Callable[P, Awaitable[R]], + *args: P.args, + **kwargs: P.kwargs, + ): + super().__init__(message, function, args, kwargs) + self._function = function + self._args = args + self._kwargs = kwargs + + async def call_without_cooldown(self) -> R: + """ + Run the command this cooldown blocked. + + Returns: + The command's return value. + """ + return await self._function(*self._args, **self._kwargs) + + +@dataclass +class _CooldownItem: + arguments: _ArgsList + timeout_timestamp: float + + +@dataclass +class _SeparatedArguments: + """Arguments separated into their hashable and non-hashable parts.""" + + hashable: _HashableArgsTuple + non_hashable: _ArgsList + + @classmethod + def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self: + """Create a new instance from full call arguments.""" + hashable = list[Hashable]() + non_hashable = list[object]() + + for item in call_arguments: + try: + hash(item) + except TypeError: + non_hashable.append(item) + else: + hashable.append(item) + + return cls(tuple(hashable), non_hashable) + + +class _CommandCooldownManager: + """ + Manage invocation cooldowns for a command through the arguments the command is called with. + + Use `set_cooldown` to set a cooldown, + and `is_on_cooldown` to check for a cooldown for a channel with the given arguments. + A cooldown lasts for `cooldown_duration` seconds. + """ + + def __init__(self, *, cooldown_duration: float): + self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]() + self._cooldown_duration = cooldown_duration + self.cleanup_task = scheduling.create_task( + self._periodical_cleanup(random.uniform(0, 10)), + name="CooldownManager cleanup", + ) + weakref.finalize(self, self.cleanup_task.cancel) + + def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None: + """Set `call_arguments` arguments on cooldown in `channel`.""" + timeout_timestamp = time.monotonic() + self._cooldown_duration + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.setdefault( + (channel, separated_arguments.hashable), + [] + ) + + for item in cooldowns_list: + if item.arguments == separated_arguments.non_hashable: + item.timeout_timestamp = timeout_timestamp + return + + cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp)) + + def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool: + """Check whether `call_arguments` is on cooldown in `channel`.""" + current_time = time.monotonic() + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.get( + (channel, separated_arguments.hashable), + None + ) + + if cooldowns_list is None: + return False + + for item in cooldowns_list: + if item.arguments == separated_arguments.non_hashable: + return item.timeout_timestamp > current_time + return False + + async def _periodical_cleanup(self, initial_delay: float) -> None: + """ + Delete stale items every hour after waiting for `initial_delay`. + + The `initial_delay` ensures cleanups are not running for every command at the same time. + A strong reference to self is only kept while cleanup is running. + """ + weak_self = weakref.ref(self) + del self + + await asyncio.sleep(initial_delay) + while True: + await asyncio.sleep(60 * 60) + weak_self()._delete_stale_items() + + def _delete_stale_items(self) -> None: + """Remove expired items from internal collections.""" + current_time = time.monotonic() + + for key, cooldowns_list in self._cooldowns.copy().items(): + filtered_cooldowns = [ + cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time + ] + + if not filtered_cooldowns: + del self._cooldowns[key] + else: + self._cooldowns[key] = filtered_cooldowns + + +def _create_argument_tuple(*args: object, **kwargs: object) -> Iterable[object]: + return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items()) + + +def block_duplicate_invocations( + *, cooldown_duration: float = 5, send_notice: bool = False +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """ + Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds. + + Args: + cooldown_duration: Length of the cooldown in seconds. + send_notice: If :obj:`True`, notify the user about the cooldown with a reply. + + Returns: + A decorator that adds a wrapper which applies the cooldowns. + + Warning: + The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown. + """ + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration) + + @command_wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + arg_tuple = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command + ctx = typing.cast("Context[BotBase]", args[1]) + channel = ctx.channel + + if not isinstance(channel, discord.DMChannel): + if mgr.is_on_cooldown(ctx.channel, arg_tuple): + if send_notice: + with suppress(discord.NotFound): + await ctx.reply("The command is on cooldown with the given arguments.") + raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs) + mgr.set_cooldown(ctx.channel, arg_tuple) + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/botcore/utils/function.py b/botcore/utils/function.py new file mode 100644 index 00000000..0e90d4c5 --- /dev/null +++ b/botcore/utils/function.py @@ -0,0 +1,112 @@ +"""Utils for manipulating functions.""" + +from __future__ import annotations + +import functools +import types +import typing +from collections.abc import Callable, Sequence, Set + +__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] + + +if typing.TYPE_CHECKING: + import typing_extensions + _P = typing_extensions.ParamSpec("_P") + _R = typing.TypeVar("_R") + + +class GlobalNameConflictError(Exception): + """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" + + +def update_wrapper_globals( + wrapper: Callable[_P, _R], + wrapped: Callable[_P, _R], + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[_P, _R]: + r""" + Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals. + + For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function + to resolve their values. This breaks for decorators that replace the function because they have + their own globals. + + .. warning:: + This function captures the state of ``wrapped``\'s module's globals when it's called; + changes won't be reflected in the new function's globals. + + Args: + wrapper: The function to wrap. + wrapped: The function to wrap with. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Raises: + :exc:`GlobalNameConflictError`: + If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, + and is not in ``ignored_conflict_names``. + """ + wrapped = typing.cast(types.FunctionType, wrapped) + wrapper = typing.cast(types.FunctionType, wrapper) + + annotation_global_names = ( + ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) + ) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = ( + set(wrapper.__code__.co_names) + & set(annotation_global_names) + & set(wrapped.__globals__) + & set(wrapper.__globals__) + - ignored_conflict_names + ) + if shared_globals: + raise GlobalNameConflictError( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " + f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." + ) + + new_globals = wrapper.__globals__.copy() + new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) + return types.FunctionType( + code=wrapper.__code__, + globals=new_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) + + +def command_wraps( + wrapped: Callable[_P, _R], + assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, + updated: Sequence[str] = functools.WRAPPER_UPDATES, + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + r""" + Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. + + See :func:`update_wrapper_globals` for more details on how the globals are updated. + + Args: + wrapped: The function to wrap with. + assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. + updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Returns: + A decorator that behaves like :func:`functools.wraps`, + with the wrapper replaced with the function :func:`update_wrapper_globals` returned. + """ # noqa: D200 + def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: + return functools.update_wrapper( + update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), + wrapped, + assigned, + updated, + ) + + return decorator |