diff options
author | 2021-01-14 05:00:22 +0100 | |
---|---|---|
committer | 2021-01-14 05:01:14 +0100 | |
commit | e86e9f921a4bbbe42a5fb6fd8486425f11af62cf (patch) | |
tree | a9b7282fd4a498aa9efab7515bece9dea6f72190 | |
parent | Ensure footer is actually max 100 chars (diff) |
Raise an error or log a warning if there's a global name conflict
When wrapper uses a global name, which conflicts with a global name
from wrapped's module that wrapped uses for its annotations, we run into
a situation that can't be solved without changing one of the names, so
an error is raised to give this clearer meaning.
The check may be erroneous in some edge cases or the objects the
conflicting names refer to can be functionally identical, so the error
can be turned into a logged warning.
-rw-r--r-- | bot/utils/function.py | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/bot/utils/function.py b/bot/utils/function.py index ab7f45761..4fa7a9f60 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -2,15 +2,22 @@ import functools import inspect +import logging import types import typing as t +log = logging.getLogger(__name__) + Argument = t.Union[int, str] BoundArgs = t.OrderedDict[str, t.Any] Decorator = t.Callable[[t.Callable], t.Callable] ArgValGetter = t.Callable[[BoundArgs], t.Any] +class GlobalNameConflictError(Exception): + """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper.""" + + def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any: """ Return a value from `arguments` based on a name or position. @@ -77,7 +84,12 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) return bound_args.arguments -def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionType) -> types.FunctionType: +def update_wrapper_globals( + wrapper: types.FunctionType, + wrapped: types.FunctionType, + *, + error_on_conflict: bool = True, +) -> types.FunctionType: """ Update globals of `wrapper` with the globals from `wrapped`. @@ -88,10 +100,26 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT This function creates a new function functionally identical to `wrapper`, which has the globals replaced with a merge of `wrapped`s globals and the `wrapper`s globals. - In case a global name from `wrapped` conflicts with a name from `wrapper`'s globals, `wrapper` will win - to keep it functional, but this may cause problems if the name is used as an annotation and - discord.py uses it as a converter on a parameter from `wrapped`. + If `error_on_conflict` is True, an exception will be raised in case `wrapper` and `wrapped` share a global name + that is used by `wrapped`'s typehints, as this can cause incorrect objects being used by discordpy's converters. + The error can be turned into a warning by setting the argument to False. """ + forwardrefs = (ann for ann in wrapped.__annotations__.values() if isinstance(ann, str)) + annotation_global_names = (ann.split(".", maxsplit=1)[0] for ann in forwardrefs) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names) + shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__) + if shared_globals: + message = ( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. " + f"Resolve the conflicts or pass error_on_conflict=False to suppress this error if this is intentional." + ) + if error_on_conflict: + raise GlobalNameConflictError(message) + else: + log.info(message) + new_globals = wrapper.__globals__.copy() new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) return types.FunctionType( @@ -107,11 +135,13 @@ def command_wraps( wrapped: types.FunctionType, assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS, updated: t.Sequence[str] = functools.WRAPPER_UPDATES, + *, + error_on_conflict: bool = True, ) -> t.Callable[[types.FunctionType], types.FunctionType]: """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( - update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated + update_wrapper_globals(wrapper, wrapped, error_on_conflict=error_on_conflict), wrapped, assigned, updated ) return decorator |