aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/utils/redis_cache.py72
-rw-r--r--tests/bot/utils/test_redis_cache.py36
-rw-r--r--tests/helpers.py2
3 files changed, 85 insertions, 25 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index 483bbc2cd..24f2f2e03 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -1,12 +1,10 @@
from __future__ import annotations
-from enum import Enum
-from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Optional, Union
from bot.bot import Bot
-ValidRedisKey = Union[str, int, float]
-JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]]
+ValidRedisType = Union[str, int, float]
class RedisCache:
@@ -41,7 +39,39 @@ class RedisCache:
self._namespaces.append(namespace)
self._namespace = namespace
- def __set_name__(self, owner: object, attribute_name: str) -> None:
+ @staticmethod
+ def _to_typestring(value: ValidRedisType) -> str:
+ """Turn a valid Redis type into a typestring."""
+ if isinstance(value, float):
+ return f"f|{value}"
+ elif isinstance(value, int):
+ return f"i|{value}"
+ elif isinstance(value, str):
+ return f"s|{value}"
+
+ @staticmethod
+ def _from_typestring(value: str) -> ValidRedisType:
+ """Turn a valid Redis type into a typestring."""
+ if value.startswith("f|"):
+ return float(value[2:])
+ if value.startswith("i|"):
+ return int(value[2:])
+ if value.startswith("s|"):
+ return value[2:]
+
+ async def _validate_cache(self) -> None:
+ """Validate that the RedisCache is ready to be used."""
+ if self.bot is None:
+ raise RuntimeError("Critical error: RedisCache has no `Bot` instance.")
+
+ if self._namespace is None:
+ raise RuntimeError(
+ "Critical error: RedisCache has no namespace. "
+ "Did you initialize this object as a class attribute?"
+ )
+ await self.bot._redis_ready.wait()
+
+ def __set_name__(self, owner: Any, attribute_name: str) -> None:
"""
Set the namespace to Class.attribute_name.
@@ -54,8 +84,11 @@ class RedisCache:
if self.bot:
return self
+ if self._namespace is None:
+ raise RuntimeError("RedisCache must be a class attribute.")
+
if instance is None:
- raise NotImplementedError("You must create an instance of RedisCache to use it.")
+ raise RuntimeError("You must create an instance of RedisCache to use it.")
for attribute in vars(instance).values():
if isinstance(attribute, Bot):
@@ -69,19 +102,32 @@ class RedisCache:
"""Return a beautiful representation of this object instance."""
return f"RedisCache(namespace={self._namespace!r})"
- async def set(self, key: ValidRedisKey, value: JSONSerializableType) -> None:
+ async def set(self, key: ValidRedisType, value: ValidRedisType) -> None:
"""Store an item in the Redis cache."""
- # await self._redis.hset(self._namespace, key, value)
+ await self._validate_cache()
+
+ # Convert to a typestring and then set it
+ key = self._to_typestring(key)
+ value = self._to_typestring(value)
+ await self._redis.hset(self._namespace, key, value)
- async def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
+ async def get(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType:
"""Get an item from the Redis cache."""
- # value = await self._redis.hget(self._namespace, key)
+ await self._validate_cache()
+ key = self._to_typestring(key)
+ value = await self._redis.hget(self._namespace, key)
+
+ if value is None:
+ return default
+ else:
+ value = self._from_typestring(value.decode("utf-8"))
+ return value
- async def delete(self, key: ValidRedisKey) -> None:
+ async def delete(self, key: ValidRedisType) -> None:
"""Delete an item from the Redis cache."""
# await self._redis.hdel(self._namespace, key)
- async def contains(self, key: ValidRedisKey) -> bool:
+ async def contains(self, key: ValidRedisType) -> bool:
"""Check if a key exists in the Redis cache."""
# return await self._redis.hexists(self._namespace, key)
@@ -103,7 +149,7 @@ class RedisCache:
"""Deletes the entire hash from the Redis cache."""
# await self._redis.delete(self._namespace)
- async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
+ async def pop(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType:
"""Get the item, remove it from the cache, and provide a default if not found."""
# value = await self.get(key, default)
# await self.delete(key)
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 991225481..ad38bfde0 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -7,27 +7,41 @@ from tests import helpers
class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
- """Tests the RedisDict class from utils.redis_dict.py."""
+ """Tests the RedisCache class from utils.redis_dict.py."""
redis = RedisCache()
- async def asyncSetUp(self): # noqa: N802 - this special method can't be all lowercase
+ async def asyncSetUp(self): # noqa: N802
"""Sets up the objects that only have to be initialized once."""
self.bot = helpers.MockBot()
self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
- def test_class_attribute_namespace(self):
+ async def test_class_attribute_namespace(self):
"""Test that RedisDict creates a namespace automatically for class attributes."""
self.assertEqual(self.redis._namespace, "RedisCacheTests.redis")
- # Test that errors are raised when this isn't true.
- # def test_set_get_item(self):
- # """Test that users can set and get items from the RedisDict."""
- # self.redis['favorite_fruit'] = 'melon'
- # self.redis['favorite_number'] = 86
- # self.assertEqual(self.redis['favorite_fruit'], 'melon')
- # self.assertEqual(self.redis['favorite_number'], 86)
- #
+ # Test that errors are raised when not assigned as a class attribute
+ bad_cache = RedisCache()
+
+ with self.assertRaises(RuntimeError):
+ await bad_cache.set("test", "me_up_deadman")
+
+ async def test_set_get_item(self):
+ """Test that users can set and get items from the RedisDict."""
+ test_cases = (
+ ('favorite_fruit', 'melon'),
+ ('favorite_number', 86),
+ ('favorite_fraction', 86.54)
+ )
+
+ # Test that we can get and set different types.
+ for test in test_cases:
+ await self.redis.set(*test)
+ self.assertEqual(await self.redis.get(test[0]), test[1])
+
+ # Test that .get allows a default value
+ self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+
# def test_set_item_types(self):
# """Test that setitem rejects keys and values that are not strings, ints or floats."""
# fruits = ["lemon", "melon", "apple"]
diff --git a/tests/helpers.py b/tests/helpers.py
index d226be3f0..2b176db79 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -299,7 +299,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
For more information, see the `MockGuild` docstring.
"""
spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop())
- additional_spec_asyncs = ("wait_for",)
+ additional_spec_asyncs = ("wait_for", "_redis_ready")
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)