aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils/lock.py
blob: 83146235e8b3d580557f2b1e9d2a1b9771c2274e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import asyncio
import inspect
import types
from collections import defaultdict
from collections.abc import Awaitable, Callable, Hashable
from functools import partial
from typing import Any
from weakref import WeakValueDictionary

from pydis_core.utils import function
from pydis_core.utils.function import command_wraps
from pydis_core.utils.logging import get_logger

log = get_logger(__name__)
__lock_dicts = defaultdict(WeakValueDictionary)

_IdCallableReturn = Hashable | Awaitable[Hashable]
_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
ResourceId = Hashable | _IdCallable


class LockedResourceError(RuntimeError):
    """
    Exception raised when an operation is attempted on a locked resource.

    Attributes:
        type (str): Name of the locked resource's type
        id (typing.Hashable): ID of the locked resource
    """

    def __init__(self, resource_type: str, resource_id: Hashable):
        self.type = resource_type
        self.id = resource_id

        super().__init__(
            f"Cannot operate on {self.type.lower()} `{self.id}`; "
            "it is currently locked and in use by another operation."
        )


class SharedEvent:
    """
    Context manager managing an internal event exposed through the wait coro.

    While any code is executing in this context manager, the underlying event will not be set;
    when all of the holders finish the event will be set.
    """

    def __init__(self):
        self._active_count = 0
        self._event = asyncio.Event()
        self._event.set()

    def __enter__(self):
        """Increment the count of the active holders and clear the internal event."""
        self._active_count += 1
        self._event.clear()

    def __exit__(self, _exc_type, _exc_val, _exc_tb):  # noqa: ANN001
        """Decrement the count of the active holders; if 0 is reached set the internal event."""
        self._active_count -= 1
        if not self._active_count:
            self._event.set()

    async def wait(self) -> None:
        """Wait for all active holders to exit."""
        await self._event.wait()


def lock(
    namespace: Hashable,
    resource_id: ResourceId,
    *,
    raise_error: bool = False,
    wait: bool = False,
) -> Callable:
    """
    Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.

    If decorating a command, this decorator must go before (below) the `command` decorator.

    Arguments:
        namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs.
        resource_id: identifies a resource on which to perform a mutually exclusive operation.
            It may also be a callable or awaitable which will return the resource ID given an ordered
            mapping of the parameters' names to arguments' values.
        raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired.
        wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually
            exclusive function currently holds the lock for a resource, do not run the decorated function
            and return None.

    Raises:
        :exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True.
    """
    def decorator(func: types.FunctionType) -> types.FunctionType:
        name = func.__name__

        @command_wraps(func)
        async def wrapper(*args, **kwargs) -> Any:
            log.trace(f"{name}: mutually exclusive decorator called")

            if callable(resource_id):
                log.trace(f"{name}: binding args to signature")
                bound_args = function.get_bound_args(func, args, kwargs)

                log.trace(f"{name}: calling the given callable to get the resource ID")
                id_ = resource_id(bound_args)

                if inspect.isawaitable(id_):
                    log.trace(f"{name}: awaiting to get resource ID")
                    id_ = await id_
            else:
                id_ = resource_id

            log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")

            # Get the lock for the ID. Create a lock if one doesn't exist yet.
            locks = __lock_dicts[namespace]
            lock_ = locks.setdefault(id_, asyncio.Lock())

            # It's safe to check an asyncio.Lock is free before acquiring it because:
            #   1. Synchronous code like `if not lock_.locked()` does not yield execution
            #   2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
            #   3. awaits only yield execution to the event loop at actual I/O boundaries
            if wait or not lock_.locked():
                log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
                async with lock_:
                    return await func(*args, **kwargs)
            else:
                log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
                if raise_error:
                    raise LockedResourceError(str(namespace), id_)
                return None

        return wrapper
    return decorator


def lock_arg(
    namespace: Hashable,
    name_or_pos: function.Argument,
    func: Callable[[Any], _IdCallableReturn] | None = None,
    *,
    raise_error: bool = False,
    wait: bool = False,
) -> Callable:
    """
    Apply the `lock` decorator using the value of the arg at the given name/position as the ID.

    See `lock` docs for more information.

    Arguments:
        func: An optional callable or awaitable which will return the ID given the argument value.
    """
    decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
    return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)