diff options
| -rw-r--r-- | bot/utils/redis_dict.py | 31 | 
1 files changed, 29 insertions, 2 deletions
diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py index b2fd7d2e9..35439b2f3 100644 --- a/bot/utils/redis_dict.py +++ b/bot/utils/redis_dict.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import json  from collections.abc import MutableMapping  from enum import Enum @@ -28,7 +30,11 @@ class RedisDict(MutableMapping):          """Initialize the RedisDict with the right namespace."""          super().__init__()          self._has_custom_namespace = namespace is not None -        self._set_namespace(namespace) + +        if self._has_custom_namespace: +            self._set_namespace(namespace) +        else: +            self.namespace = "general"      def _set_namespace(self, namespace: str) -> None:          """Try to set the namespace, but do not permit collisions.""" @@ -52,6 +58,14 @@ class RedisDict(MutableMapping):          """Return a beautiful representation of this object instance."""          return f"RedisDict(namespace={self._namespace!r})" +    def __eq__(self, other: RedisDict) -> bool: +        """Check equality between two RedisDicts.""" +        return self.items() == other.items() and self._namespace == other._namespace + +    def __ne__(self, other: RedisDict) -> bool: +        """Check inequality between two RedisDicts.""" +        return self.items() != other.items() or self._namespace != other._namespace +      def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType):          """Store an item in the Redis cache."""          # JSON serialize the value before storing it. @@ -61,12 +75,18 @@ class RedisDict(MutableMapping):      def __getitem__(self, key: ValidRedisKey):          """Get an item from the Redis cache."""          value = self._redis.hget(self._namespace, key) -        return json.loads(value) + +        if value: +            return json.loads(value)      def __delitem__(self, key: ValidRedisKey):          """Delete an item from the Redis cache."""          self._redis.hdel(self._namespace, key) +    def __contains__(self, key: ValidRedisKey): +        """Check if a key exists in the Redis cache.""" +        return self._redis.hexists(self._namespace, key) +      def __iter__(self):          """Iterate all the items in the Redis cache."""          return iter(self._redis.hkeys(self._namespace)) @@ -82,3 +102,10 @@ class RedisDict(MutableMapping):      def clear(self) -> None:          """Deletes the entire hash from the Redis cache."""          self._redis.delete(self._namespace) + +    def get(self, key: ValidRedisKey, default: Optional[str] = None) -> JSONSerializableType: +        """Get the item, but provide a default if not found.""" +        if key in self: +            return self[key] +        else: +            return default  |