aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils/function.py
blob: 98737af0f69f7579f4e9ab40eb1dfe2bb3ccf6d0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Utils for manipulating functions."""

from __future__ import annotations

import functools
import types
import typing
from collections.abc import Callable, Sequence, Set

__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]


if typing.TYPE_CHECKING:
    _P = typing.ParamSpec("_P")
    _R = typing.TypeVar("_R")


class GlobalNameConflictError(Exception):
    """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""


def update_wrapper_globals(
    wrapper: Callable[_P, _R],
    wrapped: Callable[_P, _R],
    *,
    ignored_conflict_names: Set[str] = frozenset(),
) -> Callable[_P, _R]:
    r"""
    Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals.

    For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function
    to resolve their values. This breaks for decorators that replace the function because they have
    their own globals.

    .. warning::
        This function captures the state of ``wrapped``\'s module's globals when it's called;
        changes won't be reflected in the new function's globals.

    Args:
        wrapper: The function to wrap.
        wrapped: The function to wrap with.
        ignored_conflict_names: A set of names to ignore if a conflict between them is found.

    Raises:
        :exc:`GlobalNameConflictError`:
            If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints,
            and is not in ``ignored_conflict_names``.
    """
    wrapped = typing.cast(types.FunctionType, wrapped)
    wrapper = typing.cast(types.FunctionType, wrapper)

    annotation_global_names = (
        ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
    )
    # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
    shared_globals = (
        set(wrapper.__code__.co_names)
        & set(annotation_global_names)
        & set(wrapped.__globals__)
        & set(wrapper.__globals__)
        - ignored_conflict_names
    )
    if shared_globals:
        raise GlobalNameConflictError(
            f"wrapper and the wrapped function share the following "
            f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
            f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
        )

    new_globals = wrapper.__globals__.copy()
    new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
    return types.FunctionType(
        code=wrapper.__code__,
        globals=new_globals,
        name=wrapper.__name__,
        argdefs=wrapper.__defaults__,
        closure=wrapper.__closure__,
    )


def command_wraps(
    wrapped: Callable[_P, _R],
    assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
    updated: Sequence[str] = functools.WRAPPER_UPDATES,
    *,
    ignored_conflict_names: Set[str] = frozenset(),
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
    r"""
    Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation.

    See :func:`update_wrapper_globals` for more details on how the globals are updated.

    Args:
        wrapped: The function to wrap with.
        assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``.
        updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``.
        ignored_conflict_names: A set of names to ignore if a conflict between them is found.

    Returns:
        A decorator that behaves like :func:`functools.wraps`,
        with the wrapper replaced with the function :func:`update_wrapper_globals` returned.
    """
    def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]:
        return functools.update_wrapper(
            update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
            wrapped,
            assigned,
            updated,
        )

    return decorator