aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/info/doc/_cog.py10
-rw-r--r--bot/utils/lock.py30
2 files changed, 34 insertions, 6 deletions
diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py
index 7b9dad135..26694ae55 100644
--- a/bot/exts/info/doc/_cog.py
+++ b/bot/exts/info/doc/_cog.py
@@ -16,7 +16,7 @@ from bot.bot import Bot
from bot.constants import MODERATION_ROLES, RedirectOutput
from bot.converters import Inventory, PackageName, ValidURL
from bot.pagination import LinePaginator
-from bot.utils.lock import lock
+from bot.utils.lock import SharedEvent, lock
from bot.utils.messages import send_denial, wait_for_deletion
from bot.utils.scheduling import Scheduler
from . import PRIORITY_PACKAGES, doc_cache
@@ -70,8 +70,7 @@ class DocCog(commands.Cog):
self.refresh_event = asyncio.Event()
self.refresh_event.set()
- self.symbol_get_event = asyncio.Event()
- self.symbol_get_event.set()
+ self.symbol_get_event = SharedEvent()
self.init_refresh_task = self.bot.loop.create_task(self.init_refresh_inventory())
@@ -252,9 +251,8 @@ class DocCog(commands.Cog):
return None
self.bot.stats.incr(f"doc_fetches.{symbol_info.package}")
- self.symbol_get_event.clear()
- markdown = await doc_cache.get(symbol_info)
- self.symbol_get_event.set()
+ with self.symbol_get_event:
+ markdown = await doc_cache.get(symbol_info)
if markdown is None:
log.debug(f"Redis cache miss for symbol `{symbol}`.")
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
index 997c653a1..b4bb0ebc7 100644
--- a/bot/utils/lock.py
+++ b/bot/utils/lock.py
@@ -1,3 +1,4 @@
+import asyncio
import inspect
import logging
import types
@@ -18,6 +19,35 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
ResourceId = Union[Hashable, _IdCallable]
+class SharedEvent:
+ """
+ Context manager managing an internal event exposed through the wait coro.
+
+ While any code is executing in this context manager, the underyling event will not be set;
+ when all of the holders finish the event will be set.
+ """
+
+ def __init__(self):
+ self._active_count = 0
+ self._event = asyncio.Event()
+ self._event.set()
+
+ def __enter__(self):
+ """Increment the count of the active holders and clear the internal event."""
+ self._active_count += 1
+ self._event.clear()
+
+ def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
+ """Decrement the count of the active holders; if 0 is reached set the internal event."""
+ self._active_count -= 1
+ if not self._active_count:
+ self._event.set()
+
+ async def wait(self) -> None:
+ """Wait for all active holders to exit."""
+ await self._event.wait()
+
+
class LockGuard:
"""
A context manager which acquires and releases a lock (mutex).