diff options
| author | 2020-07-17 10:33:21 -0700 | |
|---|---|---|
| committer | 2020-07-31 22:58:05 -0700 | |
| commit | be93601a31dcfa8acb03996eaaf2edcb654712f5 (patch) | |
| tree | a2ff86f21c04698f80a6904121820eaf1aa22e0a | |
| parent | Decorators: create helper function to get arg value (diff) | |
Decorators: add mutually exclusive decorator
This will be used to prevent race conditions on a resource by stopping
all other access to the resource once its been acquired.
| -rw-r--r-- | bot/decorators.py | 39 | 
1 files changed, 39 insertions, 0 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index 1fe082b6e..cae0870b6 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -2,6 +2,7 @@ import asyncio  import logging  import random  import typing as t +from collections import defaultdict  from contextlib import suppress  from functools import wraps  from weakref import WeakValueDictionary @@ -13,8 +14,10 @@ from bot.constants import Channels, ERROR_REPLIES, RedirectOutput  from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check  log = logging.getLogger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary)  Argument = t.Union[int, str] +ResourceId = t.Union[Argument, t.Callable[..., t.Hashable]]  def in_whitelist( @@ -92,6 +95,42 @@ def locked() -> t.Callable:      return wrap +def mutually_exclusive(namespace: t.Hashable, resource_arg: ResourceId) -> t.Callable: +    """ +    Turn the decorated coroutine function into a mutually exclusive operation on a resource. + +    If any other mutually exclusive function currently holds the lock for a resource, do not run the +    decorated function and return None. + +    `namespace` is an identifier used to prevent collisions among resource IDs. + +    `resource_arg` is the positional index or name of the parameter of the decorated function whose +    value will be the resource ID. It may also be a callable which will return the resource ID +    given the decorated function's args and kwargs. +    """ +    def decorator(func: t.Callable) -> t.Callable: +        @wraps(func) +        async def wrapper(*args, **kwargs) -> t.Any: +            if callable(resource_arg): +                # Call to get the ID if a callable was given. +                id_ = resource_arg(*args, **kwargs) +            else: +                # Retrieve the ID from the args via position or name. +                id_ = _get_arg_value(resource_arg, args, kwargs) + +            # Get the lock for the ID. Create a Lock if one doesn't exist yet. +            locks = __lock_dicts[namespace] +            lock = locks.setdefault(id_, asyncio.Lock) + +            if not lock.locked(): +                # Resource is free; acquire it. +                async with lock: +                    return await func(*args, **kwargs) + +        return wrapper +    return decorator + +  def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:      """      Changes the channel in the context of the command to redirect the output to a certain channel. | 
