diff options
| -rw-r--r-- | bot/decorators.py | 19 | ||||
| -rw-r--r-- | bot/errors.py | 20 |
2 files changed, 35 insertions, 4 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index e370bf834..15386e506 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -12,6 +12,7 @@ from discord import Colour, Embed, Member, NotFound from discord.ext.commands import Cog, Command, Context, check from bot.constants import Channels, ERROR_REPLIES, RedirectOutput +from bot.errors import LockedResourceError from bot.utils import function from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check @@ -98,12 +99,18 @@ def locked() -> t.Callable: return wrap -def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Callable: +def mutually_exclusive( + namespace: t.Hashable, + resource_id: ResourceId, + *, + raise_error: bool = False, +) -> t.Callable: """ Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. If any other mutually exclusive function currently holds the lock for a resource, do not run the - decorated function and return None. + decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if + the lock cannot be acquired. `namespace` is an identifier used to prevent collisions among resource IDs. @@ -145,6 +152,8 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call return await func(*args, **kwargs) else: log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") + if raise_error: + raise LockedResourceError(str(namespace), id_) return wrapper return decorator @@ -153,7 +162,9 @@ def mutually_exclusive(namespace: t.Hashable, resource_id: ResourceId) -> t.Call def mutually_exclusive_arg( namespace: t.Hashable, name_or_pos: function.Argument, - func: t.Callable[[t.Any], _IdCallableReturn] = None + func: t.Callable[[t.Any], _IdCallableReturn] = None, + *, + raise_error: bool = False, ) -> t.Callable: """ Apply `mutually_exclusive` using the value of the arg at the given name/position as the ID. @@ -161,7 +172,7 @@ def mutually_exclusive_arg( `func` is an optional callable or awaitable which will return the ID given the argument value. See `mutually_exclusive` docs for more information. """ - decorator_func = partial(mutually_exclusive, namespace) + decorator_func = partial(mutually_exclusive, namespace, raise_error=raise_error) return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) diff --git a/bot/errors.py b/bot/errors.py new file mode 100644 index 000000000..34de3c2b1 --- /dev/null +++ b/bot/errors.py @@ -0,0 +1,20 @@ +from typing import Hashable + + +class LockedResourceError(RuntimeError): + """ + Exception raised when an operation is attempted on a locked resource. + + Attributes: + `type` -- name of the locked resource's type + `resource_id` -- ID of the locked resource + """ + + def __init__(self, resource_type: str, resource_id: Hashable): + self.type = resource_type + self.id = resource_id + + super().__init__( + f"Cannot operate on {self.type.lower()} `{self.id}`; " + "it is currently locked and in use by another operation." + ) |