aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Numerlor <[email protected]>2021-01-14 05:00:22 +0100
committerGravatar Numerlor <[email protected]>2021-01-14 05:01:14 +0100
commite86e9f921a4bbbe42a5fb6fd8486425f11af62cf (patch)
treea9b7282fd4a498aa9efab7515bece9dea6f72190
parentEnsure footer is actually max 100 chars (diff)
Raise an error or log a warning if there's a global name conflict
When wrapper uses a global name, which conflicts with a global name from wrapped's module that wrapped uses for its annotations, we run into a situation that can't be solved without changing one of the names, so an error is raised to give this clearer meaning. The check may be erroneous in some edge cases or the objects the conflicting names refer to can be functionally identical, so the error can be turned into a logged warning.
-rw-r--r--bot/utils/function.py40
1 files changed, 35 insertions, 5 deletions
diff --git a/bot/utils/function.py b/bot/utils/function.py
index ab7f45761..4fa7a9f60 100644
--- a/bot/utils/function.py
+++ b/bot/utils/function.py
@@ -2,15 +2,22 @@
import functools
import inspect
+import logging
import types
import typing as t
+log = logging.getLogger(__name__)
+
Argument = t.Union[int, str]
BoundArgs = t.OrderedDict[str, t.Any]
Decorator = t.Callable[[t.Callable], t.Callable]
ArgValGetter = t.Callable[[BoundArgs], t.Any]
+class GlobalNameConflictError(Exception):
+ """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper."""
+
+
def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:
"""
Return a value from `arguments` based on a name or position.
@@ -77,7 +84,12 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any])
return bound_args.arguments
-def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionType) -> types.FunctionType:
+def update_wrapper_globals(
+ wrapper: types.FunctionType,
+ wrapped: types.FunctionType,
+ *,
+ error_on_conflict: bool = True,
+) -> types.FunctionType:
"""
Update globals of `wrapper` with the globals from `wrapped`.
@@ -88,10 +100,26 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT
This function creates a new function functionally identical to `wrapper`, which has the globals replaced with
a merge of `wrapped`s globals and the `wrapper`s globals.
- In case a global name from `wrapped` conflicts with a name from `wrapper`'s globals, `wrapper` will win
- to keep it functional, but this may cause problems if the name is used as an annotation and
- discord.py uses it as a converter on a parameter from `wrapped`.
+ If `error_on_conflict` is True, an exception will be raised in case `wrapper` and `wrapped` share a global name
+ that is used by `wrapped`'s typehints, as this can cause incorrect objects being used by discordpy's converters.
+ The error can be turned into a warning by setting the argument to False.
"""
+ forwardrefs = (ann for ann in wrapped.__annotations__.values() if isinstance(ann, str))
+ annotation_global_names = (ann.split(".", maxsplit=1)[0] for ann in forwardrefs)
+ # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
+ shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names)
+ shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__)
+ if shared_globals:
+ message = (
+ f"wrapper and the wrapped function share the following "
+ f"global names used by annotations: {', '.join(shared_globals)}. "
+ f"Resolve the conflicts or pass error_on_conflict=False to suppress this error if this is intentional."
+ )
+ if error_on_conflict:
+ raise GlobalNameConflictError(message)
+ else:
+ log.info(message)
+
new_globals = wrapper.__globals__.copy()
new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
return types.FunctionType(
@@ -107,11 +135,13 @@ def command_wraps(
wrapped: types.FunctionType,
assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
updated: t.Sequence[str] = functools.WRAPPER_UPDATES,
+ *,
+ error_on_conflict: bool = True,
) -> t.Callable[[types.FunctionType], types.FunctionType]:
"""Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation."""
def decorator(wrapper: types.FunctionType) -> types.FunctionType:
return functools.update_wrapper(
- update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated
+ update_wrapper_globals(wrapper, wrapped, error_on_conflict=error_on_conflict), wrapped, assigned, updated
)
return decorator