aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_core/utils/function.py')
-rw-r--r--pydis_core/utils/function.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py
index 7a97027b..911f660d 100644
--- a/pydis_core/utils/function.py
+++ b/pydis_core/utils/function.py
@@ -3,22 +3,113 @@
from __future__ import annotations
import functools
+import inspect
import types
import typing
from collections.abc import Callable, Sequence, Set
-__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"]
+__all__ = [
+ "GlobalNameConflictError",
+ "command_wraps",
+ "get_arg_value",
+ "get_arg_value_wrapper",
+ "get_bound_args",
+ "update_wrapper_globals",
+]
if typing.TYPE_CHECKING:
_P = typing.ParamSpec("_P")
_R = typing.TypeVar("_R")
+Argument = int | str
+BoundArgs = typing.OrderedDict[str, typing.Any]
+Decorator = typing.Callable[[typing.Callable], typing.Callable]
+ArgValGetter = typing.Callable[[BoundArgs], typing.Any]
+
class GlobalNameConflictError(Exception):
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
+def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any:
+ """
+ Return a value from `arguments` based on a name or position.
+
+ Arguments:
+ arguments: An ordered mapping of parameter names to argument values.
+ Returns:
+ Value from `arguments` based on a name or position.
+ Raises:
+ TypeError: `name_or_pos` isn't a str or int.
+ ValueError: `name_or_pos` does not match any argument.
+ """
+ if isinstance(name_or_pos, int):
+ # Convert arguments to a tuple to make them indexable.
+ arg_values = tuple(arguments.items())
+ arg_pos = name_or_pos
+
+ try:
+ _name, value = arg_values[arg_pos]
+ return value
+ except IndexError:
+ raise ValueError(f"Argument position {arg_pos} is out of bounds.")
+ elif isinstance(name_or_pos, str):
+ arg_name = name_or_pos
+ try:
+ return arguments[arg_name]
+ except KeyError:
+ raise ValueError(f"Argument {arg_name!r} doesn't exist.")
+ else:
+ raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
+
+
+def get_arg_value_wrapper(
+ decorator_func: typing.Callable[[ArgValGetter], Decorator],
+ name_or_pos: Argument,
+ func: typing.Callable[[typing.Any], typing.Any] | None = None,
+) -> Decorator:
+ """
+ Call `decorator_func` with the value of the arg at the given name/position.
+
+ Arguments:
+ decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of
+ parameter names to argument values of the function it's decorating.
+ name_or_pos: The name/position of the arg to get the value from.
+ func: An optional callable which will return a new value given the argument's value.
+
+ Returns:
+ The decorator returned by `decorator_func`.
+ """
+ def wrapper(args: BoundArgs) -> typing.Any:
+ value = get_arg_value(name_or_pos, args)
+ if func:
+ value = func(value)
+ return value
+
+ return decorator_func(wrapper)
+
+
+def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs:
+ """
+ Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
+
+ Default parameter values are also set.
+
+ Args:
+ args: The arguments to bind to ``func``
+ kwargs: The keyword arguments to bind to ``func``
+ func: The function to bind ``args`` and ``kwargs`` to
+ Returns:
+ A mapping of parameter names to argument values.
+ """
+ sig = inspect.signature(func)
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
+ return bound_args.arguments
+
+
def update_wrapper_globals(
wrapper: Callable[_P, _R],
wrapped: Callable[_P, _R],