diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/utils/lock.py | 62 | 
1 files changed, 25 insertions, 37 deletions
| diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 7aaafbc88..e44776340 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,3 +1,4 @@ +import asyncio  import inspect  import logging  from collections import defaultdict @@ -16,39 +17,21 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]  ResourceId = Union[Hashable, _IdCallable] -class LockGuard: -    """ -    A context manager which acquires and releases a lock (mutex). - -    Raise RuntimeError if trying to acquire a locked lock. -    """ - -    def __init__(self): -        self._locked = False - -    @property -    def locked(self) -> bool: -        """Return True if currently locked or False if unlocked.""" -        return self._locked - -    def __enter__(self): -        if self._locked: -            raise RuntimeError("Cannot acquire a locked lock.") - -        self._locked = True - -    def __exit__(self, _exc_type, _exc_value, _traceback):  # noqa: ANN001 -        self._locked = False -        return False  # Indicate any raised exception shouldn't be suppressed. - - -def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable: +def lock( +    namespace: Hashable, +    resource_id: ResourceId, +    *, +    raise_error: bool = False, +    wait: bool = False, +) -> Callable:      """      Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. -    If any other mutually exclusive function currently holds the lock for a resource, do not run the -    decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if -    the lock cannot be acquired. +    If `wait` is True, wait until the lock becomes available. Otherwise, if any other mutually +    exclusive function currently holds the lock for a resource, do not run the decorated function +    and return None. + +    If `raise_error` is True, raise `LockedResourceError` if the lock cannot be acquired.      `namespace` is an identifier used to prevent collisions among resource IDs. @@ -78,15 +61,19 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa              else:                  id_ = resource_id -            log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}") +            log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")              # Get the lock for the ID. Create a lock if one doesn't exist yet.              locks = __lock_dicts[namespace] -            lock_guard = locks.setdefault(id_, LockGuard()) - -            if not lock_guard.locked: -                log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...") -                with lock_guard: +            lock_ = locks.setdefault(id_, asyncio.Lock()) + +            # It's safe to check an asyncio.Lock is free before acquiring it because: +            #   1. Synchronous code like `if not lock_.locked()` does not yield execution +            #   2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free +            #   3. awaits only yield execution to the event loop at actual I/O boundaries +            if wait or not lock_.locked(): +                log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...") +                async with lock_:                      return await func(*args, **kwargs)              else:                  log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") @@ -103,6 +90,7 @@ def lock_arg(      func: Callable[[Any], _IdCallableReturn] = None,      *,      raise_error: bool = False, +    wait: bool = False,  ) -> Callable:      """      Apply the `lock` decorator using the value of the arg at the given name/position as the ID. @@ -110,5 +98,5 @@ def lock_arg(      `func` is an optional callable or awaitable which will return the ID given the argument value.      See `lock` docs for more information.      """ -    decorator_func = partial(lock, namespace, raise_error=raise_error) +    decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)      return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) | 
