diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/decorators.py | 14 | ||||
| -rw-r--r-- | bot/utils/function.py | 15 | ||||
| -rw-r--r-- | bot/utils/lock.py | 10 | 
3 files changed, 30 insertions, 9 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index a37996e80..02735d0dc 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,8 +1,8 @@  import asyncio  import logging +import types  import typing as t  from contextlib import suppress -from functools import update_wrapper  from discord import Member, NotFound  from discord.ext import commands @@ -11,6 +11,7 @@ from discord.ext.commands import Cog, Context  from bot.constants import Channels, RedirectOutput  from bot.utils import function  from bot.utils.checks import in_whitelist_check +from bot.utils.function import command_wraps  log = logging.getLogger(__name__) @@ -70,7 +71,8 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N      This decorator must go before (below) the `command` decorator.      """ -    def wrap(func: t.Callable) -> t.Callable: +    def wrap(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func)          async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:              if ctx.channel.id == destination_channel:                  log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting") @@ -104,8 +106,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N                  with suppress(NotFound):                      await ctx.message.delete()                      log.trace("Redirect output: Deleted invocation message") - -        return update_wrapper(function.update_wrapper_globals(inner, func), func) +        return inner      return wrap @@ -121,7 +122,8 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:      This decorator must go before (below) the `command` decorator.      """ -    def decorator(func: t.Callable) -> t.Callable: +    def decorator(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func)          async def wrapper(*args, **kwargs) -> None:              log.trace(f"{func.__name__}: respect role hierarchy decorator called") @@ -149,5 +151,5 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:              else:                  log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func")                  await func(*args, **kwargs) -        return update_wrapper(function.update_wrapper_globals(wrapper, func), func) +        return wrapper      return decorator diff --git a/bot/utils/function.py b/bot/utils/function.py index 037516ac4..5fd70e1e8 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -1,5 +1,6 @@  """Utilities for interaction with functions.""" +import functools  import inspect  import types  import typing as t @@ -100,3 +101,17 @@ def update_wrapper_globals(wrapper: types.FunctionType, wrapped: types.FunctionT          argdefs=wrapper.__defaults__,          closure=wrapper.__closure__,      ) + + +def command_wraps( +        wrapped: types.FunctionType, +        assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS, +        updated: t.Sequence[str] = functools.WRAPPER_UPDATES, +) -> t.Callable[[types.FunctionType], types.FunctionType]: +    """Update `wrapped` to look like the decorated function and update globals for discordpy forwardref evaluation.""" +    def decorator(wrapper: types.FunctionType) -> types.FunctionType: +        return functools.update_wrapper( +            update_wrapper_globals(wrapper, wrapped), wrapped, assigned, updated +        ) + +    return decorator diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 02188c827..978e3ae94 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,12 +1,14 @@  import inspect  import logging +import types  from collections import defaultdict -from functools import partial, update_wrapper +from functools import partial  from typing import Any, Awaitable, Callable, Hashable, Union  from weakref import WeakValueDictionary  from bot.errors import LockedResourceError  from bot.utils import function +from bot.utils.function import command_wraps  log = logging.getLogger(__name__)  __lock_dicts = defaultdict(WeakValueDictionary) @@ -58,9 +60,10 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa      If decorating a command, this decorator must go before (below) the `command` decorator.      """ -    def decorator(func: Callable) -> Callable: +    def decorator(func: types.FunctionType) -> types.FunctionType:          name = func.__name__ +        @command_wraps(func)          async def wrapper(*args, **kwargs) -> Any:              log.trace(f"{name}: mutually exclusive decorator called") @@ -91,7 +94,8 @@ def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = Fa                  log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")                  if raise_error:                      raise LockedResourceError(str(namespace), id_) -        return update_wrapper(function.update_wrapper_globals(wrapper, func), func) +        return wrapper +      return decorator  |