diff options
| author | 2020-05-17 02:30:28 +0200 | |
|---|---|---|
| committer | 2020-05-17 02:30:28 +0200 | |
| commit | 677a7f755a15f8fdf0cd97e399c4265dd8e702d9 (patch) | |
| tree | 6d91e02f94bc0b068436a3770aaf7e394921dee5 | |
| parent | Implements .clear with hash deletion. (diff) | |
Implement .get, equality, and membership check
This is supposed to be provided by our MutableMapping mixin, but unit
tests are demonstrating that these don't really work as intended.
| -rw-r--r-- | bot/utils/redis_dict.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py index b2fd7d2e9..35439b2f3 100644 --- a/bot/utils/redis_dict.py +++ b/bot/utils/redis_dict.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import MutableMapping from enum import Enum @@ -28,7 +30,11 @@ class RedisDict(MutableMapping): """Initialize the RedisDict with the right namespace.""" super().__init__() self._has_custom_namespace = namespace is not None - self._set_namespace(namespace) + + if self._has_custom_namespace: + self._set_namespace(namespace) + else: + self.namespace = "general" def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -52,6 +58,14 @@ class RedisDict(MutableMapping): """Return a beautiful representation of this object instance.""" return f"RedisDict(namespace={self._namespace!r})" + def __eq__(self, other: RedisDict) -> bool: + """Check equality between two RedisDicts.""" + return self.items() == other.items() and self._namespace == other._namespace + + def __ne__(self, other: RedisDict) -> bool: + """Check inequality between two RedisDicts.""" + return self.items() != other.items() or self._namespace != other._namespace + def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType): """Store an item in the Redis cache.""" # JSON serialize the value before storing it. @@ -61,12 +75,18 @@ class RedisDict(MutableMapping): def __getitem__(self, key: ValidRedisKey): """Get an item from the Redis cache.""" value = self._redis.hget(self._namespace, key) - return json.loads(value) + + if value: + return json.loads(value) def __delitem__(self, key: ValidRedisKey): """Delete an item from the Redis cache.""" self._redis.hdel(self._namespace, key) + def __contains__(self, key: ValidRedisKey): + """Check if a key exists in the Redis cache.""" + return self._redis.hexists(self._namespace, key) + def __iter__(self): """Iterate all the items in the Redis cache.""" return iter(self._redis.hkeys(self._namespace)) @@ -82,3 +102,10 @@ class RedisDict(MutableMapping): def clear(self) -> None: """Deletes the entire hash from the Redis cache.""" self._redis.delete(self._namespace) + + def get(self, key: ValidRedisKey, default: Optional[str] = None) -> JSONSerializableType: + """Get the item, but provide a default if not found.""" + if key in self: + return self[key] + else: + return default |