aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-07-17 11:22:00 -0700
committerGravatar MarkKoz <[email protected]>2020-07-31 22:58:05 -0700
commit226d5e17b74d711776f7e7f49a8712ba820ac5ba (patch)
treee05e024212289b4587565cc5c5d2ff10a3abfe28
parentDecorators: drop arg pos/name support for mutually_exclusive (diff)
Decorators: support awaitables for resource ID
-rw-r--r--bot/decorators.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index f49499856..063368dda 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,4 +1,5 @@
import asyncio
+import inspect
import logging
import random
import typing as t
@@ -17,7 +18,9 @@ log = logging.getLogger(__name__)
__lock_dicts = defaultdict(WeakValueDictionary)
Argument = t.Union[int, str]
-ResourceId = t.Union[t.Hashable, t.Callable[..., t.Hashable]]
+_IdCallable = t.Callable[..., t.Hashable]
+_IdAwaitable = t.Callable[..., t.Awaitable[t.Hashable]]
+ResourceId = t.Union[t.Hashable, _IdCallable, _IdAwaitable]
def in_whitelist(
@@ -104,9 +107,9 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call
`namespace` is an identifier used to prevent collisions among resource IDs.
- `resource_id` identifies a resource on which to perform a mutually exclusive operation. It may
- also be a callable which will return the resource ID given the decorated function's args and
- kwargs.
+ `resource_id` identifies a resource on which to perform a mutually exclusive operation.
+ It may also be a callable or awaitable which will return the resource ID given the decorated
+ function's args and kwargs.
"""
def decorator(func: t.Callable) -> t.Callable:
@wraps(func)
@@ -114,6 +117,10 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call
if callable(resource_id):
# Call to get the ID if a callable was given.
id_ = resource_id(*args, **kwargs)
+
+ if inspect.isawaitable(id_):
+ # Await to get the ID if an awaitable was given.
+ id_ = await id_
else:
id_ = resource_id