aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--botcore/utils/__init__.py3
-rw-r--r--botcore/utils/function.py116
2 files changed, 118 insertions, 1 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py
index fa389743..6055d144 100644
--- a/botcore/utils/__init__.py
+++ b/botcore/utils/__init__.py
@@ -1,6 +1,6 @@
"""Useful utilities and tools for Discord bot development."""
-from botcore.utils import _monkey_patches, caching, channel, logging, members, regex, scheduling
+from botcore.utils import _monkey_patches, caching, channel, function, logging, members, regex, scheduling
from botcore.utils._extensions import unqualify
@@ -24,6 +24,7 @@ __all__ = [
apply_monkey_patches,
caching,
channel,
+ function,
logging,
members,
regex,
diff --git a/botcore/utils/function.py b/botcore/utils/function.py
new file mode 100644
index 00000000..1cde5cd9
--- /dev/null
+++ b/botcore/utils/function.py
@@ -0,0 +1,116 @@
+"""Utils for manipulating functions."""
+
+from __future__ import annotations
+
+import functools
+import types
+import typing
+from collections.abc import Sequence, Set
+from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable
+
+__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]
+
+
+if typing.TYPE_CHECKING:
+ import typing_extensions
+ _P = typing_extensions.ParamSpec("_P")
+ _R = typing.TypeVar("_R")
+
+
+class GlobalNameConflictError(Exception):
+ """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
+
+
+def update_wrapper_globals(
+ wrapper: Callable[_P, _R],
+ wrapped: Callable[_P, _R],
+ *,
+ ignored_conflict_names: Set[str] = frozenset(),
+) -> Callable[_P, _R]:
+ r"""
+ Update globals of the ``wrapper`` function with the globals from the ``wrapped`` function.
+
+ For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function
+ to resolve their values, with decorators that replace the function this breaks because they have
+ their own globals.
+
+ This function creates a new function functionally identical to ``wrapper``\, which has the globals replaced with
+ a merge of ``wrapped``\s globals and the ``wrapper``\s globals.
+
+ .. warning::
+ This function captures the state of ``wrapped``\'s module's globals when it's called,
+ changes won't be reflected in the new function's globals.
+
+ Args:
+ wrapper: The function to wrap.
+ wrapped: The function to wrap with.
+ ignored_conflict_names: A set of names to ignore if a conflict between them is found.
+
+ Raises:
+ :exc:`GlobalNameConflictError`:
+ If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints,
+ and is not in ``ignored_conflict_names``.
+ """
+ wrapped = typing.cast(types.FunctionType, wrapped)
+ wrapper = typing.cast(types.FunctionType, wrapper)
+
+ annotation_global_names = (
+ ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
+ )
+ # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
+ shared_globals = (
+ set(wrapper.__code__.co_names)
+ & set(annotation_global_names)
+ & set(wrapped.__globals__)
+ & set(wrapper.__globals__)
+ - ignored_conflict_names
+ )
+ if shared_globals:
+ raise GlobalNameConflictError(
+ f"wrapper and the wrapped function share the following "
+ f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
+ f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
+ )
+
+ new_globals = wrapper.__globals__.copy()
+ new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
+ return types.FunctionType(
+ code=wrapper.__code__,
+ globals=new_globals,
+ name=wrapper.__name__,
+ argdefs=wrapper.__defaults__,
+ closure=wrapper.__closure__,
+ )
+
+
+def command_wraps(
+ wrapped: Callable[_P, _R],
+ assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
+ updated: Sequence[str] = functools.WRAPPER_UPDATES,
+ *,
+ ignored_conflict_names: Set[str] = frozenset(),
+) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
+ r"""
+ Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation.
+
+ See :func:`update_wrapper_globals` for more details on how the globals are updated.
+
+ Args:
+ wrapped: The function to wrap with.
+ assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``.
+ updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``.
+ ignored_conflict_names: A set of names to ignore if a conflict between them is found.
+
+ Returns:
+ A decorator that behaves like :func:`functools.wraps`,
+ with the wrapper replaced with the function :func:`update_wrapper_globals` returned.
+ """ # noqa: D200
+ def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]:
+ return functools.update_wrapper(
+ update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
+ wrapped,
+ assigned,
+ updated,
+ )
+
+ return decorator