aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-05-27 20:15:19 +0200
committerGravatar Leon Sandøy <[email protected]>2020-05-27 20:15:19 +0200
commit4db313e9a7899666f1597094b0d88447c7b64311 (patch)
treef33ec4122be58991da50dcff482ea5aa9a6d8eb6
parentRefactor .increment and add lock test. (diff)
Floats are no longer permitted as RedisCache keys.
Also added a test for this. This is the DRYest approach I could find. It's a little ugly, but I think it's probably good enough.
-rw-r--r--bot/utils/redis_cache.py116
-rw-r--r--tests/bot/utils/test_redis_cache.py13
2 files changed, 86 insertions, 43 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index 33e5d5852..afd37f8f8 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -2,26 +2,42 @@ from __future__ import annotations
import asyncio
import logging
-from typing import Any, Dict, ItemsView, Optional, Union
+import typing
+from typing import Any, Dict, ItemsView, Optional, Tuple, Union
from bot.bot import Bot
log = logging.getLogger(__name__)
-RedisType = Union[str, int, float]
-TYPESTRING_PREFIXES = (
+# Type aliases
+RedisKeyType = Union[str, int]
+RedisValueType = Union[str, int, float]
+
+# Prefix tuples
+PrefixTuple = Tuple[Tuple[str, Any]]
+TYPESTRING_VALUE_PREFIXES = (
("f|", float),
("i|", int),
("s|", str),
)
+TYPESTRING_KEY_PREFIXES = (
+ ("i|", int),
+ ("s|", str),
+)
# Makes a nice list like "float, int, and str"
-NICE_TYPE_LIST = ", ".join(str(_type.__name__) for _, _type in TYPESTRING_PREFIXES)
-NICE_TYPE_LIST = ", and ".join(NICE_TYPE_LIST.rsplit(", ", 1))
+NICE_VALUE_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisValueType))
+NICE_VALUE_TYPE_LIST = ", and ".join(NICE_VALUE_TYPE_LIST.rsplit(", ", 1))
+
+NICE_KEY_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisKeyType))
+NICE_KEY_TYPE_LIST = ", and ".join(NICE_KEY_TYPE_LIST.rsplit(", ", 1))
# Makes a list like "'f|', 'i|', and 's|'"
-NICE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_PREFIXES])
-NICE_PREFIX_LIST = ", and ".join(NICE_PREFIX_LIST.rsplit(", ", 1))
+NICE_VALUE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_VALUE_PREFIXES])
+NICE_VALUE_PREFIX_LIST = ", and ".join(NICE_VALUE_PREFIX_LIST.rsplit(", ", 1))
+
+NICE_KEY_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_KEY_PREFIXES])
+NICE_KEY_PREFIX_LIST = ", and ".join(NICE_KEY_PREFIX_LIST.rsplit(", ", 1))
class RedisCache:
@@ -99,33 +115,57 @@ class RedisCache:
self._namespace = namespace
@staticmethod
- def _to_typestring(value: RedisType) -> str:
+ def _to_typestring(
+ key_or_value: Union[RedisKeyType, RedisValueType],
+ prefixes: PrefixTuple,
+ nice_type_list: str
+ ) -> str:
"""Turn a valid Redis type into a typestring."""
- for prefix, _type in TYPESTRING_PREFIXES:
- if isinstance(value, _type):
- return f"{prefix}{value}"
- raise TypeError(f"RedisCache._from_typestring only supports the types {NICE_TYPE_LIST}.")
+ for prefix, _type in prefixes:
+ if isinstance(key_or_value, _type):
+ return f"{prefix}{key_or_value}"
+ raise TypeError(f"RedisCache._from_typestring only supports the types {nice_type_list}.")
@staticmethod
- def _from_typestring(value: Union[bytes, str]) -> RedisType:
- """Turn a typestring into a valid Redis type."""
+ def _from_typestring(
+ key_or_value: Union[bytes, str],
+ prefixes: PrefixTuple,
+ nice_prefix_list: str,
+ ) -> Union[RedisKeyType, RedisValueType]:
+ """Deserialize a typestring into a valid Redis type."""
# Stuff that comes out of Redis will be bytestrings, so let's decode those.
- if isinstance(value, bytes):
- value = value.decode('utf-8')
+ if isinstance(key_or_value, bytes):
+ key_or_value = key_or_value.decode('utf-8')
# Now we convert our unicode string back into the type it originally was.
- for prefix, _type in TYPESTRING_PREFIXES:
- if value.startswith(prefix):
- return _type(value[len(prefix):])
- raise TypeError(f"RedisCache._to_typestring only supports the prefixes {NICE_PREFIX_LIST}.")
+ for prefix, _type in prefixes:
+ if key_or_value.startswith(prefix):
+ return _type(key_or_value[len(prefix):])
+ raise TypeError(f"RedisCache._to_typestring only supports the prefixes {nice_prefix_list}.")
+
+ def _key_to_typestring(self, key: RedisKeyType) -> str:
+ """Serialize a RedisKeyType object into a typestring."""
+ return self._to_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_TYPE_LIST)
+
+ def _value_to_typestring(self, value: RedisValueType) -> str:
+ """Serialize a RedisValueType object into a typestring."""
+ return self._to_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_TYPE_LIST)
+
+ def _key_from_typestring(self, key: Union[bytes, str]) -> RedisKeyType:
+ """Deserialize a RedisKeyType object from a typestring."""
+ return self._from_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_PREFIX_LIST)
+
+ def _value_from_typestring(self, value: Union[bytes, str]) -> RedisValueType:
+ """Deserialize a RedisValueType object from a typestring."""
+ return self._from_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_PREFIX_LIST)
def _dict_from_typestring(self, dictionary: Dict) -> Dict:
"""Turns all contents of a dict into valid Redis types."""
- return {self._from_typestring(key): self._from_typestring(value) for key, value in dictionary.items()}
+ return {self._key_from_typestring(key): self._value_from_typestring(value) for key, value in dictionary.items()}
def _dict_to_typestring(self, dictionary: Dict) -> Dict:
"""Turns all contents of a dict into typestrings."""
- return {self._to_typestring(key): self._to_typestring(value) for key, value in dictionary.items()}
+ return {self._key_to_typestring(key): self._value_to_typestring(value) for key, value in dictionary.items()}
async def _validate_cache(self) -> None:
"""Validate that the RedisCache is ready to be used."""
@@ -209,21 +249,21 @@ class RedisCache:
"""Return a beautiful representation of this object instance."""
return f"RedisCache(namespace={self._namespace!r})"
- async def set(self, key: RedisType, value: RedisType) -> None:
+ async def set(self, key: RedisKeyType, value: RedisValueType) -> None:
"""Store an item in the Redis cache."""
await self._validate_cache()
# Convert to a typestring and then set it
- key = self._to_typestring(key)
- value = self._to_typestring(value)
+ key = self._key_to_typestring(key)
+ value = self._value_to_typestring(value)
log.trace(f"Setting {key} to {value}.")
await self._redis.hset(self._namespace, key, value)
- async def get(self, key: RedisType, default: Optional[RedisType] = None) -> Optional[RedisType]:
+ async def get(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> Optional[RedisValueType]:
"""Get an item from the Redis cache."""
await self._validate_cache()
- key = self._to_typestring(key)
+ key = self._key_to_typestring(key)
log.trace(f"Attempting to retrieve {key}.")
value = await self._redis.hget(self._namespace, key)
@@ -232,11 +272,11 @@ class RedisCache:
log.trace(f"Value not found, returning default value {default}")
return default
else:
- value = self._from_typestring(value)
+ value = self._value_from_typestring(value)
log.trace(f"Value found, returning value {value}")
return value
- async def delete(self, key: RedisType) -> None:
+ async def delete(self, key: RedisKeyType) -> None:
"""
Delete an item from the Redis cache.
@@ -245,19 +285,19 @@ class RedisCache:
See https://redis.io/commands/hdel for more info on how this works.
"""
await self._validate_cache()
- key = self._to_typestring(key)
+ key = self._key_to_typestring(key)
log.trace(f"Attempting to delete {key}.")
return await self._redis.hdel(self._namespace, key)
- async def contains(self, key: RedisType) -> bool:
+ async def contains(self, key: RedisKeyType) -> bool:
"""
Check if a key exists in the Redis cache.
Return True if the key exists, otherwise False.
"""
await self._validate_cache()
- key = self._to_typestring(key)
+ key = self._key_to_typestring(key)
exists = await self._redis.hexists(self._namespace, key)
log.trace(f"Testing if {key} exists in the RedisCache - Result is {exists}")
@@ -304,7 +344,7 @@ class RedisCache:
log.trace("Clearing the cache of all key/value pairs.")
await self._redis.delete(self._namespace)
- async def pop(self, key: RedisType, default: Optional[RedisType] = None) -> RedisType:
+ async def pop(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> RedisValueType:
"""Get the item, remove it from the cache, and provide a default if not found."""
log.trace(f"Attempting to pop {key}.")
value = await self.get(key, default)
@@ -317,7 +357,7 @@ class RedisCache:
return value
- async def update(self, items: Dict[RedisType, RedisType]) -> None:
+ async def update(self, items: Dict[RedisKeyType, RedisValueType]) -> None:
"""
Update the Redis cache with multiple values.
@@ -326,14 +366,14 @@ class RedisCache:
do not exist in the RedisCache, they are created. If they do exist, the values
are updated with the new ones from `items`.
- Please note that both the keys and the values in the `items` dictionary
- must consist of valid RedisTypes - ints, floats, or strings.
+ Please note that keys and the values in the `items` dictionary
+ must consist of valid RedisKeyTypes and RedisValueTypes.
"""
await self._validate_cache()
log.trace(f"Updating the cache with the following items:\n{items}")
await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items))
- async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ async def increment(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None:
"""
Increment the value by `amount`.
@@ -373,7 +413,7 @@ class RedisCache:
log.error(error_message)
raise TypeError(error_message)
- async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ async def decrement(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None:
"""
Decrement the value by `amount`.
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index efd168dac..4f95dff03 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -70,12 +70,15 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
async def test_set_item_type(self):
- """Test that .set rejects keys and values that are not strings, ints or floats."""
+ """Test that .set rejects keys and values that are not permitted."""
fruits = ["lemon", "melon", "apple"]
with self.assertRaises(TypeError):
await self.cog.redis.set(fruits, "nice")
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(4.23, "nice")
+
async def test_delete_item(self):
"""Test that .delete allows us to delete stuff from the RedisCache."""
# Add an item and verify that it gets added
@@ -176,16 +179,16 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
# Test conversion to typestring
for _input, expected in conversion_tests:
- self.assertEqual(self.cog.redis._to_typestring(_input), expected)
+ self.assertEqual(self.cog.redis._value_to_typestring(_input), expected)
# Test conversion from typestrings
for _input, expected in conversion_tests:
- self.assertEqual(self.cog.redis._from_typestring(expected), _input)
+ self.assertEqual(self.cog.redis._value_from_typestring(expected), _input)
# Test that exceptions are raised on invalid input
with self.assertRaises(TypeError):
- self.cog.redis._to_typestring(["internet"])
- self.cog.redis._from_typestring("o|firedog")
+ self.cog.redis._value_to_typestring(["internet"])
+ self.cog.redis._value_from_typestring("o|firedog")
async def test_increment_decrement(self):
"""Test .increment and .decrement methods."""