diff options
-rw-r--r-- | bot/utils/redis_cache.py | 72 | ||||
-rw-r--r-- | tests/bot/utils/test_redis_cache.py | 36 | ||||
-rw-r--r-- | tests/helpers.py | 2 |
3 files changed, 85 insertions, 25 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 483bbc2cd..24f2f2e03 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -1,12 +1,10 @@ from __future__ import annotations -from enum import Enum -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Dict, Optional, Union from bot.bot import Bot -ValidRedisKey = Union[str, int, float] -JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]] +ValidRedisType = Union[str, int, float] class RedisCache: @@ -41,7 +39,39 @@ class RedisCache: self._namespaces.append(namespace) self._namespace = namespace - def __set_name__(self, owner: object, attribute_name: str) -> None: + @staticmethod + def _to_typestring(value: ValidRedisType) -> str: + """Turn a valid Redis type into a typestring.""" + if isinstance(value, float): + return f"f|{value}" + elif isinstance(value, int): + return f"i|{value}" + elif isinstance(value, str): + return f"s|{value}" + + @staticmethod + def _from_typestring(value: str) -> ValidRedisType: + """Turn a valid Redis type into a typestring.""" + if value.startswith("f|"): + return float(value[2:]) + if value.startswith("i|"): + return int(value[2:]) + if value.startswith("s|"): + return value[2:] + + async def _validate_cache(self) -> None: + """Validate that the RedisCache is ready to be used.""" + if self.bot is None: + raise RuntimeError("Critical error: RedisCache has no `Bot` instance.") + + if self._namespace is None: + raise RuntimeError( + "Critical error: RedisCache has no namespace. " + "Did you initialize this object as a class attribute?" + ) + await self.bot._redis_ready.wait() + + def __set_name__(self, owner: Any, attribute_name: str) -> None: """ Set the namespace to Class.attribute_name. @@ -54,8 +84,11 @@ class RedisCache: if self.bot: return self + if self._namespace is None: + raise RuntimeError("RedisCache must be a class attribute.") + if instance is None: - raise NotImplementedError("You must create an instance of RedisCache to use it.") + raise RuntimeError("You must create an instance of RedisCache to use it.") for attribute in vars(instance).values(): if isinstance(attribute, Bot): @@ -69,19 +102,32 @@ class RedisCache: """Return a beautiful representation of this object instance.""" return f"RedisCache(namespace={self._namespace!r})" - async def set(self, key: ValidRedisKey, value: JSONSerializableType) -> None: + async def set(self, key: ValidRedisType, value: ValidRedisType) -> None: """Store an item in the Redis cache.""" - # await self._redis.hset(self._namespace, key, value) + await self._validate_cache() + + # Convert to a typestring and then set it + key = self._to_typestring(key) + value = self._to_typestring(value) + await self._redis.hset(self._namespace, key, value) - async def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + async def get(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType: """Get an item from the Redis cache.""" - # value = await self._redis.hget(self._namespace, key) + await self._validate_cache() + key = self._to_typestring(key) + value = await self._redis.hget(self._namespace, key) + + if value is None: + return default + else: + value = self._from_typestring(value.decode("utf-8")) + return value - async def delete(self, key: ValidRedisKey) -> None: + async def delete(self, key: ValidRedisType) -> None: """Delete an item from the Redis cache.""" # await self._redis.hdel(self._namespace, key) - async def contains(self, key: ValidRedisKey) -> bool: + async def contains(self, key: ValidRedisType) -> bool: """Check if a key exists in the Redis cache.""" # return await self._redis.hexists(self._namespace, key) @@ -103,7 +149,7 @@ class RedisCache: """Deletes the entire hash from the Redis cache.""" # await self._redis.delete(self._namespace) - async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + async def pop(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType: """Get the item, remove it from the cache, and provide a default if not found.""" # value = await self.get(key, default) # await self.delete(key) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 991225481..ad38bfde0 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -7,27 +7,41 @@ from tests import helpers class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisDict class from utils.redis_dict.py.""" + """Tests the RedisCache class from utils.redis_dict.py.""" redis = RedisCache() - async def asyncSetUp(self): # noqa: N802 - this special method can't be all lowercase + async def asyncSetUp(self): # noqa: N802 """Sets up the objects that only have to be initialized once.""" self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - def test_class_attribute_namespace(self): + async def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") - # Test that errors are raised when this isn't true. - # def test_set_get_item(self): - # """Test that users can set and get items from the RedisDict.""" - # self.redis['favorite_fruit'] = 'melon' - # self.redis['favorite_number'] = 86 - # self.assertEqual(self.redis['favorite_fruit'], 'melon') - # self.assertEqual(self.redis['favorite_number'], 86) - # + # Test that errors are raised when not assigned as a class attribute + bad_cache = RedisCache() + + with self.assertRaises(RuntimeError): + await bad_cache.set("test", "me_up_deadman") + + async def test_set_get_item(self): + """Test that users can set and get items from the RedisDict.""" + test_cases = ( + ('favorite_fruit', 'melon'), + ('favorite_number', 86), + ('favorite_fraction', 86.54) + ) + + # Test that we can get and set different types. + for test in test_cases: + await self.redis.set(*test) + self.assertEqual(await self.redis.get(test[0]), test[1]) + + # Test that .get allows a default value + self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + # def test_set_item_types(self): # """Test that setitem rejects keys and values that are not strings, ints or floats.""" # fruits = ["lemon", "melon", "apple"] diff --git a/tests/helpers.py b/tests/helpers.py index d226be3f0..2b176db79 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -299,7 +299,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): For more information, see the `MockGuild` docstring. """ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) - additional_spec_asyncs = ("wait_for",) + additional_spec_asyncs = ("wait_for", "_redis_ready") def __init__(self, **kwargs) -> None: super().__init__(**kwargs) |