diff options
| -rw-r--r-- | bot/cogs/reminders.py | 2 | ||||
| -rw-r--r-- | bot/decorators.py | 85 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 3 | ||||
| -rw-r--r-- | bot/utils/lock.py | 90 |
4 files changed, 94 insertions, 86 deletions
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 734e0bd2d..25b2c9421 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -14,9 +14,9 @@ from discord.ext.commands import Cog, Context, Greedy, group from bot.bot import Bot from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration -from bot.decorators import lock_arg from bot.pagination import LinePaginator from bot.utils.checks import without_role_check +from bot.utils.lock import lock_arg from bot.utils.messages import send_denial from bot.utils.scheduling import Scheduler from bot.utils.time import humanize_delta diff --git a/bot/decorators.py b/bot/decorators.py index aabbe2cc9..2ec0cb122 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,26 +1,17 @@ import asyncio -import inspect import logging import typing as t -from collections import defaultdict from contextlib import suppress -from functools import partial, wraps -from weakref import WeakValueDictionary +from functools import wraps from discord import Member, NotFound from discord.ext.commands import Cog, Context, check from bot.constants import Channels, RedirectOutput -from bot.errors import LockedResourceError -from bot.utils import LockGuard, function +from bot.utils import function from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check log = logging.getLogger(__name__) -__lock_dicts = defaultdict(WeakValueDictionary) - -_IdCallableReturn = t.Union[t.Hashable, t.Awaitable[t.Hashable]] -_IdCallable = t.Callable[[function.BoundArgs], _IdCallableReturn] -ResourceId = t.Union[t.Hashable, _IdCallable] def in_whitelist( @@ -66,78 +57,6 @@ def without_role(*role_ids: int) -> t.Callable: return check(predicate) -def lock(namespace: t.Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> t.Callable: - """ - Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. - - If any other mutually exclusive function currently holds the lock for a resource, do not run the - decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if - the lock cannot be acquired. - - `namespace` is an identifier used to prevent collisions among resource IDs. - - `resource_id` identifies a resource on which to perform a mutually exclusive operation. - It may also be a callable or awaitable which will return the resource ID given an ordered - mapping of the parameters' names to arguments' values. - - If decorating a command, this decorator must go before (below) the `command` decorator. - """ - def decorator(func: t.Callable) -> t.Callable: - name = func.__name__ - - @wraps(func) - async def wrapper(*args, **kwargs) -> t.Any: - log.trace(f"{name}: mutually exclusive decorator called") - - if callable(resource_id): - log.trace(f"{name}: binding args to signature") - bound_args = function.get_bound_args(func, args, kwargs) - - log.trace(f"{name}: calling the given callable to get the resource ID") - id_ = resource_id(bound_args) - - if inspect.isawaitable(id_): - log.trace(f"{name}: awaiting to get resource ID") - id_ = await id_ - else: - id_ = resource_id - - log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}") - - # Get the lock for the ID. Create a lock if one doesn't exist yet. - locks = __lock_dicts[namespace] - lock = locks.setdefault(id_, LockGuard()) - - if not lock.locked(): - log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...") - with lock: - return await func(*args, **kwargs) - else: - log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") - if raise_error: - raise LockedResourceError(str(namespace), id_) - - return wrapper - return decorator - - -def lock_arg( - namespace: t.Hashable, - name_or_pos: function.Argument, - func: t.Callable[[t.Any], _IdCallableReturn] = None, - *, - raise_error: bool = False, -) -> t.Callable: - """ - Apply the `lock` decorator using the value of the arg at the given name/position as the ID. - - `func` is an optional callable or awaitable which will return the ID given the argument value. - See `lock` docs for more information. - """ - decorator_func = partial(lock, namespace, raise_error=raise_error) - return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) - - def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable: """ Changes the channel in the context of the command to redirect the output to a certain channel. diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 0dd9605e8..b73410e96 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -2,10 +2,9 @@ from abc import ABCMeta from discord.ext.commands import CogMeta -from bot.utils.lock import LockGuard from bot.utils.redis_cache import RedisCache -__all__ = ["CogABCMeta", "LockGuard", "RedisCache"] +__all__ = ["CogABCMeta", "RedisCache"] class CogABCMeta(CogMeta, ABCMeta): diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 8f1b738aa..5c9dd3725 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,3 +1,21 @@ +import inspect +import logging +from collections import defaultdict +from functools import partial, wraps +from typing import Any, Awaitable, Callable, Hashable, Union +from weakref import WeakValueDictionary + +from bot.errors import LockedResourceError +from bot.utils import function + +log = logging.getLogger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary) + +_IdCallableReturn = Union[Hashable, Awaitable[Hashable]] +_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] +ResourceId = Union[Hashable, _IdCallable] + + class LockGuard: """ A context manager which acquires and releases a lock (mutex). @@ -21,3 +39,75 @@ class LockGuard: def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001 self._locked = False return False # Indicate any raised exception shouldn't be suppressed. + + +def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable: + """ + Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. + + If any other mutually exclusive function currently holds the lock for a resource, do not run the + decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if + the lock cannot be acquired. + + `namespace` is an identifier used to prevent collisions among resource IDs. + + `resource_id` identifies a resource on which to perform a mutually exclusive operation. + It may also be a callable or awaitable which will return the resource ID given an ordered + mapping of the parameters' names to arguments' values. + + If decorating a command, this decorator must go before (below) the `command` decorator. + """ + def decorator(func: Callable) -> Callable: + name = func.__name__ + + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + log.trace(f"{name}: mutually exclusive decorator called") + + if callable(resource_id): + log.trace(f"{name}: binding args to signature") + bound_args = function.get_bound_args(func, args, kwargs) + + log.trace(f"{name}: calling the given callable to get the resource ID") + id_ = resource_id(bound_args) + + if inspect.isawaitable(id_): + log.trace(f"{name}: awaiting to get resource ID") + id_ = await id_ + else: + id_ = resource_id + + log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}") + + # Get the lock for the ID. Create a lock if one doesn't exist yet. + locks = __lock_dicts[namespace] + lock = locks.setdefault(id_, LockGuard()) + + if not lock.locked(): + log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...") + with lock: + return await func(*args, **kwargs) + else: + log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") + if raise_error: + raise LockedResourceError(str(namespace), id_) + + return wrapper + return decorator + + +def lock_arg( + namespace: Hashable, + name_or_pos: function.Argument, + func: Callable[[Any], _IdCallableReturn] = None, + *, + raise_error: bool = False, +) -> Callable: + """ + Apply the `lock` decorator using the value of the arg at the given name/position as the ID. + + `func` is an optional callable or awaitable which will return the ID given the argument value. + See `lock` docs for more information. + """ + decorator_func = partial(lock, namespace, raise_error=raise_error) + return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) |