diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/info/doc/_redis_cache.py | 92 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 5 | 
2 files changed, 46 insertions, 51 deletions
| diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py index 8e08e7ae4..0f4d663d1 100644 --- a/bot/exts/info/doc/_redis_cache.py +++ b/bot/exts/info/doc/_redis_cache.py @@ -34,55 +34,52 @@ class DocRedisCache(RedisObject):          redis_key = f"{self.namespace}:{item_key(item)}"          needs_expire = False -        with await self._get_pool_connection() as connection: -            set_expire = self._set_expires.get(redis_key) -            if set_expire is None: -                # An expire is only set if the key didn't exist before. -                ttl = await connection.ttl(redis_key) -                log.debug(f"Checked TTL for `{redis_key}`.") - -                if ttl == -1: -                    log.warning(f"Key `{redis_key}` had no expire set.") -                if ttl < 0:  # not set or didn't exist -                    needs_expire = True -                else: -                    log.debug(f"Key `{redis_key}` has a {ttl} TTL.") -                    self._set_expires[redis_key] = time.monotonic() + ttl - .1  # we need this to expire before redis - -            elif time.monotonic() > set_expire: -                # If we got here the key expired in redis and we can be sure it doesn't exist. +        set_expire = self._set_expires.get(redis_key) +        if set_expire is None: +            # An expire is only set if the key didn't exist before. +            ttl = await self.redis_session.client.ttl(redis_key) +            log.debug(f"Checked TTL for `{redis_key}`.") + +            if ttl == -1: +                log.warning(f"Key `{redis_key}` had no expire set.") +            if ttl < 0:  # not set or didn't exist                  needs_expire = True -                log.debug(f"Key `{redis_key}` expired in internal key cache.") +            else: +                log.debug(f"Key `{redis_key}` has a {ttl} TTL.") +                self._set_expires[redis_key] = time.monotonic() + ttl - .1  # we need this to expire before redis -            await connection.hset(redis_key, item.symbol_id, value) -            if needs_expire: -                self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS -                await connection.expire(redis_key, WEEK_SECONDS) -                log.info(f"Set {redis_key} to expire in a week.") +        elif time.monotonic() > set_expire: +            # If we got here the key expired in redis and we can be sure it doesn't exist. +            needs_expire = True +            log.debug(f"Key `{redis_key}` expired in internal key cache.") + +        await self.redis_session.client.hset(redis_key, item.symbol_id, value) +        if needs_expire: +            self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS +            await self.redis_session.client.expire(redis_key, WEEK_SECONDS) +            log.info(f"Set {redis_key} to expire in a week.")      @namespace_lock      async def get(self, item: DocItem) -> Optional[str]:          """Return the Markdown content of the symbol `item` if it exists.""" -        with await self._get_pool_connection() as connection: -            return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8") +        return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id)      @namespace_lock      async def delete(self, package: str) -> bool:          """Remove all values for `package`; return True if at least one key was deleted, False otherwise."""          pattern = f"{self.namespace}:{package}:*" -        with await self._get_pool_connection() as connection: -            package_keys = [ -                package_key async for package_key in connection.iscan(match=pattern) -            ] -            if package_keys: -                await connection.delete(*package_keys) -                log.info(f"Deleted keys from redis: {package_keys}.") -                self._set_expires = { -                    key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) -                } -                return True -            return False +        package_keys = [ +            package_key async for package_key in self.redis_session.client.iscan(match=pattern) +        ] +        if package_keys: +            await self.redis_session.client.delete(*package_keys) +            log.info(f"Deleted keys from redis: {package_keys}.") +            self._set_expires = { +                key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) +            } +            return True +        return False  class StaleItemCounter(RedisObject): @@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject):          If the counter didn't exist, initialize it with 1.          """          key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}" -        with await self._get_pool_connection() as connection: -            await connection.expire(key, WEEK_SECONDS * 3) -            return int(await connection.incr(key)) +        await self.redis_session.client.expire(key, WEEK_SECONDS * 3) +        return int(await self.redis_session.client.incr(key))      @namespace_lock      async def delete(self, package: str) -> bool:          """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" -        with await self._get_pool_connection() as connection: -            package_keys = [ -                package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") -            ] -            if package_keys: -                await connection.delete(*package_keys) -                return True -            return False +        package_keys = [ +            package_key +            async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*") +        ] +        if package_keys: +            await self.redis_session.client.delete(*package_keys) +            return True +        return False  def item_key(item: DocItem) -> str: diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..f60c177c5 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      async def flush(self):          """Flush everything from the database to prevent carry-overs between tests.""" -        with await self.session.pool as connection: -            await connection.flushall() +        await self.session.client.flushall()      async def asyncSetUp(self):  # noqa: N802          self.session = RedisSession(use_fakeredis=True) @@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      async def asyncTearDown(self):  # noqa: N802          if self.session: -            await self.session.close() +            await self.session.client.close()      def setUp(self):          """ | 
