aboutsummaryrefslogtreecommitdiffstats
path: root/botcore/utils/cooldown.py
diff options
context:
space:
mode:
authorGravatar Numerlor <[email protected]>2022-06-21 16:05:02 +0200
committerGravatar Numerlor <[email protected]>2022-06-21 19:23:33 +0200
commit977363d6945ceb9b6bd6750f16c23998bbf3edc1 (patch)
tree22a45bd19d56839c5e0a0933680dc79961ee3d7d /botcore/utils/cooldown.py
parentFix typehint (diff)
generalize handling of fully hashable args, and args with non-hashable parts
Diffstat (limited to 'botcore/utils/cooldown.py')
-rw-r--r--botcore/utils/cooldown.py91
1 files changed, 57 insertions, 34 deletions
diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py
index 34e88901..59b0722e 100644
--- a/botcore/utils/cooldown.py
+++ b/botcore/utils/cooldown.py
@@ -3,11 +3,10 @@
from __future__ import annotations
import asyncio
-import math
import random
import time
import typing
-from collections.abc import Awaitable, Hashable
+from collections.abc import Awaitable, Hashable, Iterable
from contextlib import suppress
from dataclasses import dataclass
from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable
@@ -20,7 +19,8 @@ from botcore.utils.function import command_wraps
__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"]
-_ArgsTuple = tuple[object, ...]
+_ArgsList = list[object]
+_HashableArgsTuple = tuple[Hashable, ...]
if typing.TYPE_CHECKING:
from botcore import BotBase
@@ -62,10 +62,34 @@ class CommandOnCooldown(CommandError, typing.Generic[P, R]):
@dataclass
class _CooldownItem:
- call_arguments: _ArgsTuple
+ arguments: _ArgsList
timeout_timestamp: float
+@dataclass
+class _SeparatedArguments:
+ """Arguments separated into their hashable and non-hashable parts."""
+
+ hashable: _HashableArgsTuple
+ non_hashable: _ArgsList
+
+ @classmethod
+ def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self:
+ """Create a new instance from full call arguments."""
+ hashable = list[Hashable]()
+ non_hashable = list[object]()
+
+ for item in call_arguments:
+ try:
+ hash(item)
+ except TypeError:
+ non_hashable.append(item)
+ else:
+ hashable.append(item)
+
+ return cls(tuple(hashable), non_hashable)
+
+
class _CommandCooldownManager:
"""
Manage invocation cooldowns for a command through the arguments the command is called with.
@@ -76,43 +100,46 @@ class _CommandCooldownManager:
"""
def __init__(self, *, cooldown_duration: float):
- self._cooldowns = dict[tuple[Hashable, _ArgsTuple], float]()
- self._cooldowns_non_hashable = dict[Hashable, list[_CooldownItem]]()
+ self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]()
self._cooldown_duration = cooldown_duration
self.cleanup_task = scheduling.create_task(
self._periodical_cleanup(random.uniform(0, 10)),
name="CooldownManager cleanup",
)
- def set_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> None:
+ def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None:
"""Set `call_arguments` arguments on cooldown in `channel`."""
timeout_timestamp = time.monotonic() + self._cooldown_duration
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.setdefault(
+ (channel, separated_arguments.hashable),
+ []
+ )
- try:
- self._cooldowns[(channel, call_arguments)] = timeout_timestamp
- except TypeError:
- cooldowns_list = self._cooldowns_non_hashable.setdefault(channel, [])
- for item in cooldowns_list:
- if item.call_arguments == call_arguments:
- item.timeout_timestamp = timeout_timestamp
- else:
- cooldowns_list.append(_CooldownItem(call_arguments, timeout_timestamp))
+ for item in cooldowns_list:
+ if item.arguments == separated_arguments.non_hashable:
+ item.timeout_timestamp = timeout_timestamp
+ return
- def is_on_cooldown(self, channel: Hashable, call_arguments: _ArgsTuple) -> bool:
+ cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp))
+
+ def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool:
"""Check whether `call_arguments` is on cooldown in `channel`."""
current_time = time.monotonic()
- try:
- return self._cooldowns.get((channel, call_arguments), -math.inf) > current_time
- except TypeError:
- cooldowns_list = self._cooldowns_non_hashable.get(channel, None)
- if cooldowns_list is None:
- return False
-
- for item in cooldowns_list:
- if item.call_arguments == call_arguments:
- return item.timeout_timestamp > current_time
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.get(
+ (channel, separated_arguments.hashable),
+ None
+ )
+
+ if cooldowns_list is None:
return False
+ for item in cooldowns_list:
+ if item.arguments == separated_arguments.non_hashable:
+ return item.timeout_timestamp > current_time
+ return False
+
async def _periodical_cleanup(self, initial_delay: float) -> None:
"""
Delete stale items every hour after waiting for `initial_delay`.
@@ -128,19 +155,15 @@ class _CommandCooldownManager:
"""Remove expired items from internal collections."""
current_time = time.monotonic()
- for key, timeout_timestamp in self._cooldowns.copy().items():
- if timeout_timestamp < current_time:
- del self._cooldowns[key]
-
- for key, cooldowns_list in self._cooldowns_non_hashable.copy().items():
+ for key, cooldowns_list in self._cooldowns.copy().items():
filtered_cooldowns = [
cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time
]
if not filtered_cooldowns:
- del self._cooldowns_non_hashable[key]
+ del self._cooldowns[key]
else:
- self._cooldowns_non_hashable[key] = filtered_cooldowns
+ self._cooldowns[key] = filtered_cooldowns
def block_duplicate_invocations(