diff options
| -rw-r--r-- | pydis_core/utils/__init__.py | 2 | ||||
| -rw-r--r-- | pydis_core/utils/function.py | 93 | ||||
| -rw-r--r-- | pydis_core/utils/lock.py | 156 | 
3 files changed, 250 insertions, 1 deletions
| diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py index 6e55f911..1636b35e 100644 --- a/pydis_core/utils/__init__.py +++ b/pydis_core/utils/__init__.py @@ -10,6 +10,7 @@ from pydis_core.utils import (      error_handling,      function,      interactions, +    lock,      logging,      members,      messages, @@ -47,6 +48,7 @@ __all__ = [      error_handling,      function,      interactions, +    lock,      logging,      members,      messages, diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py index 7a97027b..911f660d 100644 --- a/pydis_core/utils/function.py +++ b/pydis_core/utils/function.py @@ -3,22 +3,113 @@  from __future__ import annotations  import functools +import inspect  import types  import typing  from collections.abc import Callable, Sequence, Set -__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"] +__all__ = [ +    "GlobalNameConflictError", +    "command_wraps", +    "get_arg_value", +    "get_arg_value_wrapper", +    "get_bound_args", +    "update_wrapper_globals", +]  if typing.TYPE_CHECKING:      _P = typing.ParamSpec("_P")      _R = typing.TypeVar("_R") +Argument = int | str +BoundArgs = typing.OrderedDict[str, typing.Any] +Decorator = typing.Callable[[typing.Callable], typing.Callable] +ArgValGetter = typing.Callable[[BoundArgs], typing.Any] +  class GlobalNameConflictError(Exception):      """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" +def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any: +    """ +    Return a value from `arguments` based on a name or position. + +    Arguments: +        arguments: An ordered mapping of parameter names to argument values. +    Returns: +        Value from `arguments` based on a name or position. +    Raises: +        TypeError: `name_or_pos` isn't a str or int. +        ValueError: `name_or_pos` does not match any argument. +    """ +    if isinstance(name_or_pos, int): +        # Convert arguments to a tuple to make them indexable. +        arg_values = tuple(arguments.items()) +        arg_pos = name_or_pos + +        try: +            _name, value = arg_values[arg_pos] +            return value +        except IndexError: +            raise ValueError(f"Argument position {arg_pos} is out of bounds.") +    elif isinstance(name_or_pos, str): +        arg_name = name_or_pos +        try: +            return arguments[arg_name] +        except KeyError: +            raise ValueError(f"Argument {arg_name!r} doesn't exist.") +    else: +        raise TypeError("'arg' must either be an int (positional index) or a str (keyword).") + + +def get_arg_value_wrapper( +    decorator_func: typing.Callable[[ArgValGetter], Decorator], +    name_or_pos: Argument, +    func: typing.Callable[[typing.Any], typing.Any] | None = None, +) -> Decorator: +    """ +    Call `decorator_func` with the value of the arg at the given name/position. + +    Arguments: +        decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of +            parameter names to argument values of the function it's decorating. +        name_or_pos: The name/position of the arg to get the value from. +        func: An optional callable which will return a new value given the argument's value. + +    Returns: +        The decorator returned by `decorator_func`. +    """ +    def wrapper(args: BoundArgs) -> typing.Any: +        value = get_arg_value(name_or_pos, args) +        if func: +            value = func(value) +        return value + +    return decorator_func(wrapper) + + +def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs: +    """ +    Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values. + +    Default parameter values are also set. + +    Args: +        args: The arguments to bind to ``func`` +        kwargs: The keyword arguments to bind to ``func`` +        func: The function to bind ``args`` and ``kwargs`` to +    Returns: +        A mapping of parameter names to argument values. +    """ +    sig = inspect.signature(func) +    bound_args = sig.bind(*args, **kwargs) +    bound_args.apply_defaults() + +    return bound_args.arguments + +  def update_wrapper_globals(      wrapper: Callable[_P, _R],      wrapped: Callable[_P, _R], diff --git a/pydis_core/utils/lock.py b/pydis_core/utils/lock.py new file mode 100644 index 00000000..83146235 --- /dev/null +++ b/pydis_core/utils/lock.py @@ -0,0 +1,156 @@ +import asyncio +import inspect +import types +from collections import defaultdict +from collections.abc import Awaitable, Callable, Hashable +from functools import partial +from typing import Any +from weakref import WeakValueDictionary + +from pydis_core.utils import function +from pydis_core.utils.function import command_wraps +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary) + +_IdCallableReturn = Hashable | Awaitable[Hashable] +_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] +ResourceId = Hashable | _IdCallable + + +class LockedResourceError(RuntimeError): +    """ +    Exception raised when an operation is attempted on a locked resource. + +    Attributes: +        type (str): Name of the locked resource's type +        id (typing.Hashable): ID of the locked resource +    """ + +    def __init__(self, resource_type: str, resource_id: Hashable): +        self.type = resource_type +        self.id = resource_id + +        super().__init__( +            f"Cannot operate on {self.type.lower()} `{self.id}`; " +            "it is currently locked and in use by another operation." +        ) + + +class SharedEvent: +    """ +    Context manager managing an internal event exposed through the wait coro. + +    While any code is executing in this context manager, the underlying event will not be set; +    when all of the holders finish the event will be set. +    """ + +    def __init__(self): +        self._active_count = 0 +        self._event = asyncio.Event() +        self._event.set() + +    def __enter__(self): +        """Increment the count of the active holders and clear the internal event.""" +        self._active_count += 1 +        self._event.clear() + +    def __exit__(self, _exc_type, _exc_val, _exc_tb):  # noqa: ANN001 +        """Decrement the count of the active holders; if 0 is reached set the internal event.""" +        self._active_count -= 1 +        if not self._active_count: +            self._event.set() + +    async def wait(self) -> None: +        """Wait for all active holders to exit.""" +        await self._event.wait() + + +def lock( +    namespace: Hashable, +    resource_id: ResourceId, +    *, +    raise_error: bool = False, +    wait: bool = False, +) -> Callable: +    """ +    Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. + +    If decorating a command, this decorator must go before (below) the `command` decorator. + +    Arguments: +        namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs. +        resource_id: identifies a resource on which to perform a mutually exclusive operation. +            It may also be a callable or awaitable which will return the resource ID given an ordered +            mapping of the parameters' names to arguments' values. +        raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired. +        wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually +            exclusive function currently holds the lock for a resource, do not run the decorated function +            and return None. + +    Raises: +        :exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True. +    """ +    def decorator(func: types.FunctionType) -> types.FunctionType: +        name = func.__name__ + +        @command_wraps(func) +        async def wrapper(*args, **kwargs) -> Any: +            log.trace(f"{name}: mutually exclusive decorator called") + +            if callable(resource_id): +                log.trace(f"{name}: binding args to signature") +                bound_args = function.get_bound_args(func, args, kwargs) + +                log.trace(f"{name}: calling the given callable to get the resource ID") +                id_ = resource_id(bound_args) + +                if inspect.isawaitable(id_): +                    log.trace(f"{name}: awaiting to get resource ID") +                    id_ = await id_ +            else: +                id_ = resource_id + +            log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}") + +            # Get the lock for the ID. Create a lock if one doesn't exist yet. +            locks = __lock_dicts[namespace] +            lock_ = locks.setdefault(id_, asyncio.Lock()) + +            # It's safe to check an asyncio.Lock is free before acquiring it because: +            #   1. Synchronous code like `if not lock_.locked()` does not yield execution +            #   2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free +            #   3. awaits only yield execution to the event loop at actual I/O boundaries +            if wait or not lock_.locked(): +                log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...") +                async with lock_: +                    return await func(*args, **kwargs) +            else: +                log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") +                if raise_error: +                    raise LockedResourceError(str(namespace), id_) +                return None + +        return wrapper +    return decorator + + +def lock_arg( +    namespace: Hashable, +    name_or_pos: function.Argument, +    func: Callable[[Any], _IdCallableReturn] | None = None, +    *, +    raise_error: bool = False, +    wait: bool = False, +) -> Callable: +    """ +    Apply the `lock` decorator using the value of the arg at the given name/position as the ID. + +    See `lock` docs for more information. + +    Arguments: +        func: An optional callable or awaitable which will return the ID given the argument value. +    """ +    decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait) +    return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) | 
