diff options
| -rw-r--r-- | bot/decorators.py | 6 | ||||
| -rw-r--r-- | bot/utils/function.py | 27 | ||||
| -rw-r--r-- | bot/utils/lock.py | 3 |
3 files changed, 30 insertions, 6 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index 063c8f878..3892e350f 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -71,7 +71,6 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N This decorator must go before (below) the `command` decorator. """ def wrap(func: t.Callable) -> t.Callable: - @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: if ctx.channel.id == destination_channel: log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting") @@ -106,7 +105,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N await ctx.message.delete() log.trace("Redirect output: Deleted invocation message") - return inner + return wraps(func)(function.update_wrapper_globals(inner, func)) return wrap @@ -123,7 +122,6 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: This decorator must go before (below) the `command` decorator. """ def decorator(func: t.Callable) -> t.Callable: - @wraps(func) async def wrapper(*args, **kwargs) -> None: log.trace(f"{func.__name__}: respect role hierarchy decorator called") @@ -151,5 +149,5 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: else: log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func") await func(*args, **kwargs) - return wrapper + return wraps(func)(function.update_wrapper_globals(wrapper, func)) return decorator diff --git a/bot/utils/function.py b/bot/utils/function.py index 3ab32fe3c..8b8c7ba5c 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -1,6 +1,7 @@ """Utilities for interaction with functions.""" import inspect +import types import typing as t Argument = t.Union[int, str] @@ -73,3 +74,29 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) bound_args.apply_defaults() return bound_args.arguments + + +def update_wrapper_globals(wrapper: types.FunctionType, func: types.FunctionType) -> types.FunctionType: + """ + Update globals of `wrapper` with the globals from `func`. + + For forwardrefs in command annotations discordpy uses the __global__ attribute of the function + to resolve their values, with decorators that replace the function this breaks because they have + their own globals. + + This function creates a new function functionally identical to `wrapper`, which has the globals replaced with + a merge of `func`s globals and the `wrapper`s globals. + + In case a global name from `func` conflicts with a name from `wrapper`'s globals, `wrapper` will win + to keep it functional, but this may cause problems if the name is used as an annotation and + discord.py uses it as a converter on a parameter from `func`. + """ + new_globals = wrapper.__globals__.copy() + new_globals.update((k, v) for k, v in func.__globals__.items() if k not in wrapper.__code__.co_names) + return types.FunctionType( + code=wrapper.__code__, + globals=new_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 7aaafbc88..cf87321c5 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -61,7 +61,6 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa def decorator(func: Callable) -> Callable: name = func.__name__ - @wraps(func) async def wrapper(*args, **kwargs) -> Any: log.trace(f"{name}: mutually exclusive decorator called") @@ -93,7 +92,7 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa if raise_error: raise LockedResourceError(str(namespace), id_) - return wrapper + return wraps(func)(function.update_wrapper_globals(wrapper, func)) return decorator |