diff options
-rw-r--r-- | pydis_core/utils/__init__.py | 2 | ||||
-rw-r--r-- | pydis_core/utils/function.py | 93 | ||||
-rw-r--r-- | pydis_core/utils/lock.py | 156 |
3 files changed, 250 insertions, 1 deletions
diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py index 6e55f911..1636b35e 100644 --- a/pydis_core/utils/__init__.py +++ b/pydis_core/utils/__init__.py @@ -10,6 +10,7 @@ from pydis_core.utils import ( error_handling, function, interactions, + lock, logging, members, messages, @@ -47,6 +48,7 @@ __all__ = [ error_handling, function, interactions, + lock, logging, members, messages, diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py index 7a97027b..911f660d 100644 --- a/pydis_core/utils/function.py +++ b/pydis_core/utils/function.py @@ -3,22 +3,113 @@ from __future__ import annotations import functools +import inspect import types import typing from collections.abc import Callable, Sequence, Set -__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"] +__all__ = [ + "GlobalNameConflictError", + "command_wraps", + "get_arg_value", + "get_arg_value_wrapper", + "get_bound_args", + "update_wrapper_globals", +] if typing.TYPE_CHECKING: _P = typing.ParamSpec("_P") _R = typing.TypeVar("_R") +Argument = int | str +BoundArgs = typing.OrderedDict[str, typing.Any] +Decorator = typing.Callable[[typing.Callable], typing.Callable] +ArgValGetter = typing.Callable[[BoundArgs], typing.Any] + class GlobalNameConflictError(Exception): """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" +def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any: + """ + Return a value from `arguments` based on a name or position. + + Arguments: + arguments: An ordered mapping of parameter names to argument values. + Returns: + Value from `arguments` based on a name or position. + Raises: + TypeError: `name_or_pos` isn't a str or int. + ValueError: `name_or_pos` does not match any argument. + """ + if isinstance(name_or_pos, int): + # Convert arguments to a tuple to make them indexable. + arg_values = tuple(arguments.items()) + arg_pos = name_or_pos + + try: + _name, value = arg_values[arg_pos] + return value + except IndexError: + raise ValueError(f"Argument position {arg_pos} is out of bounds.") + elif isinstance(name_or_pos, str): + arg_name = name_or_pos + try: + return arguments[arg_name] + except KeyError: + raise ValueError(f"Argument {arg_name!r} doesn't exist.") + else: + raise TypeError("'arg' must either be an int (positional index) or a str (keyword).") + + +def get_arg_value_wrapper( + decorator_func: typing.Callable[[ArgValGetter], Decorator], + name_or_pos: Argument, + func: typing.Callable[[typing.Any], typing.Any] | None = None, +) -> Decorator: + """ + Call `decorator_func` with the value of the arg at the given name/position. + + Arguments: + decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of + parameter names to argument values of the function it's decorating. + name_or_pos: The name/position of the arg to get the value from. + func: An optional callable which will return a new value given the argument's value. + + Returns: + The decorator returned by `decorator_func`. + """ + def wrapper(args: BoundArgs) -> typing.Any: + value = get_arg_value(name_or_pos, args) + if func: + value = func(value) + return value + + return decorator_func(wrapper) + + +def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs: + """ + Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values. + + Default parameter values are also set. + + Args: + args: The arguments to bind to ``func`` + kwargs: The keyword arguments to bind to ``func`` + func: The function to bind ``args`` and ``kwargs`` to + Returns: + A mapping of parameter names to argument values. + """ + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + return bound_args.arguments + + def update_wrapper_globals( wrapper: Callable[_P, _R], wrapped: Callable[_P, _R], diff --git a/pydis_core/utils/lock.py b/pydis_core/utils/lock.py new file mode 100644 index 00000000..83146235 --- /dev/null +++ b/pydis_core/utils/lock.py @@ -0,0 +1,156 @@ +import asyncio +import inspect +import types +from collections import defaultdict +from collections.abc import Awaitable, Callable, Hashable +from functools import partial +from typing import Any +from weakref import WeakValueDictionary + +from pydis_core.utils import function +from pydis_core.utils.function import command_wraps +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary) + +_IdCallableReturn = Hashable | Awaitable[Hashable] +_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] +ResourceId = Hashable | _IdCallable + + +class LockedResourceError(RuntimeError): + """ + Exception raised when an operation is attempted on a locked resource. + + Attributes: + type (str): Name of the locked resource's type + id (typing.Hashable): ID of the locked resource + """ + + def __init__(self, resource_type: str, resource_id: Hashable): + self.type = resource_type + self.id = resource_id + + super().__init__( + f"Cannot operate on {self.type.lower()} `{self.id}`; " + "it is currently locked and in use by another operation." + ) + + +class SharedEvent: + """ + Context manager managing an internal event exposed through the wait coro. + + While any code is executing in this context manager, the underlying event will not be set; + when all of the holders finish the event will be set. + """ + + def __init__(self): + self._active_count = 0 + self._event = asyncio.Event() + self._event.set() + + def __enter__(self): + """Increment the count of the active holders and clear the internal event.""" + self._active_count += 1 + self._event.clear() + + def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001 + """Decrement the count of the active holders; if 0 is reached set the internal event.""" + self._active_count -= 1 + if not self._active_count: + self._event.set() + + async def wait(self) -> None: + """Wait for all active holders to exit.""" + await self._event.wait() + + +def lock( + namespace: Hashable, + resource_id: ResourceId, + *, + raise_error: bool = False, + wait: bool = False, +) -> Callable: + """ + Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. + + If decorating a command, this decorator must go before (below) the `command` decorator. + + Arguments: + namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs. + resource_id: identifies a resource on which to perform a mutually exclusive operation. + It may also be a callable or awaitable which will return the resource ID given an ordered + mapping of the parameters' names to arguments' values. + raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired. + wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually + exclusive function currently holds the lock for a resource, do not run the decorated function + and return None. + + Raises: + :exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True. + """ + def decorator(func: types.FunctionType) -> types.FunctionType: + name = func.__name__ + + @command_wraps(func) + async def wrapper(*args, **kwargs) -> Any: + log.trace(f"{name}: mutually exclusive decorator called") + + if callable(resource_id): + log.trace(f"{name}: binding args to signature") + bound_args = function.get_bound_args(func, args, kwargs) + + log.trace(f"{name}: calling the given callable to get the resource ID") + id_ = resource_id(bound_args) + + if inspect.isawaitable(id_): + log.trace(f"{name}: awaiting to get resource ID") + id_ = await id_ + else: + id_ = resource_id + + log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}") + + # Get the lock for the ID. Create a lock if one doesn't exist yet. + locks = __lock_dicts[namespace] + lock_ = locks.setdefault(id_, asyncio.Lock()) + + # It's safe to check an asyncio.Lock is free before acquiring it because: + # 1. Synchronous code like `if not lock_.locked()` does not yield execution + # 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free + # 3. awaits only yield execution to the event loop at actual I/O boundaries + if wait or not lock_.locked(): + log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...") + async with lock_: + return await func(*args, **kwargs) + else: + log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") + if raise_error: + raise LockedResourceError(str(namespace), id_) + return None + + return wrapper + return decorator + + +def lock_arg( + namespace: Hashable, + name_or_pos: function.Argument, + func: Callable[[Any], _IdCallableReturn] | None = None, + *, + raise_error: bool = False, + wait: bool = False, +) -> Callable: + """ + Apply the `lock` decorator using the value of the arg at the given name/position as the ID. + + See `lock` docs for more information. + + Arguments: + func: An optional callable or awaitable which will return the ID given the argument value. + """ + decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait) + return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) |