aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/utils/function.py40
1 files changed, 35 insertions, 5 deletions
diff --git a/bot/utils/function.py b/bot/utils/function.py
index ab7f45761..4fa7a9f60 100644
--- a/bot/utils/function.py
+++ b/bot/utils/function.py
@@ -2,15 +2,22 @@
import functools
import inspect
+import logging
import types
import typing as t
+log = logging.getLogger(__name__)
+
Argument = t.Union[int, str]
BoundArgs = t.OrderedDict[str, t.Any]
Decorator = t.Callable[[t.Callable], t.Callable]
ArgValGetter = t.Callable[[BoundArgs], t.Any]
+class GlobalNameConflictError(Exception):
+ """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper."""
+
+
def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:
"""
Return a value from `arguments` based on a name or position.
@@ -77,7 +84,12 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any])
return bound_args.arguments
-def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionType) -> types.FunctionType:
+def update_wrapper_globals(
+ wrapper: types.FunctionType,
+ wrapped: types.FunctionType,
+ *,
+ error_on_conflict: bool = True,
+) -> types.FunctionType:
"""
Update globals of `wrapper` with the globals from `wrapped`.
@@ -88,10 +100,26 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT
This function creates a new function functionally identical to `wrapper`, which has the globals replaced with
a merge of `wrapped`s globals and the `wrapper`s globals.
- In case a global name from `wrapped` conflicts with a name from `wrapper`'s globals, `wrapper` will win
- to keep it functional, but this may cause problems if the name is used as an annotation and
- discord.py uses it as a converter on a parameter from `wrapped`.
+ If `error_on_conflict` is True, an exception will be raised in case `wrapper` and `wrapped` share a global name
+ that is used by `wrapped`'s typehints, as this can cause incorrect objects being used by discordpy's converters.
+ The error can be turned into a warning by setting the argument to False.
"""
+ forwardrefs = (ann for ann in wrapped.__annotations__.values() if isinstance(ann, str))
+ annotation_global_names = (ann.split(".", maxsplit=1)[0] for ann in forwardrefs)
+ # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
+ shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names)
+ shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__)
+ if shared_globals:
+ message = (
+ f"wrapper and the wrapped function share the following "
+ f"global names used by annotations: {', '.join(shared_globals)}. "
+ f"Resolve the conflicts or pass error_on_conflict=False to suppress this error if this is intentional."
+ )
+ if error_on_conflict:
+ raise GlobalNameConflictError(message)
+ else:
+ log.info(message)
+
new_globals = wrapper.__globals__.copy()
new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
return types.FunctionType(
@@ -107,11 +135,13 @@ def command_wraps(
wrapped: types.FunctionType,
assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
updated: t.Sequence[str] = functools.WRAPPER_UPDATES,
+ *,
+ error_on_conflict: bool = True,
) -> t.Callable[[types.FunctionType], types.FunctionType]:
"""Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation."""
def decorator(wrapper: types.FunctionType) -> types.FunctionType:
return functools.update_wrapper(
- update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated
+ update_wrapper_globals(wrapper, wrapped, error_on_conflict=error_on_conflict), wrapped, assigned, updated
)
return decorator