diff options
| -rw-r--r-- | bot/utils/redis_cache.py | 54 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 21 | 
2 files changed, 60 insertions, 15 deletions
| diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 6831be157..1ec1b9fea 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -4,6 +4,11 @@ from typing import Any, AsyncIterator, Dict, Optional, Union  from bot.bot import Bot +TYPESTRING_PREFIXES = ( +    ("f|", float), +    ("i|", int), +    ("s|", str), +)  ValidRedisType = Union[str, int, float] @@ -78,26 +83,45 @@ class RedisCache:          self._namespace = namespace      @staticmethod -    def _to_typestring(value: ValidRedisType) -> str: -        """Turn a valid Redis type into a typestring.""" -        if isinstance(value, float): -            return f"f|{value}" -        elif isinstance(value, int): -            return f"i|{value}" -        elif isinstance(value, str): -            return f"s|{value}" +    def _valid_typestring_types() -> str: +        """ +        Creates a nice, readable list of valid types for typestrings, useful for error messages. + +        This will be dynamically updated if we change the TYPESTRING_PREFIXES constant up top. +        """ +        valid_types = ", ".join([str(_type).split("'")[1] for _, _type in TYPESTRING_PREFIXES]) +        valid_types = ", and ".join(valid_types.rsplit(", ", 1)) +        return valid_types      @staticmethod -    def _from_typestring(value: Union[bytes, str]) -> ValidRedisType: +    def _valid_typestring_prefixes() -> str: +        """ +        Creates a nice, readable list of valid prefixes for typestrings, useful for error messages. + +        This will be dynamically updated if we change the TYPESTRING_PREFIXES constant up top. +        """ +        valid_prefixes = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_PREFIXES]) +        valid_prefixes = ", and ".join(valid_prefixes.rsplit(", ", 1)) +        return valid_prefixes + +    def _to_typestring(self, value: ValidRedisType) -> str: +        """Turn a valid Redis type into a typestring.""" +        for prefix, _type in TYPESTRING_PREFIXES: +            if isinstance(value, _type): +                return f"{prefix}{value}" +        raise TypeError(f"RedisCache._from_typestring only supports the types {self._valid_typestring_types()}.") + +    def _from_typestring(self, value: Union[bytes, str]) -> ValidRedisType:          """Turn a typestring into a valid Redis type.""" +        # Stuff that comes out of Redis will be bytestrings, so let's decode those.          if isinstance(value, bytes):              value = value.decode('utf-8') -        if value.startswith("f|"): -            return float(value[2:]) -        if value.startswith("i|"): -            return int(value[2:]) -        if value.startswith("s|"): -            return value[2:] + +        # Now we convert our unicode string back into the type it originally was. +        for prefix, _type in TYPESTRING_PREFIXES: +            if value.startswith(prefix): +                return _type(value[2:]) +        raise TypeError(f"RedisCache._to_typestring only supports the prefixes {self._valid_typestring_prefixes()}.")      def _dict_from_typestring(self, dictionary: Dict) -> Dict:          """Turns all contents of a dict into valid Redis types.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 2ce57499a..150195726 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -152,3 +152,24 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):              "mega": "hungry, though",          }          self.assertDictEqual(await self.redis.to_dict(), result) + +    def test_typestring_conversion(self): +        """Test the typestring-related helper functions.""" +        conversion_tests = ( +            (12, "i|12"), +            (12.4, "f|12.4"), +            ("cowabunga", "s|cowabunga"), +        ) + +        # Test conversion to typestring +        for _input, expected in conversion_tests: +            self.assertEqual(self.redis._to_typestring(_input), expected) + +        # Test conversion from typestrings +        for _input, expected in conversion_tests: +            self.assertEqual(self.redis._from_typestring(expected), _input) + +        # Test that exceptions are raised on invalid input +        with self.assertRaises(TypeError): +            self.redis._to_typestring(["internet"]) +            self.redis._from_typestring("o|firedog") | 
