diff options
| author | 2020-05-17 02:30:28 +0200 | |
|---|---|---|
| committer | 2020-05-17 02:30:28 +0200 | |
| commit | 677a7f755a15f8fdf0cd97e399c4265dd8e702d9 (patch) | |
| tree | 6d91e02f94bc0b068436a3770aaf7e394921dee5 | |
| parent | Implements .clear with hash deletion. (diff) | |
Implement .get, equality, and membership check
This is supposed to be provided by our MutableMapping mixin, but unit
tests are demonstrating that these don't really work as intended.
Diffstat (limited to '')
| -rw-r--r-- | bot/utils/redis_dict.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py index b2fd7d2e9..35439b2f3 100644 --- a/bot/utils/redis_dict.py +++ b/bot/utils/redis_dict.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import MutableMapping from enum import Enum @@ -28,7 +30,11 @@ class RedisDict(MutableMapping): """Initialize the RedisDict with the right namespace.""" super().__init__() self._has_custom_namespace = namespace is not None - self._set_namespace(namespace) + + if self._has_custom_namespace: + self._set_namespace(namespace) + else: + self.namespace = "general" def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -52,6 +58,14 @@ class RedisDict(MutableMapping): """Return a beautiful representation of this object instance.""" return f"RedisDict(namespace={self._namespace!r})" + def __eq__(self, other: RedisDict) -> bool: + """Check equality between two RedisDicts.""" + return self.items() == other.items() and self._namespace == other._namespace + + def __ne__(self, other: RedisDict) -> bool: + """Check inequality between two RedisDicts.""" + return self.items() != other.items() or self._namespace != other._namespace + def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType): """Store an item in the Redis cache.""" # JSON serialize the value before storing it. @@ -61,12 +75,18 @@ class RedisDict(MutableMapping): def __getitem__(self, key: ValidRedisKey): """Get an item from the Redis cache.""" value = self._redis.hget(self._namespace, key) - return json.loads(value) + + if value: + return json.loads(value) def __delitem__(self, key: ValidRedisKey): """Delete an item from the Redis cache.""" self._redis.hdel(self._namespace, key) + def __contains__(self, key: ValidRedisKey): + """Check if a key exists in the Redis cache.""" + return self._redis.hexists(self._namespace, key) + def __iter__(self): """Iterate all the items in the Redis cache.""" return iter(self._redis.hkeys(self._namespace)) @@ -82,3 +102,10 @@ class RedisDict(MutableMapping): def clear(self) -> None: """Deletes the entire hash from the Redis cache.""" self._redis.delete(self._namespace) + + def get(self, key: ValidRedisKey, default: Optional[str] = None) -> JSONSerializableType: + """Get the item, but provide a default if not found.""" + if key in self: + return self[key] + else: + return default |