aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-05-22 02:34:04 +0200
committerGravatar Leon Sandøy <[email protected]>2020-05-22 02:34:04 +0200
commitcc3591df0f14041be683bb6716d1e427c52aa2d7 (patch)
treeaf4321016a9a8a63fe0bb765d2fd04c35155770e
parentAdd the REDIS_PASSWORD environment variable (diff)
Add the REDIS_PASSWORD environment variable
In production, we will need this password to make a connection to Redis.
-rw-r--r--bot/utils/redis_cache.py115
-rw-r--r--bot/utils/redis_dict.py137
2 files changed, 115 insertions, 137 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
new file mode 100644
index 000000000..d0a7eba4a
--- /dev/null
+++ b/bot/utils/redis_cache.py
@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union
+
+from bot.bot import Bot
+
+ValidRedisKey = Union[str, int, float]
+JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]]
+
+
+class RedisCache:
+ """
+ A simplified interface for a Redis connection.
+
+ This class must be created as a class attribute in a class. This is because it
+ uses __set_name__ to create a namespace like MyCog.my_class_attribute which is
+ used as a hash name when we store stuff in Redis, to prevent collisions.
+
+ The class this object is instantiated in must also contains an attribute with an
+ instance of Bot. This is because Bot contains our redis_pool, which is how this
+ class communicates with the Redis server.
+
+ We implement several convenient methods that are fairly similar to have a dict
+ behaves, and should be familiar to Python users. The biggest difference is that
+ all the public methods in this class are coroutines.
+ """
+
+ _namespaces = []
+
+ def __init__(self) -> None:
+ """Raise a NotImplementedError if `__set_name__` hasn't been run."""
+ if not self._namespace:
+ raise NotImplementedError("RedisCache must be a class attribute.")
+
+ def _set_namespace(self, namespace: str) -> None:
+ """Try to set the namespace, but do not permit collisions."""
+ while namespace in self._namespaces:
+ namespace += "_"
+
+ self._namespaces.append(namespace)
+ self._namespace = namespace
+
+ def __set_name__(self, owner: object, attribute_name: str) -> None:
+ """
+ Set the namespace to Class.attribute_name.
+
+ Called automatically when this class is constructed inside a class as an attribute.
+ """
+ if not self._has_custom_namespace:
+ self._set_namespace(f"{owner.__name__}.{attribute_name}")
+
+ def __get__(self, instance: RedisCache, owner: Any) -> RedisCache:
+ """Fetch the Bot instance, we need it for the redis pool."""
+ if self.bot:
+ return self
+
+ if instance is None:
+ raise NotImplementedError("You must create an instance of RedisCache to use it.")
+
+ for attribute in vars(instance).values():
+ if isinstance(attribute, Bot):
+ self.bot = attribute
+ self._redis = self.bot.redis_pool
+ return self
+ else:
+ raise RuntimeError("Cannot initialize a RedisCache without a `Bot` instance.")
+
+ def __repr__(self) -> str:
+ """Return a beautiful representation of this object instance."""
+ return f"RedisCache(namespace={self._namespace!r})"
+
+ async def set(self, key: ValidRedisKey, value: JSONSerializableType) -> None:
+ """Store an item in the Redis cache."""
+ # await self._redis.hset(self._namespace, key, value)
+
+ async def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
+ """Get an item from the Redis cache."""
+ # value = await self._redis.hget(self._namespace, key)
+
+ async def delete(self, key: ValidRedisKey) -> None:
+ """Delete an item from the Redis cache."""
+ # await self._redis.hdel(self._namespace, key)
+
+ async def contains(self, key: ValidRedisKey) -> bool:
+ """Check if a key exists in the Redis cache."""
+ # return await self._redis.hexists(self._namespace, key)
+
+ async def items(self) -> AsyncIterator:
+ """Iterate all the items in the Redis cache."""
+ # data = await redis.hgetall(self.get_with_namespace(key))
+ # for item in data:
+ # yield item
+
+ async def length(self) -> int:
+ """Return the number of items in the Redis cache."""
+ # return await self._redis.hlen(self._namespace)
+
+ async def to_dict(self) -> Dict:
+ """Convert to dict and return."""
+ # return dict(self.items())
+
+ async def clear(self) -> None:
+ """Deletes the entire hash from the Redis cache."""
+ # await self._redis.delete(self._namespace)
+
+ async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
+ """Get the item, remove it from the cache, and provide a default if not found."""
+ value = await self.get(key, default)
+ await self.delete(key)
+ return value
+
+ async def update(self) -> None:
+ """Update the Redis cache with multiple values."""
+ # https://aioredis.readthedocs.io/en/v1.3.0/mixins.html#aioredis.commands.HashCommandsMixin.hmset_dict
diff --git a/bot/utils/redis_dict.py b/bot/utils/redis_dict.py
deleted file mode 100644
index 4a5e34249..000000000
--- a/bot/utils/redis_dict.py
+++ /dev/null
@@ -1,137 +0,0 @@
-from __future__ import annotations
-
-import json
-from collections.abc import MutableMapping
-from enum import Enum
-from typing import Dict, List, Optional, Tuple, Union
-
-import redis as redis_py
-
-from bot import constants
-
-ValidRedisKey = Union[str, int, float]
-JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]]
-
-
-class RedisDict(MutableMapping):
- """
- A dictionary interface for a Redis database.
-
- Objects created by this class should mostly behave like a normal dictionary,
- but will store all the data in our Redis database for persistence between restarts.
-
- Redis is limited to simple types, so to allow you to store collections like lists
- and dictionaries, we JSON deserialize every value. That means that it will not be possible
- to store complex objects, only stuff like strings, numbers, and collections of strings and numbers.
- """
-
- _namespaces = []
- _redis = redis_py.Redis(
- host=constants.Redis.host,
- port=constants.Redis.port,
- password=constants.Redis.password,
- ) # Can be overridden for testing
-
- def __init__(self, namespace: Optional[str] = None) -> None:
- """Initialize the RedisDict with the right namespace."""
- super().__init__()
- self._has_custom_namespace = namespace is not None
-
- if self._has_custom_namespace:
- self._set_namespace(namespace)
- else:
- self.namespace = "global"
-
- def _set_namespace(self, namespace: str) -> None:
- """Try to set the namespace, but do not permit collisions."""
- while namespace in self._namespaces:
- namespace = namespace + "_"
-
- self._namespaces.append(namespace)
- self._namespace = namespace
-
- def __set_name__(self, owner: object, attribute_name: str) -> None:
- """
- Set the namespace to Class.attribute_name.
-
- Called automatically when this class is constructed inside a class as an attribute, as long as
- no custom namespace is provided to the constructor.
- """
- if not self._has_custom_namespace:
- self._set_namespace(f"{owner.__name__}.{attribute_name}")
-
- def __repr__(self) -> str:
- """Return a beautiful representation of this object instance."""
- return f"RedisDict(namespace={self._namespace!r})"
-
- def __eq__(self, other: RedisDict) -> bool:
- """Check equality between two RedisDicts."""
- return self.items() == other.items() and self._namespace == other._namespace
-
- def __ne__(self, other: RedisDict) -> bool:
- """Check inequality between two RedisDicts."""
- return self.items() != other.items() or self._namespace != other._namespace
-
- def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType):
- """Store an item in the Redis cache."""
- # JSON serialize the value before storing it.
- json_value = json.dumps(value)
- self._redis.hset(self._namespace, key, json_value)
-
- def __getitem__(self, key: ValidRedisKey):
- """Get an item from the Redis cache."""
- value = self._redis.hget(self._namespace, key)
-
- if value:
- return json.loads(value)
-
- def __delitem__(self, key: ValidRedisKey):
- """Delete an item from the Redis cache."""
- self._redis.hdel(self._namespace, key)
-
- def __contains__(self, key: ValidRedisKey):
- """Check if a key exists in the Redis cache."""
- return self._redis.hexists(self._namespace, key)
-
- def __iter__(self):
- """Iterate all the items in the Redis cache."""
- keys = self._redis.hkeys(self._namespace)
- return iter([key.decode('utf-8') for key in keys])
-
- def __len__(self):
- """Return the number of items in the Redis cache."""
- return self._redis.hlen(self._namespace)
-
- def copy(self) -> Dict:
- """Convert to dict and return."""
- return dict(self.items())
-
- def clear(self) -> None:
- """Deletes the entire hash from the Redis cache."""
- self._redis.delete(self._namespace)
-
- def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
- """Get the item, but provide a default if not found."""
- if key in self:
- return self[key]
- else:
- return default
-
- def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
- """Get the item, remove it from the cache, and provide a default if not found."""
- value = self.get(key, default)
- del self[key]
- return value
-
- def popitem(self) -> JSONSerializableType:
- """Get the last item added to the cache."""
- key = list(self.keys())[-1]
- return self.pop(key)
-
- def setdefault(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
- """Try to get the item. If the item does not exist, set it to `default` and return that."""
- value = self.get(key)
-
- if value is None:
- self[key] = default
- return default