aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/info/doc/_redis_cache.py92
-rw-r--r--tests/bot/exts/moderation/test_incidents.py5
2 files changed, 46 insertions, 51 deletions
diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py
index 8e08e7ae4..0f4d663d1 100644
--- a/bot/exts/info/doc/_redis_cache.py
+++ b/bot/exts/info/doc/_redis_cache.py
@@ -34,55 +34,52 @@ class DocRedisCache(RedisObject):
redis_key = f"{self.namespace}:{item_key(item)}"
needs_expire = False
- with await self._get_pool_connection() as connection:
- set_expire = self._set_expires.get(redis_key)
- if set_expire is None:
- # An expire is only set if the key didn't exist before.
- ttl = await connection.ttl(redis_key)
- log.debug(f"Checked TTL for `{redis_key}`.")
-
- if ttl == -1:
- log.warning(f"Key `{redis_key}` had no expire set.")
- if ttl < 0: # not set or didn't exist
- needs_expire = True
- else:
- log.debug(f"Key `{redis_key}` has a {ttl} TTL.")
- self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis
-
- elif time.monotonic() > set_expire:
- # If we got here the key expired in redis and we can be sure it doesn't exist.
+ set_expire = self._set_expires.get(redis_key)
+ if set_expire is None:
+ # An expire is only set if the key didn't exist before.
+ ttl = await self.redis_session.client.ttl(redis_key)
+ log.debug(f"Checked TTL for `{redis_key}`.")
+
+ if ttl == -1:
+ log.warning(f"Key `{redis_key}` had no expire set.")
+ if ttl < 0: # not set or didn't exist
needs_expire = True
- log.debug(f"Key `{redis_key}` expired in internal key cache.")
+ else:
+ log.debug(f"Key `{redis_key}` has a {ttl} TTL.")
+ self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis
- await connection.hset(redis_key, item.symbol_id, value)
- if needs_expire:
- self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS
- await connection.expire(redis_key, WEEK_SECONDS)
- log.info(f"Set {redis_key} to expire in a week.")
+ elif time.monotonic() > set_expire:
+ # If we got here the key expired in redis and we can be sure it doesn't exist.
+ needs_expire = True
+ log.debug(f"Key `{redis_key}` expired in internal key cache.")
+
+ await self.redis_session.client.hset(redis_key, item.symbol_id, value)
+ if needs_expire:
+ self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS
+ await self.redis_session.client.expire(redis_key, WEEK_SECONDS)
+ log.info(f"Set {redis_key} to expire in a week.")
@namespace_lock
async def get(self, item: DocItem) -> Optional[str]:
"""Return the Markdown content of the symbol `item` if it exists."""
- with await self._get_pool_connection() as connection:
- return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8")
+ return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id)
@namespace_lock
async def delete(self, package: str) -> bool:
"""Remove all values for `package`; return True if at least one key was deleted, False otherwise."""
pattern = f"{self.namespace}:{package}:*"
- with await self._get_pool_connection() as connection:
- package_keys = [
- package_key async for package_key in connection.iscan(match=pattern)
- ]
- if package_keys:
- await connection.delete(*package_keys)
- log.info(f"Deleted keys from redis: {package_keys}.")
- self._set_expires = {
- key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern)
- }
- return True
- return False
+ package_keys = [
+ package_key async for package_key in self.redis_session.client.iscan(match=pattern)
+ ]
+ if package_keys:
+ await self.redis_session.client.delete(*package_keys)
+ log.info(f"Deleted keys from redis: {package_keys}.")
+ self._set_expires = {
+ key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern)
+ }
+ return True
+ return False
class StaleItemCounter(RedisObject):
@@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject):
If the counter didn't exist, initialize it with 1.
"""
key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}"
- with await self._get_pool_connection() as connection:
- await connection.expire(key, WEEK_SECONDS * 3)
- return int(await connection.incr(key))
+ await self.redis_session.client.expire(key, WEEK_SECONDS * 3)
+ return int(await self.redis_session.client.incr(key))
@namespace_lock
async def delete(self, package: str) -> bool:
"""Remove all values for `package`; return True if at least one key was deleted, False otherwise."""
- with await self._get_pool_connection() as connection:
- package_keys = [
- package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*")
- ]
- if package_keys:
- await connection.delete(*package_keys)
- return True
- return False
+ package_keys = [
+ package_key
+ async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*")
+ ]
+ if package_keys:
+ await self.redis_session.client.delete(*package_keys)
+ return True
+ return False
def item_key(item: DocItem) -> str:
diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py
index cfe0c4b03..f60c177c5 100644
--- a/tests/bot/exts/moderation/test_incidents.py
+++ b/tests/bot/exts/moderation/test_incidents.py
@@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):
async def flush(self):
"""Flush everything from the database to prevent carry-overs between tests."""
- with await self.session.pool as connection:
- await connection.flushall()
+ await self.session.client.flushall()
async def asyncSetUp(self): # noqa: N802
self.session = RedisSession(use_fakeredis=True)
@@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):
async def asyncTearDown(self): # noqa: N802
if self.session:
- await self.session.close()
+ await self.session.client.close()
def setUp(self):
"""