From 646e21964c71e7a1de7f7bcccadebefa821c1e8d Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 13 Jun 2022 14:56:27 +0200 Subject: port command_wraps/update_wrapper_globals from bot --- botcore/utils/__init__.py | 3 +- botcore/utils/function.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 botcore/utils/function.py (limited to 'botcore') diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index fa389743..6055d144 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,6 @@ """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, logging, members, regex, scheduling +from botcore.utils import _monkey_patches, caching, channel, function, logging, members, regex, scheduling from botcore.utils._extensions import unqualify @@ -24,6 +24,7 @@ __all__ = [ apply_monkey_patches, caching, channel, + function, logging, members, regex, diff --git a/botcore/utils/function.py b/botcore/utils/function.py new file mode 100644 index 00000000..1cde5cd9 --- /dev/null +++ b/botcore/utils/function.py @@ -0,0 +1,116 @@ +"""Utils for manipulating functions.""" + +from __future__ import annotations + +import functools +import types +import typing +from collections.abc import Sequence, Set +from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable + +__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] + + +if typing.TYPE_CHECKING: + import typing_extensions + _P = typing_extensions.ParamSpec("_P") + _R = typing.TypeVar("_R") + + +class GlobalNameConflictError(Exception): + """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" + + +def update_wrapper_globals( + wrapper: Callable[_P, _R], + wrapped: Callable[_P, _R], + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[_P, _R]: + r""" + Update globals of the ``wrapper`` function with the globals from the ``wrapped`` function. + + For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function + to resolve their values, with decorators that replace the function this breaks because they have + their own globals. + + This function creates a new function functionally identical to ``wrapper``\, which has the globals replaced with + a merge of ``wrapped``\s globals and the ``wrapper``\s globals. + + .. warning:: + This function captures the state of ``wrapped``\'s module's globals when it's called, + changes won't be reflected in the new function's globals. + + Args: + wrapper: The function to wrap. + wrapped: The function to wrap with. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Raises: + :exc:`GlobalNameConflictError`: + If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, + and is not in ``ignored_conflict_names``. + """ + wrapped = typing.cast(types.FunctionType, wrapped) + wrapper = typing.cast(types.FunctionType, wrapper) + + annotation_global_names = ( + ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) + ) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = ( + set(wrapper.__code__.co_names) + & set(annotation_global_names) + & set(wrapped.__globals__) + & set(wrapper.__globals__) + - ignored_conflict_names + ) + if shared_globals: + raise GlobalNameConflictError( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " + f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." + ) + + new_globals = wrapper.__globals__.copy() + new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) + return types.FunctionType( + code=wrapper.__code__, + globals=new_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) + + +def command_wraps( + wrapped: Callable[_P, _R], + assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, + updated: Sequence[str] = functools.WRAPPER_UPDATES, + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + r""" + Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. + + See :func:`update_wrapper_globals` for more details on how the globals are updated. + + Args: + wrapped: The function to wrap with. + assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. + updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Returns: + A decorator that behaves like :func:`functools.wraps`, + with the wrapper replaced with the function :func:`update_wrapper_globals` returned. + """ # noqa: D200 + def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: + return functools.update_wrapper( + update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), + wrapped, + assigned, + updated, + ) + + return decorator -- cgit v1.2.3