aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/utils/redis_cache.py35
-rw-r--r--tests/bot/utils/test_redis_cache.py34
2 files changed, 69 insertions, 0 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index fb9a534bd..290fae1a0 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import asyncio
from typing import Any, Dict, ItemsView, Optional, Union
from bot.bot import Bot
@@ -77,6 +78,7 @@ class RedisCache:
"""Initialize the RedisCache."""
self._namespace = None
self.bot = None
+ self.increment_lock = asyncio.Lock()
def _set_namespace(self, namespace: str) -> None:
"""Try to set the namespace, but do not permit collisions."""
@@ -287,3 +289,36 @@ class RedisCache:
"""Update the Redis cache with multiple values."""
await self._validate_cache()
await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items))
+
+ async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ """
+ Increment the value by `amount`.
+
+ This works for both floats and ints, but will raise a TypeError
+ if you try to do it for any other type of value.
+
+ This also supports negative amounts, although it would provide better
+ readability to use .decrement() for that.
+ """
+ # Since this has several API calls, we need a lock to prevent race conditions
+ async with self.increment_lock:
+ value = await self.get(key)
+
+ # Can't increment a non-existing value
+ if value is None:
+ raise RuntimeError("The provided key does not exist!")
+
+ # If it does exist, and it's an int or a float, increment and set it.
+ if isinstance(value, int) or isinstance(value, float):
+ value += amount
+ await self.set(key, value)
+ else:
+ raise TypeError("You may only increment or decrement values that are integers or floats.")
+
+ async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ """
+ Decrement the value by `amount`.
+
+ Basically just does the opposite of .increment.
+ """
+ await self.increment(key, -amount)
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 6e12002ed..dbbaef018 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -173,3 +173,37 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(TypeError):
self.redis._to_typestring(["internet"])
self.redis._from_typestring("o|firedog")
+
+ async def test_increment_decrement(self):
+ """Test .increment and .decrement methods."""
+ await self.redis.set("entropic", 5)
+ await self.redis.set("disentropic", 12.5)
+
+ # Test default increment
+ await self.redis.increment("entropic")
+ self.assertEqual(await self.redis.get("entropic"), 6)
+
+ # Test default decrement
+ await self.redis.decrement("entropic")
+ self.assertEqual(await self.redis.get("entropic"), 5)
+
+ # Test float increment with float
+ await self.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.redis.get("disentropic"), 14.5)
+
+ # Test float increment with int
+ await self.redis.increment("disentropic", 2)
+ self.assertEqual(await self.redis.get("disentropic"), 16.5)
+
+ # Test negative increments, because why not.
+ await self.redis.increment("entropic", -5)
+ self.assertEqual(await self.redis.get("entropic"), 0)
+
+ # Negative decrements? Sure.
+ await self.redis.decrement("entropic", -5)
+ self.assertEqual(await self.redis.get("entropic"), 5)
+
+ # What about if we use a negative float to decrement an int?
+ # This should convert the type into a float.
+ await self.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.redis.get("entropic"), 7.5)