diff options
author | 2022-07-23 22:52:52 +0100 | |
---|---|---|
committer | 2022-08-14 19:43:52 +0100 | |
commit | 7782c196830098f81f39d235354636cd0d4a481d (patch) | |
tree | 68ffeec0c351449940b2dd4c7349b4458d65cda2 | |
parent | redis-py breaking changes (diff) |
No longer use the removed RedisSession connection object
This has been abstracted away, the correct way to do this now is to directly access the client.
-rw-r--r-- | bot/exts/info/doc/_redis_cache.py | 92 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 5 |
2 files changed, 46 insertions, 51 deletions
diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py index 8e08e7ae4..0f4d663d1 100644 --- a/bot/exts/info/doc/_redis_cache.py +++ b/bot/exts/info/doc/_redis_cache.py @@ -34,55 +34,52 @@ class DocRedisCache(RedisObject): redis_key = f"{self.namespace}:{item_key(item)}" needs_expire = False - with await self._get_pool_connection() as connection: - set_expire = self._set_expires.get(redis_key) - if set_expire is None: - # An expire is only set if the key didn't exist before. - ttl = await connection.ttl(redis_key) - log.debug(f"Checked TTL for `{redis_key}`.") - - if ttl == -1: - log.warning(f"Key `{redis_key}` had no expire set.") - if ttl < 0: # not set or didn't exist - needs_expire = True - else: - log.debug(f"Key `{redis_key}` has a {ttl} TTL.") - self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - - elif time.monotonic() > set_expire: - # If we got here the key expired in redis and we can be sure it doesn't exist. + set_expire = self._set_expires.get(redis_key) + if set_expire is None: + # An expire is only set if the key didn't exist before. + ttl = await self.redis_session.client.ttl(redis_key) + log.debug(f"Checked TTL for `{redis_key}`.") + + if ttl == -1: + log.warning(f"Key `{redis_key}` had no expire set.") + if ttl < 0: # not set or didn't exist needs_expire = True - log.debug(f"Key `{redis_key}` expired in internal key cache.") + else: + log.debug(f"Key `{redis_key}` has a {ttl} TTL.") + self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - await connection.hset(redis_key, item.symbol_id, value) - if needs_expire: - self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS - await connection.expire(redis_key, WEEK_SECONDS) - log.info(f"Set {redis_key} to expire in a week.") + elif time.monotonic() > set_expire: + # If we got here the key expired in redis and we can be sure it doesn't exist. + needs_expire = True + log.debug(f"Key `{redis_key}` expired in internal key cache.") + + await self.redis_session.client.hset(redis_key, item.symbol_id, value) + if needs_expire: + self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS + await self.redis_session.client.expire(redis_key, WEEK_SECONDS) + log.info(f"Set {redis_key} to expire in a week.") @namespace_lock async def get(self, item: DocItem) -> Optional[str]: """Return the Markdown content of the symbol `item` if it exists.""" - with await self._get_pool_connection() as connection: - return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8") + return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" pattern = f"{self.namespace}:{package}:*" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=pattern) - ] - if package_keys: - await connection.delete(*package_keys) - log.info(f"Deleted keys from redis: {package_keys}.") - self._set_expires = { - key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) - } - return True - return False + package_keys = [ + package_key async for package_key in self.redis_session.client.iscan(match=pattern) + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + log.info(f"Deleted keys from redis: {package_keys}.") + self._set_expires = { + key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) + } + return True + return False class StaleItemCounter(RedisObject): @@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject): If the counter didn't exist, initialize it with 1. """ key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}" - with await self._get_pool_connection() as connection: - await connection.expire(key, WEEK_SECONDS * 3) - return int(await connection.incr(key)) + await self.redis_session.client.expire(key, WEEK_SECONDS * 3) + return int(await self.redis_session.client.incr(key)) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") - ] - if package_keys: - await connection.delete(*package_keys) - return True - return False + package_keys = [ + package_key + async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*") + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + return True + return False def item_key(item: DocItem) -> str: diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..f60c177c5 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def flush(self): """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() + await self.session.client.flushall() async def asyncSetUp(self): # noqa: N802 self.session = RedisSession(use_fakeredis=True) @@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def asyncTearDown(self): # noqa: N802 if self.session: - await self.session.close() + await self.session.client.close() def setUp(self): """ |