diff options
| -rw-r--r-- | bot/utils/redis_cache.py | 11 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 163 |
2 files changed, 109 insertions, 65 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 895a12da4..33e5d5852 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -81,7 +81,7 @@ class RedisCache: """Initialize the RedisCache.""" self._namespace = None self.bot = None - self._increment_lock = asyncio.Lock() + self._increment_lock = None def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -345,6 +345,15 @@ class RedisCache: """ log.trace(f"Attempting to increment/decrement the value with the key {key} by {amount}.") + # We initialize the lock here, because we need to ensure we get it + # running on the same loop as the calling coroutine. + # + # If we initialized the lock in the __init__, the loop that the coroutine this method + # would be called from might not exist yet, and so the lock would be on a different + # loop, which would raise RuntimeErrors. + if self._increment_lock is None: + self._increment_lock = asyncio.Lock() + # Since this has several API calls, we need a lock to prevent race conditions async with self._increment_lock: value = await self.get(key) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 900a6d035..efd168dac 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -1,3 +1,4 @@ +import asyncio import unittest import fakeredis.aioredis @@ -9,17 +10,30 @@ from tests import helpers class RedisCacheTests(unittest.IsolatedAsyncioTestCase): """Tests the RedisCache class from utils.redis_dict.py.""" - redis = RedisCache() - async def asyncSetUp(self): # noqa: N802 """Sets up the objects that only have to be initialized once.""" self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - await self.redis.clear() + + # Okay, so this is necessary so that we can create a clean new + # class for every test method, and we want that because it will + # ensure we get a fresh loop, which is necessary for test_increment_lock + # to be able to pass. + class DummyCog: + """A dummy cog, for dummies.""" + + redis = RedisCache() + + def __init__(self, bot: helpers.MockBot): + self.bot = bot + + self.cog = DummyCog(self.bot) + + await self.cog.redis.clear() def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") + self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") async def test_class_attribute_required(self): """Test that errors are raised when not assigned as a class attribute.""" @@ -31,9 +45,13 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): def test_namespace_collision(self): """Test that we prevent colliding namespaces.""" - bad_cache = RedisCache() - bad_cache._set_namespace("RedisCacheTests.redis") - self.assertEqual(bad_cache._namespace, "RedisCacheTests.redis_") + bob_cache_1 = RedisCache() + bob_cache_1._set_namespace("BobRoss") + self.assertEqual(bob_cache_1._namespace, "BobRoss") + + bob_cache_2 = RedisCache() + bob_cache_2._set_namespace("BobRoss") + self.assertEqual(bob_cache_2._namespace, "BobRoss_") async def test_set_get_item(self): """Test that users can set and get items from the RedisDict.""" @@ -45,35 +63,35 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test that we can get and set different types. for test in test_cases: - await self.redis.set(*test) - self.assertEqual(await self.redis.get(test[0]), test[1]) + await self.cog.redis.set(*test) + self.assertEqual(await self.cog.redis.get(test[0]), test[1]) # Test that .get allows a default value - self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") async def test_set_item_type(self): """Test that .set rejects keys and values that are not strings, ints or floats.""" fruits = ["lemon", "melon", "apple"] with self.assertRaises(TypeError): - await self.redis.set(fruits, "nice") + await self.cog.redis.set(fruits, "nice") async def test_delete_item(self): """Test that .delete allows us to delete stuff from the RedisCache.""" # Add an item and verify that it gets added - await self.redis.set("internet", "firetruck") - self.assertEqual(await self.redis.get("internet"), "firetruck") + await self.cog.redis.set("internet", "firetruck") + self.assertEqual(await self.cog.redis.get("internet"), "firetruck") # Delete that item and verify that it gets deleted - await self.redis.delete("internet") - self.assertIs(await self.redis.get("internet"), None) + await self.cog.redis.delete("internet") + self.assertIs(await self.cog.redis.get("internet"), None) async def test_contains(self): """Test that we can check membership with .contains.""" - await self.redis.set('favorite_country', "Burkina Faso") + await self.cog.redis.set('favorite_country', "Burkina Faso") - self.assertIs(await self.redis.contains('favorite_country'), True) - self.assertIs(await self.redis.contains('favorite_dentist'), False) + self.assertIs(await self.cog.redis.contains('favorite_country'), True) + self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) async def test_items(self): """Test that the RedisDict can be iterated.""" @@ -84,10 +102,10 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): ('third_favorite_turtle', 'Raphael'), ] for key, value in test_cases: - await self.redis.set(key, value) + await self.cog.redis.set(key, value) # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.redis.items()] + redis_items = [item for item in await self.cog.redis.items()] # These sequences are probably in the same order now, but probably # isn't good enough for tests. Let's not rely on .hgetall always @@ -100,43 +118,43 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_length(self): """Test that we can get the correct .length from the RedisDict.""" - await self.redis.set('one', 1) - await self.redis.set('two', 2) - await self.redis.set('three', 3) - self.assertEqual(await self.redis.length(), 3) + await self.cog.redis.set('one', 1) + await self.cog.redis.set('two', 2) + await self.cog.redis.set('three', 3) + self.assertEqual(await self.cog.redis.length(), 3) - await self.redis.set('four', 4) - self.assertEqual(await self.redis.length(), 4) + await self.cog.redis.set('four', 4) + self.assertEqual(await self.cog.redis.length(), 4) async def test_to_dict(self): """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.redis.to_dict() - local_copy = {key: value for key, value in await self.redis.items()} + copy = await self.cog.redis.to_dict() + local_copy = {key: value for key, value in await self.cog.redis.items()} self.assertIs(type(copy), dict) self.assertDictEqual(copy, local_copy) async def test_clear(self): """Test that the .clear method removes the entire hash.""" - await self.redis.set('teddy', 'with me') - await self.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.redis.length(), 2) + await self.cog.redis.set('teddy', 'with me') + await self.cog.redis.set('in my dreams', 'you have a weird hat') + self.assertEqual(await self.cog.redis.length(), 2) - await self.redis.clear() - self.assertEqual(await self.redis.length(), 0) + await self.cog.redis.clear() + self.assertEqual(await self.cog.redis.length(), 0) async def test_pop(self): """Test that we can .pop an item from the RedisDict.""" - await self.redis.set('john', 'was afraid') + await self.cog.redis.set('john', 'was afraid') - self.assertEqual(await self.redis.pop('john'), 'was afraid') - self.assertEqual(await self.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.redis.length(), 0) + self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') + self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') + self.assertEqual(await self.cog.redis.length(), 0) async def test_update(self): """Test that we can .update the RedisDict with multiple items.""" - await self.redis.set("reckfried", "lona") - await self.redis.set("bel air", "prince") - await self.redis.update({ + await self.cog.redis.set("reckfried", "lona") + await self.cog.redis.set("bel air", "prince") + await self.cog.redis.update({ "reckfried": "jona", "mega": "hungry, though", }) @@ -146,7 +164,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): "bel air": "prince", "mega": "hungry, though", } - self.assertDictEqual(await self.redis.to_dict(), result) + self.assertDictEqual(await self.cog.redis.to_dict(), result) def test_typestring_conversion(self): """Test the typestring-related helper functions.""" @@ -158,58 +176,75 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test conversion to typestring for _input, expected in conversion_tests: - self.assertEqual(self.redis._to_typestring(_input), expected) + self.assertEqual(self.cog.redis._to_typestring(_input), expected) # Test conversion from typestrings for _input, expected in conversion_tests: - self.assertEqual(self.redis._from_typestring(expected), _input) + self.assertEqual(self.cog.redis._from_typestring(expected), _input) # Test that exceptions are raised on invalid input with self.assertRaises(TypeError): - self.redis._to_typestring(["internet"]) - self.redis._from_typestring("o|firedog") + self.cog.redis._to_typestring(["internet"]) + self.cog.redis._from_typestring("o|firedog") async def test_increment_decrement(self): """Test .increment and .decrement methods.""" - await self.redis.set("entropic", 5) - await self.redis.set("disentropic", 12.5) + await self.cog.redis.set("entropic", 5) + await self.cog.redis.set("disentropic", 12.5) # Test default increment - await self.redis.increment("entropic") - self.assertEqual(await self.redis.get("entropic"), 6) + await self.cog.redis.increment("entropic") + self.assertEqual(await self.cog.redis.get("entropic"), 6) # Test default decrement - await self.redis.decrement("entropic") - self.assertEqual(await self.redis.get("entropic"), 5) + await self.cog.redis.decrement("entropic") + self.assertEqual(await self.cog.redis.get("entropic"), 5) # Test float increment with float - await self.redis.increment("disentropic", 2.0) - self.assertEqual(await self.redis.get("disentropic"), 14.5) + await self.cog.redis.increment("disentropic", 2.0) + self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) # Test float increment with int - await self.redis.increment("disentropic", 2) - self.assertEqual(await self.redis.get("disentropic"), 16.5) + await self.cog.redis.increment("disentropic", 2) + self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) # Test negative increments, because why not. - await self.redis.increment("entropic", -5) - self.assertEqual(await self.redis.get("entropic"), 0) + await self.cog.redis.increment("entropic", -5) + self.assertEqual(await self.cog.redis.get("entropic"), 0) # Negative decrements? Sure. - await self.redis.decrement("entropic", -5) - self.assertEqual(await self.redis.get("entropic"), 5) + await self.cog.redis.decrement("entropic", -5) + self.assertEqual(await self.cog.redis.get("entropic"), 5) # What about if we use a negative float to decrement an int? # This should convert the type into a float. - await self.redis.decrement("entropic", -2.5) - self.assertEqual(await self.redis.get("entropic"), 7.5) + await self.cog.redis.decrement("entropic", -2.5) + self.assertEqual(await self.cog.redis.get("entropic"), 7.5) # Let's test that they raise the right errors with self.assertRaises(KeyError): - await self.redis.increment("doesn't_exist!") + await self.cog.redis.increment("doesn't_exist!") - await self.redis.set("stringthing", "stringthing") + await self.cog.redis.set("stringthing", "stringthing") with self.assertRaises(TypeError): - await self.redis.increment("stringthing") + await self.cog.redis.increment("stringthing") + + async def test_increment_lock(self): + """Test that we can't produce a race condition in .increment.""" + await self.cog.redis.set("test_key", 0) + tasks = [] + + # Increment this a lot in different tasks + for _ in range(100): + task = asyncio.create_task( + self.cog.redis.increment("test_key", 1) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Confirm that the value has been incremented the exact right number of times. + value = await self.cog.redis.get("test_key") + self.assertEqual(value, 100) async def test_exceptions_raised(self): """Testing that the various RuntimeErrors are reachable.""" |