diff options
| -rw-r--r-- | bot/utils/redis_cache.py | 116 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 13 | 
2 files changed, 86 insertions, 43 deletions
| diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 33e5d5852..afd37f8f8 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -2,26 +2,42 @@ from __future__ import annotations  import asyncio  import logging -from typing import Any, Dict, ItemsView, Optional, Union +import typing +from typing import Any, Dict, ItemsView, Optional, Tuple, Union  from bot.bot import Bot  log = logging.getLogger(__name__) -RedisType = Union[str, int, float] -TYPESTRING_PREFIXES = ( +# Type aliases +RedisKeyType = Union[str, int] +RedisValueType = Union[str, int, float] + +# Prefix tuples +PrefixTuple = Tuple[Tuple[str, Any]] +TYPESTRING_VALUE_PREFIXES = (      ("f|", float),      ("i|", int),      ("s|", str),  ) +TYPESTRING_KEY_PREFIXES = ( +    ("i|", int), +    ("s|", str), +)  # Makes a nice list like "float, int, and str" -NICE_TYPE_LIST = ", ".join(str(_type.__name__) for _, _type in TYPESTRING_PREFIXES) -NICE_TYPE_LIST = ", and ".join(NICE_TYPE_LIST.rsplit(", ", 1)) +NICE_VALUE_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisValueType)) +NICE_VALUE_TYPE_LIST = ", and ".join(NICE_VALUE_TYPE_LIST.rsplit(", ", 1)) + +NICE_KEY_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisKeyType)) +NICE_KEY_TYPE_LIST = ", and ".join(NICE_KEY_TYPE_LIST.rsplit(", ", 1))  # Makes a list like "'f|', 'i|', and 's|'" -NICE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_PREFIXES]) -NICE_PREFIX_LIST = ", and ".join(NICE_PREFIX_LIST.rsplit(", ", 1)) +NICE_VALUE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_VALUE_PREFIXES]) +NICE_VALUE_PREFIX_LIST = ", and ".join(NICE_VALUE_PREFIX_LIST.rsplit(", ", 1)) + +NICE_KEY_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_KEY_PREFIXES]) +NICE_KEY_PREFIX_LIST = ", and ".join(NICE_KEY_PREFIX_LIST.rsplit(", ", 1))  class RedisCache: @@ -99,33 +115,57 @@ class RedisCache:          self._namespace = namespace      @staticmethod -    def _to_typestring(value: RedisType) -> str: +    def _to_typestring( +            key_or_value: Union[RedisKeyType, RedisValueType], +            prefixes: PrefixTuple, +            nice_type_list: str +    ) -> str:          """Turn a valid Redis type into a typestring.""" -        for prefix, _type in TYPESTRING_PREFIXES: -            if isinstance(value, _type): -                return f"{prefix}{value}" -        raise TypeError(f"RedisCache._from_typestring only supports the types {NICE_TYPE_LIST}.") +        for prefix, _type in prefixes: +            if isinstance(key_or_value, _type): +                return f"{prefix}{key_or_value}" +        raise TypeError(f"RedisCache._from_typestring only supports the types {nice_type_list}.")      @staticmethod -    def _from_typestring(value: Union[bytes, str]) -> RedisType: -        """Turn a typestring into a valid Redis type.""" +    def _from_typestring( +            key_or_value: Union[bytes, str], +            prefixes: PrefixTuple, +            nice_prefix_list: str, +    ) -> Union[RedisKeyType, RedisValueType]: +        """Deserialize a typestring into a valid Redis type."""          # Stuff that comes out of Redis will be bytestrings, so let's decode those. -        if isinstance(value, bytes): -            value = value.decode('utf-8') +        if isinstance(key_or_value, bytes): +            key_or_value = key_or_value.decode('utf-8')          # Now we convert our unicode string back into the type it originally was. -        for prefix, _type in TYPESTRING_PREFIXES: -            if value.startswith(prefix): -                return _type(value[len(prefix):]) -        raise TypeError(f"RedisCache._to_typestring only supports the prefixes {NICE_PREFIX_LIST}.") +        for prefix, _type in prefixes: +            if key_or_value.startswith(prefix): +                return _type(key_or_value[len(prefix):]) +        raise TypeError(f"RedisCache._to_typestring only supports the prefixes {nice_prefix_list}.") + +    def _key_to_typestring(self, key: RedisKeyType) -> str: +        """Serialize a RedisKeyType object into a typestring.""" +        return self._to_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_TYPE_LIST) + +    def _value_to_typestring(self, value: RedisValueType) -> str: +        """Serialize a RedisValueType object into a typestring.""" +        return self._to_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_TYPE_LIST) + +    def _key_from_typestring(self, key: Union[bytes, str]) -> RedisKeyType: +        """Deserialize a RedisKeyType object from a typestring.""" +        return self._from_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_PREFIX_LIST) + +    def _value_from_typestring(self, value: Union[bytes, str]) -> RedisValueType: +        """Deserialize a RedisValueType object from a typestring.""" +        return self._from_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_PREFIX_LIST)      def _dict_from_typestring(self, dictionary: Dict) -> Dict:          """Turns all contents of a dict into valid Redis types.""" -        return {self._from_typestring(key): self._from_typestring(value) for key, value in dictionary.items()} +        return {self._key_from_typestring(key): self._value_from_typestring(value) for key, value in dictionary.items()}      def _dict_to_typestring(self, dictionary: Dict) -> Dict:          """Turns all contents of a dict into typestrings.""" -        return {self._to_typestring(key): self._to_typestring(value) for key, value in dictionary.items()} +        return {self._key_to_typestring(key): self._value_to_typestring(value) for key, value in dictionary.items()}      async def _validate_cache(self) -> None:          """Validate that the RedisCache is ready to be used.""" @@ -209,21 +249,21 @@ class RedisCache:          """Return a beautiful representation of this object instance."""          return f"RedisCache(namespace={self._namespace!r})" -    async def set(self, key: RedisType, value: RedisType) -> None: +    async def set(self, key: RedisKeyType, value: RedisValueType) -> None:          """Store an item in the Redis cache."""          await self._validate_cache()          # Convert to a typestring and then set it -        key = self._to_typestring(key) -        value = self._to_typestring(value) +        key = self._key_to_typestring(key) +        value = self._value_to_typestring(value)          log.trace(f"Setting {key} to {value}.")          await self._redis.hset(self._namespace, key, value) -    async def get(self, key: RedisType, default: Optional[RedisType] = None) -> Optional[RedisType]: +    async def get(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> Optional[RedisValueType]:          """Get an item from the Redis cache."""          await self._validate_cache() -        key = self._to_typestring(key) +        key = self._key_to_typestring(key)          log.trace(f"Attempting to retrieve {key}.")          value = await self._redis.hget(self._namespace, key) @@ -232,11 +272,11 @@ class RedisCache:              log.trace(f"Value not found, returning default value {default}")              return default          else: -            value = self._from_typestring(value) +            value = self._value_from_typestring(value)              log.trace(f"Value found, returning value {value}")              return value -    async def delete(self, key: RedisType) -> None: +    async def delete(self, key: RedisKeyType) -> None:          """          Delete an item from the Redis cache. @@ -245,19 +285,19 @@ class RedisCache:          See https://redis.io/commands/hdel for more info on how this works.          """          await self._validate_cache() -        key = self._to_typestring(key) +        key = self._key_to_typestring(key)          log.trace(f"Attempting to delete {key}.")          return await self._redis.hdel(self._namespace, key) -    async def contains(self, key: RedisType) -> bool: +    async def contains(self, key: RedisKeyType) -> bool:          """          Check if a key exists in the Redis cache.          Return True if the key exists, otherwise False.          """          await self._validate_cache() -        key = self._to_typestring(key) +        key = self._key_to_typestring(key)          exists = await self._redis.hexists(self._namespace, key)          log.trace(f"Testing if {key} exists in the RedisCache - Result is {exists}") @@ -304,7 +344,7 @@ class RedisCache:          log.trace("Clearing the cache of all key/value pairs.")          await self._redis.delete(self._namespace) -    async def pop(self, key: RedisType, default: Optional[RedisType] = None) -> RedisType: +    async def pop(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> RedisValueType:          """Get the item, remove it from the cache, and provide a default if not found."""          log.trace(f"Attempting to pop {key}.")          value = await self.get(key, default) @@ -317,7 +357,7 @@ class RedisCache:          return value -    async def update(self, items: Dict[RedisType, RedisType]) -> None: +    async def update(self, items: Dict[RedisKeyType, RedisValueType]) -> None:          """          Update the Redis cache with multiple values. @@ -326,14 +366,14 @@ class RedisCache:          do not exist in the RedisCache, they are created. If they do exist, the values          are updated with the new ones from `items`. -        Please note that both the keys and the values in the `items` dictionary -        must consist of valid RedisTypes - ints, floats, or strings. +        Please note that keys and the values in the `items` dictionary +        must consist of valid RedisKeyTypes and RedisValueTypes.          """          await self._validate_cache()          log.trace(f"Updating the cache with the following items:\n{items}")          await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items)) -    async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None: +    async def increment(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None:          """          Increment the value by `amount`. @@ -373,7 +413,7 @@ class RedisCache:                  log.error(error_message)                  raise TypeError(error_message) -    async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None: +    async def decrement(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None:          """          Decrement the value by `amount`. diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index efd168dac..4f95dff03 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -70,12 +70,15 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")      async def test_set_item_type(self): -        """Test that .set rejects keys and values that are not strings, ints or floats.""" +        """Test that .set rejects keys and values that are not permitted."""          fruits = ["lemon", "melon", "apple"]          with self.assertRaises(TypeError):              await self.cog.redis.set(fruits, "nice") +        with self.assertRaises(TypeError): +            await self.cog.redis.set(4.23, "nice") +      async def test_delete_item(self):          """Test that .delete allows us to delete stuff from the RedisCache."""          # Add an item and verify that it gets added @@ -176,16 +179,16 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):          # Test conversion to typestring          for _input, expected in conversion_tests: -            self.assertEqual(self.cog.redis._to_typestring(_input), expected) +            self.assertEqual(self.cog.redis._value_to_typestring(_input), expected)          # Test conversion from typestrings          for _input, expected in conversion_tests: -            self.assertEqual(self.cog.redis._from_typestring(expected), _input) +            self.assertEqual(self.cog.redis._value_from_typestring(expected), _input)          # Test that exceptions are raised on invalid input          with self.assertRaises(TypeError): -            self.cog.redis._to_typestring(["internet"]) -            self.cog.redis._from_typestring("o|firedog") +            self.cog.redis._value_to_typestring(["internet"]) +            self.cog.redis._value_from_typestring("o|firedog")      async def test_increment_decrement(self):          """Test .increment and .decrement methods.""" | 
