diff options
| -rw-r--r-- | bot/exts/info/doc/_cog.py | 10 | ||||
| -rw-r--r-- | bot/utils/lock.py | 30 |
2 files changed, 34 insertions, 6 deletions
diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py index 7b9dad135..26694ae55 100644 --- a/bot/exts/info/doc/_cog.py +++ b/bot/exts/info/doc/_cog.py @@ -16,7 +16,7 @@ from bot.bot import Bot from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import Inventory, PackageName, ValidURL from bot.pagination import LinePaginator -from bot.utils.lock import lock +from bot.utils.lock import SharedEvent, lock from bot.utils.messages import send_denial, wait_for_deletion from bot.utils.scheduling import Scheduler from . import PRIORITY_PACKAGES, doc_cache @@ -70,8 +70,7 @@ class DocCog(commands.Cog): self.refresh_event = asyncio.Event() self.refresh_event.set() - self.symbol_get_event = asyncio.Event() - self.symbol_get_event.set() + self.symbol_get_event = SharedEvent() self.init_refresh_task = self.bot.loop.create_task(self.init_refresh_inventory()) @@ -252,9 +251,8 @@ class DocCog(commands.Cog): return None self.bot.stats.incr(f"doc_fetches.{symbol_info.package}") - self.symbol_get_event.clear() - markdown = await doc_cache.get(symbol_info) - self.symbol_get_event.set() + with self.symbol_get_event: + markdown = await doc_cache.get(symbol_info) if markdown is None: log.debug(f"Redis cache miss for symbol `{symbol}`.") diff --git a/bot/utils/lock.py b/bot/utils/lock.py index 997c653a1..b4bb0ebc7 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,3 +1,4 @@ +import asyncio import inspect import logging import types @@ -18,6 +19,35 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] ResourceId = Union[Hashable, _IdCallable] +class SharedEvent: + """ + Context manager managing an internal event exposed through the wait coro. + + While any code is executing in this context manager, the underyling event will not be set; + when all of the holders finish the event will be set. + """ + + def __init__(self): + self._active_count = 0 + self._event = asyncio.Event() + self._event.set() + + def __enter__(self): + """Increment the count of the active holders and clear the internal event.""" + self._active_count += 1 + self._event.clear() + + def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001 + """Decrement the count of the active holders; if 0 is reached set the internal event.""" + self._active_count -= 1 + if not self._active_count: + self._event.set() + + async def wait(self) -> None: + """Wait for all active holders to exit.""" + await self._event.wait() + + class LockGuard: """ A context manager which acquires and releases a lock (mutex). |