aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-05-24 13:04:41 +0200
committerGravatar Leon Sandøy <[email protected]>2020-05-24 13:04:41 +0200
commit01bedcadf762262eef0a2b406faf66cdc16a5c85 (patch)
treeb03fa0a94c620b33c2e350699936d68b79bdc765
parentMake .items return ItemsView instead of AsyncIter (diff)
Add .increment and .decrement methods.
Sometimes, we just want to store a counter in the cache. In this case, it is convenient to have a single method that will allow us to increment or decrement this counter. These methods allow you to decrement or increment floats and integers by an specified amount. By default, it'll increment or decrement by 1. Since this involves several API requests, we create an asyncio.Lock so that we don't end up with race conditions.
-rw-r--r--bot/utils/redis_cache.py35
-rw-r--r--tests/bot/utils/test_redis_cache.py34
2 files changed, 69 insertions, 0 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index fb9a534bd..290fae1a0 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import asyncio
from typing import Any, Dict, ItemsView, Optional, Union
from bot.bot import Bot
@@ -77,6 +78,7 @@ class RedisCache:
"""Initialize the RedisCache."""
self._namespace = None
self.bot = None
+ self.increment_lock = asyncio.Lock()
def _set_namespace(self, namespace: str) -> None:
"""Try to set the namespace, but do not permit collisions."""
@@ -287,3 +289,36 @@ class RedisCache:
"""Update the Redis cache with multiple values."""
await self._validate_cache()
await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items))
+
+ async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ """
+ Increment the value by `amount`.
+
+ This works for both floats and ints, but will raise a TypeError
+ if you try to do it for any other type of value.
+
+ This also supports negative amounts, although it would provide better
+ readability to use .decrement() for that.
+ """
+ # Since this has several API calls, we need a lock to prevent race conditions
+ async with self.increment_lock:
+ value = await self.get(key)
+
+ # Can't increment a non-existing value
+ if value is None:
+ raise RuntimeError("The provided key does not exist!")
+
+ # If it does exist, and it's an int or a float, increment and set it.
+ if isinstance(value, int) or isinstance(value, float):
+ value += amount
+ await self.set(key, value)
+ else:
+ raise TypeError("You may only increment or decrement values that are integers or floats.")
+
+ async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None:
+ """
+ Decrement the value by `amount`.
+
+ Basically just does the opposite of .increment.
+ """
+ await self.increment(key, -amount)
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 6e12002ed..dbbaef018 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -173,3 +173,37 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(TypeError):
self.redis._to_typestring(["internet"])
self.redis._from_typestring("o|firedog")
+
+ async def test_increment_decrement(self):
+ """Test .increment and .decrement methods."""
+ await self.redis.set("entropic", 5)
+ await self.redis.set("disentropic", 12.5)
+
+ # Test default increment
+ await self.redis.increment("entropic")
+ self.assertEqual(await self.redis.get("entropic"), 6)
+
+ # Test default decrement
+ await self.redis.decrement("entropic")
+ self.assertEqual(await self.redis.get("entropic"), 5)
+
+ # Test float increment with float
+ await self.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.redis.get("disentropic"), 14.5)
+
+ # Test float increment with int
+ await self.redis.increment("disentropic", 2)
+ self.assertEqual(await self.redis.get("disentropic"), 16.5)
+
+ # Test negative increments, because why not.
+ await self.redis.increment("entropic", -5)
+ self.assertEqual(await self.redis.get("entropic"), 0)
+
+ # Negative decrements? Sure.
+ await self.redis.decrement("entropic", -5)
+ self.assertEqual(await self.redis.get("entropic"), 5)
+
+ # What about if we use a negative float to decrement an int?
+ # This should convert the type into a float.
+ await self.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.redis.get("entropic"), 7.5)