1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
"""Utils for manipulating functions."""
from __future__ import annotations
import functools
import types
import typing
from collections.abc import Callable, Sequence, Set
__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]
if typing.TYPE_CHECKING:
_P = typing.ParamSpec("_P")
_R = typing.TypeVar("_R")
class GlobalNameConflictError(Exception):
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
def update_wrapper_globals(
wrapper: Callable[_P, _R],
wrapped: Callable[_P, _R],
*,
ignored_conflict_names: Set[str] = frozenset(),
) -> Callable[_P, _R]:
r"""
Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals.
For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function
to resolve their values. This breaks for decorators that replace the function because they have
their own globals.
.. warning::
This function captures the state of ``wrapped``\'s module's globals when it's called;
changes won't be reflected in the new function's globals.
Args:
wrapper: The function to wrap.
wrapped: The function to wrap with.
ignored_conflict_names: A set of names to ignore if a conflict between them is found.
Raises:
:exc:`GlobalNameConflictError`:
If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints,
and is not in ``ignored_conflict_names``.
"""
wrapped = typing.cast(types.FunctionType, wrapped)
wrapper = typing.cast(types.FunctionType, wrapper)
annotation_global_names = (
ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
)
# Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
shared_globals = (
set(wrapper.__code__.co_names)
& set(annotation_global_names)
& set(wrapped.__globals__)
& set(wrapper.__globals__)
- ignored_conflict_names
)
if shared_globals:
raise GlobalNameConflictError(
f"wrapper and the wrapped function share the following "
f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
)
new_globals = wrapper.__globals__.copy()
new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
return types.FunctionType(
code=wrapper.__code__,
globals=new_globals,
name=wrapper.__name__,
argdefs=wrapper.__defaults__,
closure=wrapper.__closure__,
)
def command_wraps(
wrapped: Callable[_P, _R],
assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
updated: Sequence[str] = functools.WRAPPER_UPDATES,
*,
ignored_conflict_names: Set[str] = frozenset(),
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
r"""
Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation.
See :func:`update_wrapper_globals` for more details on how the globals are updated.
Args:
wrapped: The function to wrap with.
assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``.
updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``.
ignored_conflict_names: A set of names to ignore if a conflict between them is found.
Returns:
A decorator that behaves like :func:`functools.wraps`,
with the wrapper replaced with the function :func:`update_wrapper_globals` returned.
"""
def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]:
return functools.update_wrapper(
update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
wrapped,
assigned,
updated,
)
return decorator
|