aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pydis_core/utils/__init__.py2
-rw-r--r--pydis_core/utils/function.py93
-rw-r--r--pydis_core/utils/lock.py156
3 files changed, 250 insertions, 1 deletions
diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py
index 6e55f911..1636b35e 100644
--- a/pydis_core/utils/__init__.py
+++ b/pydis_core/utils/__init__.py
@@ -10,6 +10,7 @@ from pydis_core.utils import (
error_handling,
function,
interactions,
+ lock,
logging,
members,
messages,
@@ -47,6 +48,7 @@ __all__ = [
error_handling,
function,
interactions,
+ lock,
logging,
members,
messages,
diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py
index 7a97027b..911f660d 100644
--- a/pydis_core/utils/function.py
+++ b/pydis_core/utils/function.py
@@ -3,22 +3,113 @@
from __future__ import annotations
import functools
+import inspect
import types
import typing
from collections.abc import Callable, Sequence, Set
-__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"]
+__all__ = [
+ "GlobalNameConflictError",
+ "command_wraps",
+ "get_arg_value",
+ "get_arg_value_wrapper",
+ "get_bound_args",
+ "update_wrapper_globals",
+]
if typing.TYPE_CHECKING:
_P = typing.ParamSpec("_P")
_R = typing.TypeVar("_R")
+Argument = int | str
+BoundArgs = typing.OrderedDict[str, typing.Any]
+Decorator = typing.Callable[[typing.Callable], typing.Callable]
+ArgValGetter = typing.Callable[[BoundArgs], typing.Any]
+
class GlobalNameConflictError(Exception):
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
+def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any:
+ """
+ Return a value from `arguments` based on a name or position.
+
+ Arguments:
+ arguments: An ordered mapping of parameter names to argument values.
+ Returns:
+ Value from `arguments` based on a name or position.
+ Raises:
+ TypeError: `name_or_pos` isn't a str or int.
+ ValueError: `name_or_pos` does not match any argument.
+ """
+ if isinstance(name_or_pos, int):
+ # Convert arguments to a tuple to make them indexable.
+ arg_values = tuple(arguments.items())
+ arg_pos = name_or_pos
+
+ try:
+ _name, value = arg_values[arg_pos]
+ return value
+ except IndexError:
+ raise ValueError(f"Argument position {arg_pos} is out of bounds.")
+ elif isinstance(name_or_pos, str):
+ arg_name = name_or_pos
+ try:
+ return arguments[arg_name]
+ except KeyError:
+ raise ValueError(f"Argument {arg_name!r} doesn't exist.")
+ else:
+ raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
+
+
+def get_arg_value_wrapper(
+ decorator_func: typing.Callable[[ArgValGetter], Decorator],
+ name_or_pos: Argument,
+ func: typing.Callable[[typing.Any], typing.Any] | None = None,
+) -> Decorator:
+ """
+ Call `decorator_func` with the value of the arg at the given name/position.
+
+ Arguments:
+ decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of
+ parameter names to argument values of the function it's decorating.
+ name_or_pos: The name/position of the arg to get the value from.
+ func: An optional callable which will return a new value given the argument's value.
+
+ Returns:
+ The decorator returned by `decorator_func`.
+ """
+ def wrapper(args: BoundArgs) -> typing.Any:
+ value = get_arg_value(name_or_pos, args)
+ if func:
+ value = func(value)
+ return value
+
+ return decorator_func(wrapper)
+
+
+def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs:
+ """
+ Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
+
+ Default parameter values are also set.
+
+ Args:
+ args: The arguments to bind to ``func``
+ kwargs: The keyword arguments to bind to ``func``
+ func: The function to bind ``args`` and ``kwargs`` to
+ Returns:
+ A mapping of parameter names to argument values.
+ """
+ sig = inspect.signature(func)
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
+ return bound_args.arguments
+
+
def update_wrapper_globals(
wrapper: Callable[_P, _R],
wrapped: Callable[_P, _R],
diff --git a/pydis_core/utils/lock.py b/pydis_core/utils/lock.py
new file mode 100644
index 00000000..83146235
--- /dev/null
+++ b/pydis_core/utils/lock.py
@@ -0,0 +1,156 @@
+import asyncio
+import inspect
+import types
+from collections import defaultdict
+from collections.abc import Awaitable, Callable, Hashable
+from functools import partial
+from typing import Any
+from weakref import WeakValueDictionary
+
+from pydis_core.utils import function
+from pydis_core.utils.function import command_wraps
+from pydis_core.utils.logging import get_logger
+
+log = get_logger(__name__)
+__lock_dicts = defaultdict(WeakValueDictionary)
+
+_IdCallableReturn = Hashable | Awaitable[Hashable]
+_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
+ResourceId = Hashable | _IdCallable
+
+
+class LockedResourceError(RuntimeError):
+ """
+ Exception raised when an operation is attempted on a locked resource.
+
+ Attributes:
+ type (str): Name of the locked resource's type
+ id (typing.Hashable): ID of the locked resource
+ """
+
+ def __init__(self, resource_type: str, resource_id: Hashable):
+ self.type = resource_type
+ self.id = resource_id
+
+ super().__init__(
+ f"Cannot operate on {self.type.lower()} `{self.id}`; "
+ "it is currently locked and in use by another operation."
+ )
+
+
+class SharedEvent:
+ """
+ Context manager managing an internal event exposed through the wait coro.
+
+ While any code is executing in this context manager, the underlying event will not be set;
+ when all of the holders finish the event will be set.
+ """
+
+ def __init__(self):
+ self._active_count = 0
+ self._event = asyncio.Event()
+ self._event.set()
+
+ def __enter__(self):
+ """Increment the count of the active holders and clear the internal event."""
+ self._active_count += 1
+ self._event.clear()
+
+ def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
+ """Decrement the count of the active holders; if 0 is reached set the internal event."""
+ self._active_count -= 1
+ if not self._active_count:
+ self._event.set()
+
+ async def wait(self) -> None:
+ """Wait for all active holders to exit."""
+ await self._event.wait()
+
+
+def lock(
+ namespace: Hashable,
+ resource_id: ResourceId,
+ *,
+ raise_error: bool = False,
+ wait: bool = False,
+) -> Callable:
+ """
+ Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
+
+ If decorating a command, this decorator must go before (below) the `command` decorator.
+
+ Arguments:
+ namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs.
+ resource_id: identifies a resource on which to perform a mutually exclusive operation.
+ It may also be a callable or awaitable which will return the resource ID given an ordered
+ mapping of the parameters' names to arguments' values.
+ raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired.
+ wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually
+ exclusive function currently holds the lock for a resource, do not run the decorated function
+ and return None.
+
+ Raises:
+ :exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True.
+ """
+ def decorator(func: types.FunctionType) -> types.FunctionType:
+ name = func.__name__
+
+ @command_wraps(func)
+ async def wrapper(*args, **kwargs) -> Any:
+ log.trace(f"{name}: mutually exclusive decorator called")
+
+ if callable(resource_id):
+ log.trace(f"{name}: binding args to signature")
+ bound_args = function.get_bound_args(func, args, kwargs)
+
+ log.trace(f"{name}: calling the given callable to get the resource ID")
+ id_ = resource_id(bound_args)
+
+ if inspect.isawaitable(id_):
+ log.trace(f"{name}: awaiting to get resource ID")
+ id_ = await id_
+ else:
+ id_ = resource_id
+
+ log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")
+
+ # Get the lock for the ID. Create a lock if one doesn't exist yet.
+ locks = __lock_dicts[namespace]
+ lock_ = locks.setdefault(id_, asyncio.Lock())
+
+ # It's safe to check an asyncio.Lock is free before acquiring it because:
+ # 1. Synchronous code like `if not lock_.locked()` does not yield execution
+ # 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
+ # 3. awaits only yield execution to the event loop at actual I/O boundaries
+ if wait or not lock_.locked():
+ log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
+ async with lock_:
+ return await func(*args, **kwargs)
+ else:
+ log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
+ if raise_error:
+ raise LockedResourceError(str(namespace), id_)
+ return None
+
+ return wrapper
+ return decorator
+
+
+def lock_arg(
+ namespace: Hashable,
+ name_or_pos: function.Argument,
+ func: Callable[[Any], _IdCallableReturn] | None = None,
+ *,
+ raise_error: bool = False,
+ wait: bool = False,
+) -> Callable:
+ """
+ Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
+
+ See `lock` docs for more information.
+
+ Arguments:
+ func: An optional callable or awaitable which will return the ID given the argument value.
+ """
+ decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
+ return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)