diff options
-rw-r--r-- | bot/utils/redis_cache.py | 115 | ||||
-rw-r--r-- | bot/utils/redis_dict.py | 137 |
2 files changed, 115 insertions, 137 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py new file mode 100644 index 000000000..d0a7eba4a --- /dev/null +++ b/bot/utils/redis_cache.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union + +from bot.bot import Bot + +ValidRedisKey = Union[str, int, float] +JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]] + + +class RedisCache: + """ + A simplified interface for a Redis connection. + + This class must be created as a class attribute in a class. This is because it + uses __set_name__ to create a namespace like MyCog.my_class_attribute which is + used as a hash name when we store stuff in Redis, to prevent collisions. + + The class this object is instantiated in must also contains an attribute with an + instance of Bot. This is because Bot contains our redis_pool, which is how this + class communicates with the Redis server. + + We implement several convenient methods that are fairly similar to have a dict + behaves, and should be familiar to Python users. The biggest difference is that + all the public methods in this class are coroutines. + """ + + _namespaces = [] + + def __init__(self) -> None: + """Raise a NotImplementedError if `__set_name__` hasn't been run.""" + if not self._namespace: + raise NotImplementedError("RedisCache must be a class attribute.") + + def _set_namespace(self, namespace: str) -> None: + """Try to set the namespace, but do not permit collisions.""" + while namespace in self._namespaces: + namespace += "_" + + self._namespaces.append(namespace) + self._namespace = namespace + + def __set_name__(self, owner: object, attribute_name: str) -> None: + """ + Set the namespace to Class.attribute_name. + + Called automatically when this class is constructed inside a class as an attribute. + """ + if not self._has_custom_namespace: + self._set_namespace(f"{owner.__name__}.{attribute_name}") + + def __get__(self, instance: RedisCache, owner: Any) -> RedisCache: + """Fetch the Bot instance, we need it for the redis pool.""" + if self.bot: + return self + + if instance is None: + raise NotImplementedError("You must create an instance of RedisCache to use it.") + + for attribute in vars(instance).values(): + if isinstance(attribute, Bot): + self.bot = attribute + self._redis = self.bot.redis_pool + return self + else: + raise RuntimeError("Cannot initialize a RedisCache without a `Bot` instance.") + + def __repr__(self) -> str: + """Return a beautiful representation of this object instance.""" + return f"RedisCache(namespace={self._namespace!r})" + + async def set(self, key: ValidRedisKey, value: JSONSerializableType) -> None: + """Store an item in the Redis cache.""" + # await self._redis.hset(self._namespace, key, value) + + async def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + """Get an item from the Redis cache.""" + # value = await self._redis.hget(self._namespace, key) + + async def delete(self, key: ValidRedisKey) -> None: + """Delete an item from the Redis cache.""" + # await self._redis.hdel(self._namespace, key) + + async def contains(self, key: ValidRedisKey) -> bool: + """Check if a key exists in the Redis cache.""" + # return await self._redis.hexists(self._namespace, key) + + async def items(self) -> AsyncIterator: + """Iterate all the items in the Redis cache.""" + # data = await redis.hgetall(self.get_with_namespace(key)) + # for item in data: + # yield item + + async def length(self) -> int: + """Return the number of items in the Redis cache.""" + # return await self._redis.hlen(self._namespace) + + async def to_dict(self) -> Dict: + """Convert to dict and return.""" + # return dict(self.items()) + + async def clear(self) -> None: + """Deletes the entire hash from the Redis cache.""" + # await self._redis.delete(self._namespace) + + async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + """Get the item, remove it from the cache, and provide a default if not found.""" + value = await self.get(key, default) + await self.delete(key) + return value + + async def update(self) -> None: + """Update the Redis cache with multiple values.""" + # https://aioredis.readthedocs.io/en/v1.3.0/mixins.html#aioredis.commands.HashCommandsMixin.hmset_dict diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py deleted file mode 100644 index 4a5e34249..000000000 --- a/bot/utils/redis_dict.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import MutableMapping -from enum import Enum -from typing import Dict, List, Optional, Tuple, Union - -import redis as redis_py - -from bot import constants - -ValidRedisKey = Union[str, int, float] -JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]] - - -class RedisDict(MutableMapping): - """ - A dictionary interface for a Redis database. - - Objects created by this class should mostly behave like a normal dictionary, - but will store all the data in our Redis database for persistence between restarts. - - Redis is limited to simple types, so to allow you to store collections like lists - and dictionaries, we JSON deserialize every value. That means that it will not be possible - to store complex objects, only stuff like strings, numbers, and collections of strings and numbers. - """ - - _namespaces = [] - _redis = redis_py.Redis( - host=constants.Redis.host, - port=constants.Redis.port, - password=constants.Redis.password, - ) # Can be overridden for testing - - def __init__(self, namespace: Optional[str] = None) -> None: - """Initialize the RedisDict with the right namespace.""" - super().__init__() - self._has_custom_namespace = namespace is not None - - if self._has_custom_namespace: - self._set_namespace(namespace) - else: - self.namespace = "global" - - def _set_namespace(self, namespace: str) -> None: - """Try to set the namespace, but do not permit collisions.""" - while namespace in self._namespaces: - namespace = namespace + "_" - - self._namespaces.append(namespace) - self._namespace = namespace - - def __set_name__(self, owner: object, attribute_name: str) -> None: - """ - Set the namespace to Class.attribute_name. - - Called automatically when this class is constructed inside a class as an attribute, as long as - no custom namespace is provided to the constructor. - """ - if not self._has_custom_namespace: - self._set_namespace(f"{owner.__name__}.{attribute_name}") - - def __repr__(self) -> str: - """Return a beautiful representation of this object instance.""" - return f"RedisDict(namespace={self._namespace!r})" - - def __eq__(self, other: RedisDict) -> bool: - """Check equality between two RedisDicts.""" - return self.items() == other.items() and self._namespace == other._namespace - - def __ne__(self, other: RedisDict) -> bool: - """Check inequality between two RedisDicts.""" - return self.items() != other.items() or self._namespace != other._namespace - - def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType): - """Store an item in the Redis cache.""" - # JSON serialize the value before storing it. - json_value = json.dumps(value) - self._redis.hset(self._namespace, key, json_value) - - def __getitem__(self, key: ValidRedisKey): - """Get an item from the Redis cache.""" - value = self._redis.hget(self._namespace, key) - - if value: - return json.loads(value) - - def __delitem__(self, key: ValidRedisKey): - """Delete an item from the Redis cache.""" - self._redis.hdel(self._namespace, key) - - def __contains__(self, key: ValidRedisKey): - """Check if a key exists in the Redis cache.""" - return self._redis.hexists(self._namespace, key) - - def __iter__(self): - """Iterate all the items in the Redis cache.""" - keys = self._redis.hkeys(self._namespace) - return iter([key.decode('utf-8') for key in keys]) - - def __len__(self): - """Return the number of items in the Redis cache.""" - return self._redis.hlen(self._namespace) - - def copy(self) -> Dict: - """Convert to dict and return.""" - return dict(self.items()) - - def clear(self) -> None: - """Deletes the entire hash from the Redis cache.""" - self._redis.delete(self._namespace) - - def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: - """Get the item, but provide a default if not found.""" - if key in self: - return self[key] - else: - return default - - def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: - """Get the item, remove it from the cache, and provide a default if not found.""" - value = self.get(key, default) - del self[key] - return value - - def popitem(self) -> JSONSerializableType: - """Get the last item added to the cache.""" - key = list(self.keys())[-1] - return self.pop(key) - - def setdefault(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: - """Try to get the item. If the item does not exist, set it to `default` and return that.""" - value = self.get(key) - - if value is None: - self[key] = default - return default |