aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/reminders.py2
-rw-r--r--bot/decorators.py85
-rw-r--r--bot/utils/__init__.py3
-rw-r--r--bot/utils/lock.py90
4 files changed, 94 insertions, 86 deletions
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 734e0bd2d..25b2c9421 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -14,9 +14,9 @@ from discord.ext.commands import Cog, Context, Greedy, group
from bot.bot import Bot
from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, STAFF_ROLES
from bot.converters import Duration
-from bot.decorators import lock_arg
from bot.pagination import LinePaginator
from bot.utils.checks import without_role_check
+from bot.utils.lock import lock_arg
from bot.utils.messages import send_denial
from bot.utils.scheduling import Scheduler
from bot.utils.time import humanize_delta
diff --git a/bot/decorators.py b/bot/decorators.py
index aabbe2cc9..2ec0cb122 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,26 +1,17 @@
import asyncio
-import inspect
import logging
import typing as t
-from collections import defaultdict
from contextlib import suppress
-from functools import partial, wraps
-from weakref import WeakValueDictionary
+from functools import wraps
from discord import Member, NotFound
from discord.ext.commands import Cog, Context, check
from bot.constants import Channels, RedirectOutput
-from bot.errors import LockedResourceError
-from bot.utils import LockGuard, function
+from bot.utils import function
from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check
log = logging.getLogger(__name__)
-__lock_dicts = defaultdict(WeakValueDictionary)
-
-_IdCallableReturn = t.Union[t.Hashable, t.Awaitable[t.Hashable]]
-_IdCallable = t.Callable[[function.BoundArgs], _IdCallableReturn]
-ResourceId = t.Union[t.Hashable, _IdCallable]
def in_whitelist(
@@ -66,78 +57,6 @@ def without_role(*role_ids: int) -> t.Callable:
return check(predicate)
-def lock(namespace: t.Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> t.Callable:
- """
- Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
-
- If any other mutually exclusive function currently holds the lock for a resource, do not run the
- decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if
- the lock cannot be acquired.
-
- `namespace` is an identifier used to prevent collisions among resource IDs.
-
- `resource_id` identifies a resource on which to perform a mutually exclusive operation.
- It may also be a callable or awaitable which will return the resource ID given an ordered
- mapping of the parameters' names to arguments' values.
-
- If decorating a command, this decorator must go before (below) the `command` decorator.
- """
- def decorator(func: t.Callable) -> t.Callable:
- name = func.__name__
-
- @wraps(func)
- async def wrapper(*args, **kwargs) -> t.Any:
- log.trace(f"{name}: mutually exclusive decorator called")
-
- if callable(resource_id):
- log.trace(f"{name}: binding args to signature")
- bound_args = function.get_bound_args(func, args, kwargs)
-
- log.trace(f"{name}: calling the given callable to get the resource ID")
- id_ = resource_id(bound_args)
-
- if inspect.isawaitable(id_):
- log.trace(f"{name}: awaiting to get resource ID")
- id_ = await id_
- else:
- id_ = resource_id
-
- log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}")
-
- # Get the lock for the ID. Create a lock if one doesn't exist yet.
- locks = __lock_dicts[namespace]
- lock = locks.setdefault(id_, LockGuard())
-
- if not lock.locked():
- log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...")
- with lock:
- return await func(*args, **kwargs)
- else:
- log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
- if raise_error:
- raise LockedResourceError(str(namespace), id_)
-
- return wrapper
- return decorator
-
-
-def lock_arg(
- namespace: t.Hashable,
- name_or_pos: function.Argument,
- func: t.Callable[[t.Any], _IdCallableReturn] = None,
- *,
- raise_error: bool = False,
-) -> t.Callable:
- """
- Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
-
- `func` is an optional callable or awaitable which will return the ID given the argument value.
- See `lock` docs for more information.
- """
- decorator_func = partial(lock, namespace, raise_error=raise_error)
- return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)
-
-
def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:
"""
Changes the channel in the context of the command to redirect the output to a certain channel.
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 0dd9605e8..b73410e96 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -2,10 +2,9 @@ from abc import ABCMeta
from discord.ext.commands import CogMeta
-from bot.utils.lock import LockGuard
from bot.utils.redis_cache import RedisCache
-__all__ = ["CogABCMeta", "LockGuard", "RedisCache"]
+__all__ = ["CogABCMeta", "RedisCache"]
class CogABCMeta(CogMeta, ABCMeta):
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
index 8f1b738aa..5c9dd3725 100644
--- a/bot/utils/lock.py
+++ b/bot/utils/lock.py
@@ -1,3 +1,21 @@
+import inspect
+import logging
+from collections import defaultdict
+from functools import partial, wraps
+from typing import Any, Awaitable, Callable, Hashable, Union
+from weakref import WeakValueDictionary
+
+from bot.errors import LockedResourceError
+from bot.utils import function
+
+log = logging.getLogger(__name__)
+__lock_dicts = defaultdict(WeakValueDictionary)
+
+_IdCallableReturn = Union[Hashable, Awaitable[Hashable]]
+_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
+ResourceId = Union[Hashable, _IdCallable]
+
+
class LockGuard:
"""
A context manager which acquires and releases a lock (mutex).
@@ -21,3 +39,75 @@ class LockGuard:
def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001
self._locked = False
return False # Indicate any raised exception shouldn't be suppressed.
+
+
+def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable:
+ """
+ Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
+
+ If any other mutually exclusive function currently holds the lock for a resource, do not run the
+ decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if
+ the lock cannot be acquired.
+
+ `namespace` is an identifier used to prevent collisions among resource IDs.
+
+ `resource_id` identifies a resource on which to perform a mutually exclusive operation.
+ It may also be a callable or awaitable which will return the resource ID given an ordered
+ mapping of the parameters' names to arguments' values.
+
+ If decorating a command, this decorator must go before (below) the `command` decorator.
+ """
+ def decorator(func: Callable) -> Callable:
+ name = func.__name__
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs) -> Any:
+ log.trace(f"{name}: mutually exclusive decorator called")
+
+ if callable(resource_id):
+ log.trace(f"{name}: binding args to signature")
+ bound_args = function.get_bound_args(func, args, kwargs)
+
+ log.trace(f"{name}: calling the given callable to get the resource ID")
+ id_ = resource_id(bound_args)
+
+ if inspect.isawaitable(id_):
+ log.trace(f"{name}: awaiting to get resource ID")
+ id_ = await id_
+ else:
+ id_ = resource_id
+
+ log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}")
+
+ # Get the lock for the ID. Create a lock if one doesn't exist yet.
+ locks = __lock_dicts[namespace]
+ lock = locks.setdefault(id_, LockGuard())
+
+ if not lock.locked():
+ log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...")
+ with lock:
+ return await func(*args, **kwargs)
+ else:
+ log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
+ if raise_error:
+ raise LockedResourceError(str(namespace), id_)
+
+ return wrapper
+ return decorator
+
+
+def lock_arg(
+ namespace: Hashable,
+ name_or_pos: function.Argument,
+ func: Callable[[Any], _IdCallableReturn] = None,
+ *,
+ raise_error: bool = False,
+) -> Callable:
+ """
+ Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
+
+ `func` is an optional callable or awaitable which will return the ID given the argument value.
+ See `lock` docs for more information.
+ """
+ decorator_func = partial(lock, namespace, raise_error=raise_error)
+ return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)