aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-07-17 10:33:21 -0700
committerGravatar MarkKoz <[email protected]>2020-07-31 22:58:05 -0700
commitbe93601a31dcfa8acb03996eaaf2edcb654712f5 (patch)
treea2ff86f21c04698f80a6904121820eaf1aa22e0a
parentDecorators: create helper function to get arg value (diff)
Decorators: add mutually exclusive decorator
This will be used to prevent race conditions on a resource by stopping all other access to the resource once its been acquired.
-rw-r--r--bot/decorators.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index 1fe082b6e..cae0870b6 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -2,6 +2,7 @@ import asyncio
import logging
import random
import typing as t
+from collections import defaultdict
from contextlib import suppress
from functools import wraps
from weakref import WeakValueDictionary
@@ -13,8 +14,10 @@ from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check
log = logging.getLogger(__name__)
+__lock_dicts = defaultdict(WeakValueDictionary)
Argument = t.Union[int, str]
+ResourceId = t.Union[Argument, t.Callable[..., t.Hashable]]
def in_whitelist(
@@ -92,6 +95,42 @@ def locked() -> t.Callable:
return wrap
+def mutually_exclusive(namespace: t.Hashable, resource_arg: ResourceId) -> t.Callable:
+ """
+ Turn the decorated coroutine function into a mutually exclusive operation on a resource.
+
+ If any other mutually exclusive function currently holds the lock for a resource, do not run the
+ decorated function and return None.
+
+ `namespace` is an identifier used to prevent collisions among resource IDs.
+
+ `resource_arg` is the positional index or name of the parameter of the decorated function whose
+ value will be the resource ID. It may also be a callable which will return the resource ID
+ given the decorated function's args and kwargs.
+ """
+ def decorator(func: t.Callable) -> t.Callable:
+ @wraps(func)
+ async def wrapper(*args, **kwargs) -> t.Any:
+ if callable(resource_arg):
+ # Call to get the ID if a callable was given.
+ id_ = resource_arg(*args, **kwargs)
+ else:
+ # Retrieve the ID from the args via position or name.
+ id_ = _get_arg_value(resource_arg, args, kwargs)
+
+ # Get the lock for the ID. Create a Lock if one doesn't exist yet.
+ locks = __lock_dicts[namespace]
+ lock = locks.setdefault(id_, asyncio.Lock)
+
+ if not lock.locked():
+ # Resource is free; acquire it.
+ async with lock:
+ return await func(*args, **kwargs)
+
+ return wrapper
+ return decorator
+
+
def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:
"""
Changes the channel in the context of the command to redirect the output to a certain channel.