aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-07-17 23:12:07 -0700
committerGravatar MarkKoz <[email protected]>2020-07-31 22:58:06 -0700
commita30b640a61943c1992b8614b98667d49653d8b71 (patch)
treefe2eaa519ad9b14a9399894a3d58e7d209026e49
parentDecorators: clarify use of mutually_exclusive with commands (diff)
Decorators: pass bound arguments to callable
Bound arguments are more convenient to work with than the raw args and kwargs.
-rw-r--r--bot/decorators.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index f581e66d2..7f58abd1c 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -18,9 +18,10 @@ log = logging.getLogger(__name__)
__lock_dicts = defaultdict(WeakValueDictionary)
Argument = t.Union[int, str]
-_IdCallable = t.Callable[..., t.Hashable]
-_IdAwaitable = t.Callable[..., t.Awaitable[t.Hashable]]
-ResourceId = t.Union[t.Hashable, _IdCallable, _IdAwaitable]
+BoundArgs = t.OrderedDict[str, t.Any]
+_IdCallableReturn = t.Union[t.Hashable, t.Awaitable[t.Hashable]]
+_IdCallable = t.Callable[[BoundArgs], _IdCallableReturn]
+ResourceId = t.Union[t.Hashable, _IdCallable]
def in_whitelist(
@@ -108,8 +109,8 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call
`namespace` is an identifier used to prevent collisions among resource IDs.
`resource_id` identifies a resource on which to perform a mutually exclusive operation.
- It may also be a callable or awaitable which will return the resource ID given the decorated
- function's args and kwargs.
+ It may also be a callable or awaitable which will return the resource ID given an ordered
+ mapping of the parameters' names to arguments' values.
If decorating a command, this decorator must go before (below) the `command` decorator.
"""
@@ -121,8 +122,13 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call
log.trace(f"{name}: mutually exclusive decorator called")
if callable(resource_id):
+ log.trace(f"{name}: binding args to signature")
+ sig = inspect.signature(func)
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
log.trace(f"{name}: calling the given callable to get the resource ID")
- id_ = resource_id(*args, **kwargs)
+ id_ = resource_id(bound_args.arguments)
if inspect.isawaitable(id_):
log.trace(f"{name}: awaiting to get resource ID")