aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/utils/lock.py62
1 files changed, 25 insertions, 37 deletions
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
index 7aaafbc88..e44776340 100644
--- a/bot/utils/lock.py
+++ b/bot/utils/lock.py
@@ -1,3 +1,4 @@
+import asyncio
import inspect
import logging
from collections import defaultdict
@@ -16,39 +17,21 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
ResourceId = Union[Hashable, _IdCallable]
-class LockGuard:
- """
- A context manager which acquires and releases a lock (mutex).
-
- Raise RuntimeError if trying to acquire a locked lock.
- """
-
- def __init__(self):
- self._locked = False
-
- @property
- def locked(self) -> bool:
- """Return True if currently locked or False if unlocked."""
- return self._locked
-
- def __enter__(self):
- if self._locked:
- raise RuntimeError("Cannot acquire a locked lock.")
-
- self._locked = True
-
- def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001
- self._locked = False
- return False # Indicate any raised exception shouldn't be suppressed.
-
-
-def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable:
+def lock(
+ namespace: Hashable,
+ resource_id: ResourceId,
+ *,
+ raise_error: bool = False,
+ wait: bool = False,
+) -> Callable:
"""
Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
- If any other mutually exclusive function currently holds the lock for a resource, do not run the
- decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if
- the lock cannot be acquired.
+ If `wait` is True, wait until the lock becomes available. Otherwise, if any other mutually
+ exclusive function currently holds the lock for a resource, do not run the decorated function
+ and return None.
+
+ If `raise_error` is True, raise `LockedResourceError` if the lock cannot be acquired.
`namespace` is an identifier used to prevent collisions among resource IDs.
@@ -78,15 +61,19 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa
else:
id_ = resource_id
- log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}")
+ log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")
# Get the lock for the ID. Create a lock if one doesn't exist yet.
locks = __lock_dicts[namespace]
- lock_guard = locks.setdefault(id_, LockGuard())
-
- if not lock_guard.locked:
- log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...")
- with lock_guard:
+ lock_ = locks.setdefault(id_, asyncio.Lock())
+
+ # It's safe to check an asyncio.Lock is free before acquiring it because:
+ # 1. Synchronous code like `if not lock_.locked()` does not yield execution
+ # 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
+ # 3. awaits only yield execution to the event loop at actual I/O boundaries
+ if wait or not lock_.locked():
+ log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
+ async with lock_:
return await func(*args, **kwargs)
else:
log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
@@ -103,6 +90,7 @@ def lock_arg(
func: Callable[[Any], _IdCallableReturn] = None,
*,
raise_error: bool = False,
+ wait: bool = False,
) -> Callable:
"""
Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
@@ -110,5 +98,5 @@ def lock_arg(
`func` is an optional callable or awaitable which will return the ID given the argument value.
See `lock` docs for more information.
"""
- decorator_func = partial(lock, namespace, raise_error=raise_error)
+ decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)