From 08dfb4f8e4fe7aa169c05a3ab8aa2e1e1b3165fa Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sun, 6 Nov 2022 10:51:45 +0000 Subject: Add lock utils This includes some additional function utils too. Co-authored-by: Numerlor Co-authored-by: MarkKoz --- pydis_core/utils/function.py | 93 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) (limited to 'pydis_core/utils/function.py') diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py index 7a97027b..911f660d 100644 --- a/pydis_core/utils/function.py +++ b/pydis_core/utils/function.py @@ -3,22 +3,113 @@ from __future__ import annotations import functools +import inspect import types import typing from collections.abc import Callable, Sequence, Set -__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"] +__all__ = [ + "GlobalNameConflictError", + "command_wraps", + "get_arg_value", + "get_arg_value_wrapper", + "get_bound_args", + "update_wrapper_globals", +] if typing.TYPE_CHECKING: _P = typing.ParamSpec("_P") _R = typing.TypeVar("_R") +Argument = int | str +BoundArgs = typing.OrderedDict[str, typing.Any] +Decorator = typing.Callable[[typing.Callable], typing.Callable] +ArgValGetter = typing.Callable[[BoundArgs], typing.Any] + class GlobalNameConflictError(Exception): """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" +def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any: + """ + Return a value from `arguments` based on a name or position. + + Arguments: + arguments: An ordered mapping of parameter names to argument values. + Returns: + Value from `arguments` based on a name or position. + Raises: + TypeError: `name_or_pos` isn't a str or int. + ValueError: `name_or_pos` does not match any argument. + """ + if isinstance(name_or_pos, int): + # Convert arguments to a tuple to make them indexable. + arg_values = tuple(arguments.items()) + arg_pos = name_or_pos + + try: + _name, value = arg_values[arg_pos] + return value + except IndexError: + raise ValueError(f"Argument position {arg_pos} is out of bounds.") + elif isinstance(name_or_pos, str): + arg_name = name_or_pos + try: + return arguments[arg_name] + except KeyError: + raise ValueError(f"Argument {arg_name!r} doesn't exist.") + else: + raise TypeError("'arg' must either be an int (positional index) or a str (keyword).") + + +def get_arg_value_wrapper( + decorator_func: typing.Callable[[ArgValGetter], Decorator], + name_or_pos: Argument, + func: typing.Callable[[typing.Any], typing.Any] | None = None, +) -> Decorator: + """ + Call `decorator_func` with the value of the arg at the given name/position. + + Arguments: + decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of + parameter names to argument values of the function it's decorating. + name_or_pos: The name/position of the arg to get the value from. + func: An optional callable which will return a new value given the argument's value. + + Returns: + The decorator returned by `decorator_func`. + """ + def wrapper(args: BoundArgs) -> typing.Any: + value = get_arg_value(name_or_pos, args) + if func: + value = func(value) + return value + + return decorator_func(wrapper) + + +def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs: + """ + Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values. + + Default parameter values are also set. + + Args: + args: The arguments to bind to ``func`` + kwargs: The keyword arguments to bind to ``func`` + func: The function to bind ``args`` and ``kwargs`` to + Returns: + A mapping of parameter names to argument values. + """ + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + return bound_args.arguments + + def update_wrapper_globals( wrapper: Callable[_P, _R], wrapped: Callable[_P, _R], -- cgit v1.2.3