diff options
| author | 2021-01-10 22:07:58 +0100 | |
|---|---|---|
| committer | 2021-01-10 22:07:58 +0100 | |
| commit | 4788a9364ac84cf0ee210c8b026ea7f2d5dd31ee (patch) | |
| tree | bd4a1c84bb6b570e9c01054c30fea47b05b51bc3 | |
| parent | Change the func name to wrapped for clarity (diff) | |
Create decorator for update_wrapper_globals mimicking functools.wraps
| -rw-r--r-- | bot/decorators.py | 14 | ||||
| -rw-r--r-- | bot/utils/function.py | 15 | ||||
| -rw-r--r-- | bot/utils/lock.py | 10 |
3 files changed, 30 insertions, 9 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index a37996e80..02735d0dc 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,8 +1,8 @@ import asyncio import logging +import types import typing as t from contextlib import suppress -from functools import update_wrapper from discord import Member, NotFound from discord.ext import commands @@ -11,6 +11,7 @@ from discord.ext.commands import Cog, Context from bot.constants import Channels, RedirectOutput from bot.utils import function from bot.utils.checks import in_whitelist_check +from bot.utils.function import command_wraps log = logging.getLogger(__name__) @@ -70,7 +71,8 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N This decorator must go before (below) the `command` decorator. """ - def wrap(func: t.Callable) -> t.Callable: + def wrap(func: types.FunctionType) -> types.FunctionType: + @command_wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: if ctx.channel.id == destination_channel: log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting") @@ -104,8 +106,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N with suppress(NotFound): await ctx.message.delete() log.trace("Redirect output: Deleted invocation message") - - return update_wrapper(function.update_wrapper_globals(inner, func), func) + return inner return wrap @@ -121,7 +122,8 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: This decorator must go before (below) the `command` decorator. """ - def decorator(func: t.Callable) -> t.Callable: + def decorator(func: types.FunctionType) -> types.FunctionType: + @command_wraps(func) async def wrapper(*args, **kwargs) -> None: log.trace(f"{func.__name__}: respect role hierarchy decorator called") @@ -149,5 +151,5 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: else: log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func") await func(*args, **kwargs) - return update_wrapper(function.update_wrapper_globals(wrapper, func), func) + return wrapper return decorator diff --git a/bot/utils/function.py b/bot/utils/function.py index 037516ac4..5fd70e1e8 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -1,5 +1,6 @@ """Utilities for interaction with functions.""" +import functools import inspect import types import typing as t @@ -100,3 +101,17 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT argdefs=wrapper.__defaults__, closure=wrapper.__closure__, ) + + +def command_wraps( + wrapped: types.FunctionType, + assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS, + updated: t.Sequence[str] = functools.WRAPPER_UPDATES, +) -> t.Callable[[types.FunctionType], types.FunctionType]: + """Update `wrapped` to look like the decorated function and update globals for discordpy forwardref evaluation.""" + def decorator(wrapper: types.FunctionType) -> types.FunctionType: + return functools.update_wrapper( + update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated + ) + + return decorator diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 02188c827..978e3ae94 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,12 +1,14 @@ import inspect import logging +import types from collections import defaultdict -from functools import partial, update_wrapper +from functools import partial from typing import Any, Awaitable, Callable, Hashable, Union from weakref import WeakValueDictionary from bot.errors import LockedResourceError from bot.utils import function +from bot.utils.function import command_wraps log = logging.getLogger(__name__) __lock_dicts = defaultdict(WeakValueDictionary) @@ -58,9 +60,10 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa If decorating a command, this decorator must go before (below) the `command` decorator. """ - def decorator(func: Callable) -> Callable: + def decorator(func: types.FunctionType) -> types.FunctionType: name = func.__name__ + @command_wraps(func) async def wrapper(*args, **kwargs) -> Any: log.trace(f"{name}: mutually exclusive decorator called") @@ -91,7 +94,8 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") if raise_error: raise LockedResourceError(str(namespace), id_) - return update_wrapper(function.update_wrapper_globals(wrapper, func), func) + return wrapper + return decorator |