aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils/redis_dict.py
blob: 4a5e342493e937d706276b512b19c750fde9a185 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import json
from collections.abc import MutableMapping
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union

import redis as redis_py

from bot import constants

ValidRedisKey = Union[str, int, float]
JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]]


class RedisDict(MutableMapping):
    """
    A dictionary interface for a Redis database.

    Objects created by this class should mostly behave like a normal dictionary,
    but will store all the data in our Redis database for persistence between restarts.

    Redis is limited to simple types, so to allow you to store collections like lists
    and dictionaries, we JSON deserialize every value. That means that it will not be possible
    to store complex objects, only stuff like strings, numbers, and collections of strings and numbers.
    """

    _namespaces = []
    _redis = redis_py.Redis(
        host=constants.Redis.host,
        port=constants.Redis.port,
        password=constants.Redis.password,
    )  # Can be overridden for testing

    def __init__(self, namespace: Optional[str] = None) -> None:
        """Initialize the RedisDict with the right namespace."""
        super().__init__()
        self._has_custom_namespace = namespace is not None

        if self._has_custom_namespace:
            self._set_namespace(namespace)
        else:
            self.namespace = "global"

    def _set_namespace(self, namespace: str) -> None:
        """Try to set the namespace, but do not permit collisions."""
        while namespace in self._namespaces:
            namespace = namespace + "_"

        self._namespaces.append(namespace)
        self._namespace = namespace

    def __set_name__(self, owner: object, attribute_name: str) -> None:
        """
        Set the namespace to Class.attribute_name.

        Called automatically when this class is constructed inside a class as an attribute, as long as
        no custom namespace is provided to the constructor.
        """
        if not self._has_custom_namespace:
            self._set_namespace(f"{owner.__name__}.{attribute_name}")

    def __repr__(self) -> str:
        """Return a beautiful representation of this object instance."""
        return f"RedisDict(namespace={self._namespace!r})"

    def __eq__(self, other: RedisDict) -> bool:
        """Check equality between two RedisDicts."""
        return self.items() == other.items() and self._namespace == other._namespace

    def __ne__(self, other: RedisDict) -> bool:
        """Check inequality between two RedisDicts."""
        return self.items() != other.items() or self._namespace != other._namespace

    def __setitem__(self, key: ValidRedisKey, value: JSONSerializableType):
        """Store an item in the Redis cache."""
        # JSON serialize the value before storing it.
        json_value = json.dumps(value)
        self._redis.hset(self._namespace, key, json_value)

    def __getitem__(self, key: ValidRedisKey):
        """Get an item from the Redis cache."""
        value = self._redis.hget(self._namespace, key)

        if value:
            return json.loads(value)

    def __delitem__(self, key: ValidRedisKey):
        """Delete an item from the Redis cache."""
        self._redis.hdel(self._namespace, key)

    def __contains__(self, key: ValidRedisKey):
        """Check if a key exists in the Redis cache."""
        return self._redis.hexists(self._namespace, key)

    def __iter__(self):
        """Iterate all the items in the Redis cache."""
        keys = self._redis.hkeys(self._namespace)
        return iter([key.decode('utf-8') for key in keys])

    def __len__(self):
        """Return the number of items in the Redis cache."""
        return self._redis.hlen(self._namespace)

    def copy(self) -> Dict:
        """Convert to dict and return."""
        return dict(self.items())

    def clear(self) -> None:
        """Deletes the entire hash from the Redis cache."""
        self._redis.delete(self._namespace)

    def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
        """Get the item, but provide a default if not found."""
        if key in self:
            return self[key]
        else:
            return default

    def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
        """Get the item, remove it from the cache, and provide a default if not found."""
        value = self.get(key, default)
        del self[key]
        return value

    def popitem(self) -> JSONSerializableType:
        """Get the last item added to the cache."""
        key = list(self.keys())[-1]
        return self.pop(key)

    def setdefault(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType:
        """Try to get the item. If the item does not exist, set it to `default` and return that."""
        value = self.get(key)

        if value is None:
            self[key] = default
            return default