diff options
| author | 2022-06-13 14:56:27 +0200 | |
|---|---|---|
| committer | 2022-06-15 00:02:08 +0200 | |
| commit | 646e21964c71e7a1de7f7bcccadebefa821c1e8d (patch) | |
| tree | 67f6d1523f7e5f0c2b12d2e5a7556025268b10dd /botcore | |
| parent | Merge pull request #79 from python-discord/upgrade-docs (diff) | |
port command_wraps/update_wrapper_globals from bot
Diffstat (limited to 'botcore')
| -rw-r--r-- | botcore/utils/__init__.py | 3 | ||||
| -rw-r--r-- | botcore/utils/function.py | 116 | 
2 files changed, 118 insertions, 1 deletions
| diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index fa389743..6055d144 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,6 @@  """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, logging, members, regex, scheduling +from botcore.utils import _monkey_patches, caching, channel, function, logging, members, regex, scheduling  from botcore.utils._extensions import unqualify @@ -24,6 +24,7 @@ __all__ = [      apply_monkey_patches,      caching,      channel, +    function,      logging,      members,      regex, diff --git a/botcore/utils/function.py b/botcore/utils/function.py new file mode 100644 index 00000000..1cde5cd9 --- /dev/null +++ b/botcore/utils/function.py @@ -0,0 +1,116 @@ +"""Utils for manipulating functions.""" + +from __future__ import annotations + +import functools +import types +import typing +from collections.abc import Sequence, Set +from typing import Callable  # sphinx-autodoc-typehints breaks with collections.abc.Callable + +__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] + + +if typing.TYPE_CHECKING: +    import typing_extensions +    _P = typing_extensions.ParamSpec("_P") +    _R = typing.TypeVar("_R") + + +class GlobalNameConflictError(Exception): +    """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" + + +def update_wrapper_globals( +    wrapper: Callable[_P, _R], +    wrapped: Callable[_P, _R], +    *, +    ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[_P, _R]: +    r""" +    Update globals of the ``wrapper`` function with the globals from the ``wrapped`` function. + +    For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function +    to resolve their values, with decorators that replace the function this breaks because they have +    their own globals. + +    This function creates a new function functionally identical to ``wrapper``\, which has the globals replaced with +    a merge of ``wrapped``\s globals and the ``wrapper``\s globals. + +    .. warning:: +        This function captures the state of ``wrapped``\'s module's globals when it's called, +        changes won't be reflected in the new function's globals. + +    Args: +        wrapper: The function to wrap. +        wrapped: The function to wrap with. +        ignored_conflict_names: A set of names to ignore if a conflict between them is found. + +    Raises: +        :exc:`GlobalNameConflictError`: +            If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, +            and is not in ``ignored_conflict_names``. +    """ +    wrapped = typing.cast(types.FunctionType, wrapped) +    wrapper = typing.cast(types.FunctionType, wrapper) + +    annotation_global_names = ( +        ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) +    ) +    # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. +    shared_globals = ( +        set(wrapper.__code__.co_names) +        & set(annotation_global_names) +        & set(wrapped.__globals__) +        & set(wrapper.__globals__) +        - ignored_conflict_names +    ) +    if shared_globals: +        raise GlobalNameConflictError( +            f"wrapper and the wrapped function share the following " +            f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " +            f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." +        ) + +    new_globals = wrapper.__globals__.copy() +    new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) +    return types.FunctionType( +        code=wrapper.__code__, +        globals=new_globals, +        name=wrapper.__name__, +        argdefs=wrapper.__defaults__, +        closure=wrapper.__closure__, +    ) + + +def command_wraps( +    wrapped: Callable[_P, _R], +    assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, +    updated: Sequence[str] = functools.WRAPPER_UPDATES, +    *, +    ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: +    r""" +    Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. + +    See :func:`update_wrapper_globals` for more details on how the globals are updated. + +    Args: +        wrapped: The function to wrap with. +        assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. +        updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. +        ignored_conflict_names: A set of names to ignore if a conflict between them is found. + +    Returns: +        A decorator that behaves like :func:`functools.wraps`, +        with the wrapper replaced with the function :func:`update_wrapper_globals` returned. +    """  # noqa: D200 +    def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: +        return functools.update_wrapper( +            update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), +            wrapped, +            assigned, +            updated, +        ) + +    return decorator | 
