diff options
| -rw-r--r-- | bot/utils/function.py | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/bot/utils/function.py b/bot/utils/function.py index ab7f45761..4fa7a9f60 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -2,15 +2,22 @@ import functools import inspect +import logging import types import typing as t +log = logging.getLogger(__name__) + Argument = t.Union[int, str] BoundArgs = t.OrderedDict[str, t.Any] Decorator = t.Callable[[t.Callable], t.Callable] ArgValGetter = t.Callable[[BoundArgs], t.Any] +class GlobalNameConflictError(Exception): + """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper.""" + + def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any: """ Return a value from `arguments` based on a name or position. @@ -77,7 +84,12 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) return bound_args.arguments -def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionType) -> types.FunctionType: +def update_wrapper_globals( + wrapper: types.FunctionType, + wrapped: types.FunctionType, + *, + error_on_conflict: bool = True, +) -> types.FunctionType: """ Update globals of `wrapper` with the globals from `wrapped`. @@ -88,10 +100,26 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT This function creates a new function functionally identical to `wrapper`, which has the globals replaced with a merge of `wrapped`s globals and the `wrapper`s globals. - In case a global name from `wrapped` conflicts with a name from `wrapper`'s globals, `wrapper` will win - to keep it functional, but this may cause problems if the name is used as an annotation and - discord.py uses it as a converter on a parameter from `wrapped`. + If `error_on_conflict` is True, an exception will be raised in case `wrapper` and `wrapped` share a global name + that is used by `wrapped`'s typehints, as this can cause incorrect objects being used by discordpy's converters. + The error can be turned into a warning by setting the argument to False. """ + forwardrefs = (ann for ann in wrapped.__annotations__.values() if isinstance(ann, str)) + annotation_global_names = (ann.split(".", maxsplit=1)[0] for ann in forwardrefs) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names) + shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__) + if shared_globals: + message = ( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. " + f"Resolve the conflicts or pass error_on_conflict=False to suppress this error if this is intentional." + ) + if error_on_conflict: + raise GlobalNameConflictError(message) + else: + log.info(message) + new_globals = wrapper.__globals__.copy() new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) return types.FunctionType( @@ -107,11 +135,13 @@ def command_wraps( wrapped: types.FunctionType, assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS, updated: t.Sequence[str] = functools.WRAPPER_UPDATES, + *, + error_on_conflict: bool = True, ) -> t.Callable[[types.FunctionType], types.FunctionType]: """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( - update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated + update_wrapper_globals(wrapper, wrapped, error_on_conflict=error_on_conflict), wrapped, assigned, updated ) return decorator |