diff options
| -rw-r--r-- | bot/decorators.py | 6 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 3 | ||||
| -rw-r--r-- | bot/utils/lock.py | 23 |
3 files changed, 28 insertions, 4 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index 0e84cf37e..3418dfd11 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, check from bot.constants import Channels, ERROR_REPLIES, RedirectOutput from bot.errors import LockedResourceError -from bot.utils import function +from bot.utils import LockGuard, function from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check log = logging.getLogger(__name__) @@ -144,11 +144,11 @@ def mutually_exclusive( # Get the lock for the ID. Create a lock if one doesn't exist yet. locks = __lock_dicts[namespace] - lock = locks.setdefault(id_, asyncio.Lock()) + lock = locks.setdefault(id_, LockGuard()) if not lock.locked(): log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...") - async with lock: + with lock: return await func(*args, **kwargs) else: log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 5a6e1811b..0dd9605e8 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -2,9 +2,10 @@ from abc import ABCMeta from discord.ext.commands import CogMeta +from bot.utils.lock import LockGuard from bot.utils.redis_cache import RedisCache -__all__ = ['RedisCache', 'CogABCMeta'] +__all__ = ["CogABCMeta", "LockGuard", "RedisCache"] class CogABCMeta(CogMeta, ABCMeta): diff --git a/bot/utils/lock.py b/bot/utils/lock.py new file mode 100644 index 000000000..8f1b738aa --- /dev/null +++ b/bot/utils/lock.py @@ -0,0 +1,23 @@ +class LockGuard: + """ + A context manager which acquires and releases a lock (mutex). + + Raise RuntimeError if trying to acquire a locked lock. + """ + + def __init__(self): + self._locked = False + + def locked(self) -> bool: + """Return True if currently locked or False if unlocked.""" + return self._locked + + def __enter__(self): + if self._locked: + raise RuntimeError("Cannot acquire a locked lock.") + + self._locked = True + + def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001 + self._locked = False + return False # Indicate any raised exception shouldn't be suppressed. |