aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/utils/redis_dict.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py
index b2fd7d2e9..35439b2f3 100644
--- a/bot/utils/redis_dict.py
+++ b/bot/utils/redis_dict.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from collections.abc import MutableMapping
from enum import Enum
@@ -28,7 +30,11 @@ class RedisDict(MutableMapping):
"""Initialize the RedisDict with the right namespace."""
super().__init__()
self._has_custom_namespace = namespace is not None
- self._set_namespace(namespace)
+
+ if self._has_custom_namespace:
+ self._set_namespace(namespace)
+ else:
+ self.namespace = "general"
def _set_namespace(self, namespace: str) -> None:
"""Try to set the namespace, but do not permit collisions."""
@@ -52,6 +58,14 @@ class RedisDict(MutableMapping):
"""Return a beautiful representation of this object instance."""
return f"RedisDict(namespace={self._namespace!r})"
+ def __eq__(self, other: RedisDict) -> bool:
+ """Check equality between two RedisDicts."""
+ return self.items() == other.items() and self._namespace == other._namespace
+
+ def __ne__(self, other: RedisDict) -> bool:
+ """Check inequality between two RedisDicts."""
+ return self.items() != other.items() or self._namespace != other._namespace
+
def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType):
"""Store an item in the Redis cache."""
# JSON serialize the value before storing it.
@@ -61,12 +75,18 @@ class RedisDict(MutableMapping):
def __getitem__(self, key: ValidRedisKey):
"""Get an item from the Redis cache."""
value = self._redis.hget(self._namespace, key)
- return json.loads(value)
+
+ if value:
+ return json.loads(value)
def __delitem__(self, key: ValidRedisKey):
"""Delete an item from the Redis cache."""
self._redis.hdel(self._namespace, key)
+ def __contains__(self, key: ValidRedisKey):
+ """Check if a key exists in the Redis cache."""
+ return self._redis.hexists(self._namespace, key)
+
def __iter__(self):
"""Iterate all the items in the Redis cache."""
return iter(self._redis.hkeys(self._namespace))
@@ -82,3 +102,10 @@ class RedisDict(MutableMapping):
def clear(self) -> None:
"""Deletes the entire hash from the Redis cache."""
self._redis.delete(self._namespace)
+
+ def get(self, key: ValidRedisKey, default: Optional[str] = None) -> JSONSerializableType:
+ """Get the item, but provide a default if not found."""
+ if key in self:
+ return self[key]
+ else:
+ return default