diff options
author | 2022-06-21 16:05:02 +0200 | |
---|---|---|
committer | 2022-06-21 19:23:33 +0200 | |
commit | 977363d6945ceb9b6bd6750f16c23998bbf3edc1 (patch) | |
tree | 22a45bd19d56839c5e0a0933680dc79961ee3d7d /botcore/utils/cooldown.py | |
parent | Fix typehint (diff) |
generalize handling of fully hashable args, and args with non-hashable parts
Diffstat (limited to 'botcore/utils/cooldown.py')
-rw-r--r-- | botcore/utils/cooldown.py | 91 |
1 files changed, 57 insertions, 34 deletions
diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py index 34e88901..59b0722e 100644 --- a/botcore/utils/cooldown.py +++ b/botcore/utils/cooldown.py @@ -3,11 +3,10 @@ from __future__ import annotations import asyncio -import math import random import time import typing -from collections.abc import Awaitable, Hashable +from collections.abc import Awaitable, Hashable, Iterable from contextlib import suppress from dataclasses import dataclass from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable @@ -20,7 +19,8 @@ from botcore.utils.function import command_wraps __all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] -_ArgsTuple = tuple[object, ...] +_ArgsList = list[object] +_HashableArgsTuple = tuple[Hashable, ...] if typing.TYPE_CHECKING: from botcore import BotBase @@ -62,10 +62,34 @@ class CommandOnCooldown(CommandError, typing.Generic[P, R]): @dataclass class _CooldownItem: - call_arguments: _ArgsTuple + arguments: _ArgsList timeout_timestamp: float +@dataclass +class _SeparatedArguments: + """Arguments separated into their hashable and non-hashable parts.""" + + hashable: _HashableArgsTuple + non_hashable: _ArgsList + + @classmethod + def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self: + """Create a new instance from full call arguments.""" + hashable = list[Hashable]() + non_hashable = list[object]() + + for item in call_arguments: + try: + hash(item) + except TypeError: + non_hashable.append(item) + else: + hashable.append(item) + + return cls(tuple(hashable), non_hashable) + + class _CommandCooldownManager: """ Manage invocation cooldowns for a command through the arguments the command is called with. @@ -76,43 +100,46 @@ class _CommandCooldownManager: """ def __init__(self, *, cooldown_duration: float): - self._cooldowns = dict[tuple[Hashable, _ArgsTuple], float]() - self._cooldowns_non_hashable = dict[Hashable, list[_CooldownItem]]() + self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]() self._cooldown_duration = cooldown_duration self.cleanup_task = scheduling.create_task( self._periodical_cleanup(random.uniform(0, 10)), name="CooldownManager cleanup", ) - def set_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> None: + def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None: """Set `call_arguments` arguments on cooldown in `channel`.""" timeout_timestamp = time.monotonic() + self._cooldown_duration + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.setdefault( + (channel, separated_arguments.hashable), + [] + ) - try: - self._cooldowns[(channel, call_arguments)] = timeout_timestamp - except TypeError: - cooldowns_list = self._cooldowns_non_hashable.setdefault(channel, []) - for item in cooldowns_list: - if item.call_arguments == call_arguments: - item.timeout_timestamp = timeout_timestamp - else: - cooldowns_list.append(_CooldownItem(call_arguments, timeout_timestamp)) + for item in cooldowns_list: + if item.arguments == separated_arguments.non_hashable: + item.timeout_timestamp = timeout_timestamp + return - def is_on_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> bool: + cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp)) + + def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool: """Check whether `call_arguments` is on cooldown in `channel`.""" current_time = time.monotonic() - try: - return self._cooldowns.get((channel, call_arguments), -math.inf) > current_time - except TypeError: - cooldowns_list = self._cooldowns_non_hashable.get(channel, None) - if cooldowns_list is None: - return False - - for item in cooldowns_list: - if item.call_arguments == call_arguments: - return item.timeout_timestamp > current_time + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.get( + (channel, separated_arguments.hashable), + None + ) + + if cooldowns_list is None: return False + for item in cooldowns_list: + if item.arguments == separated_arguments.non_hashable: + return item.timeout_timestamp > current_time + return False + async def _periodical_cleanup(self, initial_delay: float) -> None: """ Delete stale items every hour after waiting for `initial_delay`. @@ -128,19 +155,15 @@ class _CommandCooldownManager: """Remove expired items from internal collections.""" current_time = time.monotonic() - for key, timeout_timestamp in self._cooldowns.copy().items(): - if timeout_timestamp < current_time: - del self._cooldowns[key] - - for key, cooldowns_list in self._cooldowns_non_hashable.copy().items(): + for key, cooldowns_list in self._cooldowns.copy().items(): filtered_cooldowns = [ cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time ] if not filtered_cooldowns: - del self._cooldowns_non_hashable[key] + del self._cooldowns[key] else: - self._cooldowns_non_hashable[key] = filtered_cooldowns + self._cooldowns[key] = filtered_cooldowns def block_duplicate_invocations( |