diff options
author | 2022-11-06 10:51:45 +0000 | |
---|---|---|
committer | 2024-01-30 21:38:26 +0000 | |
commit | 08dfb4f8e4fe7aa169c05a3ab8aa2e1e1b3165fa (patch) | |
tree | e6a64e2de9ed5a21ce2e2c27ee6dd01be1849113 /pydis_core/utils/lock.py | |
parent | Satisfy new ruff linting rules (diff) |
Add lock utils
This includes some additional function utils too.
Co-authored-by: Numerlor <[email protected]>
Co-authored-by: MarkKoz <[email protected]>
Diffstat (limited to 'pydis_core/utils/lock.py')
-rw-r--r-- | pydis_core/utils/lock.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/pydis_core/utils/lock.py b/pydis_core/utils/lock.py new file mode 100644 index 00000000..83146235 --- /dev/null +++ b/pydis_core/utils/lock.py @@ -0,0 +1,156 @@ +import asyncio +import inspect +import types +from collections import defaultdict +from collections.abc import Awaitable, Callable, Hashable +from functools import partial +from typing import Any +from weakref import WeakValueDictionary + +from pydis_core.utils import function +from pydis_core.utils.function import command_wraps +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary) + +_IdCallableReturn = Hashable | Awaitable[Hashable] +_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] +ResourceId = Hashable | _IdCallable + + +class LockedResourceError(RuntimeError): + """ + Exception raised when an operation is attempted on a locked resource. + + Attributes: + type (str): Name of the locked resource's type + id (typing.Hashable): ID of the locked resource + """ + + def __init__(self, resource_type: str, resource_id: Hashable): + self.type = resource_type + self.id = resource_id + + super().__init__( + f"Cannot operate on {self.type.lower()} `{self.id}`; " + "it is currently locked and in use by another operation." + ) + + +class SharedEvent: + """ + Context manager managing an internal event exposed through the wait coro. + + While any code is executing in this context manager, the underlying event will not be set; + when all of the holders finish the event will be set. + """ + + def __init__(self): + self._active_count = 0 + self._event = asyncio.Event() + self._event.set() + + def __enter__(self): + """Increment the count of the active holders and clear the internal event.""" + self._active_count += 1 + self._event.clear() + + def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001 + """Decrement the count of the active holders; if 0 is reached set the internal event.""" + self._active_count -= 1 + if not self._active_count: + self._event.set() + + async def wait(self) -> None: + """Wait for all active holders to exit.""" + await self._event.wait() + + +def lock( + namespace: Hashable, + resource_id: ResourceId, + *, + raise_error: bool = False, + wait: bool = False, +) -> Callable: + """ + Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. + + If decorating a command, this decorator must go before (below) the `command` decorator. + + Arguments: + namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs. + resource_id: identifies a resource on which to perform a mutually exclusive operation. + It may also be a callable or awaitable which will return the resource ID given an ordered + mapping of the parameters' names to arguments' values. + raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired. + wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually + exclusive function currently holds the lock for a resource, do not run the decorated function + and return None. + + Raises: + :exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True. + """ + def decorator(func: types.FunctionType) -> types.FunctionType: + name = func.__name__ + + @command_wraps(func) + async def wrapper(*args, **kwargs) -> Any: + log.trace(f"{name}: mutually exclusive decorator called") + + if callable(resource_id): + log.trace(f"{name}: binding args to signature") + bound_args = function.get_bound_args(func, args, kwargs) + + log.trace(f"{name}: calling the given callable to get the resource ID") + id_ = resource_id(bound_args) + + if inspect.isawaitable(id_): + log.trace(f"{name}: awaiting to get resource ID") + id_ = await id_ + else: + id_ = resource_id + + log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}") + + # Get the lock for the ID. Create a lock if one doesn't exist yet. + locks = __lock_dicts[namespace] + lock_ = locks.setdefault(id_, asyncio.Lock()) + + # It's safe to check an asyncio.Lock is free before acquiring it because: + # 1. Synchronous code like `if not lock_.locked()` does not yield execution + # 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free + # 3. awaits only yield execution to the event loop at actual I/O boundaries + if wait or not lock_.locked(): + log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...") + async with lock_: + return await func(*args, **kwargs) + else: + log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") + if raise_error: + raise LockedResourceError(str(namespace), id_) + return None + + return wrapper + return decorator + + +def lock_arg( + namespace: Hashable, + name_or_pos: function.Argument, + func: Callable[[Any], _IdCallableReturn] | None = None, + *, + raise_error: bool = False, + wait: bool = False, +) -> Callable: + """ + Apply the `lock` decorator using the value of the arg at the given name/position as the ID. + + See `lock` docs for more information. + + Arguments: + func: An optional callable or awaitable which will return the ID given the argument value. + """ + decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait) + return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) |