aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-05-27 12:51:40 +0200
committerGravatar Leon Sandøy <[email protected]>2020-05-27 12:51:40 +0200
commitb18930735e05e09ba615cb54fe1dbdfd41bb0f81 (patch)
tree6d9289baf36b5a93955e11a9405ee7159c1cadb7
parentClear cache in asyncSetUp instead of tests. (diff)
Refactor .increment and add lock test.
The way we were doing the asyncio.Lock() stuff for increment was slightly problematic. @aeros has adviced us that it's better to just initialize the lock as None in __init__, and then initialize it inside the first coroutine that uses it instead. This ensures that the correct loop gets attached to the lock, so we don't end up getting errors like this one: RuntimeError: got Future <Future pending> attached to a different loop This happens because the lock and the actual calling coroutines aren't on the same loop. When creating a new test, test_increment_lock, we discovered that we needed a small refactor here and also in the test class to make this new test pass. So, now we're creating a DummyCog for every test method, and this will ensure the loop streams never cross. Cause we all know we must never cross the streams.
-rw-r--r--bot/utils/redis_cache.py11
-rw-r--r--tests/bot/utils/test_redis_cache.py163
2 files changed, 109 insertions, 65 deletions
diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py
index 895a12da4..33e5d5852 100644
--- a/bot/utils/redis_cache.py
+++ b/bot/utils/redis_cache.py
@@ -81,7 +81,7 @@ class RedisCache:
"""Initialize the RedisCache."""
self._namespace = None
self.bot = None
- self._increment_lock = asyncio.Lock()
+ self._increment_lock = None
def _set_namespace(self, namespace: str) -> None:
"""Try to set the namespace, but do not permit collisions."""
@@ -345,6 +345,15 @@ class RedisCache:
"""
log.trace(f"Attempting to increment/decrement the value with the key {key} by {amount}.")
+ # We initialize the lock here, because we need to ensure we get it
+ # running on the same loop as the calling coroutine.
+ #
+ # If we initialized the lock in the __init__, the loop that the coroutine this method
+ # would be called from might not exist yet, and so the lock would be on a different
+ # loop, which would raise RuntimeErrors.
+ if self._increment_lock is None:
+ self._increment_lock = asyncio.Lock()
+
# Since this has several API calls, we need a lock to prevent race conditions
async with self._increment_lock:
value = await self.get(key)
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 900a6d035..efd168dac 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -1,3 +1,4 @@
+import asyncio
import unittest
import fakeredis.aioredis
@@ -9,17 +10,30 @@ from tests import helpers
class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
"""Tests the RedisCache class from utils.redis_dict.py."""
- redis = RedisCache()
-
async def asyncSetUp(self): # noqa: N802
"""Sets up the objects that only have to be initialized once."""
self.bot = helpers.MockBot()
self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
- await self.redis.clear()
+
+ # Okay, so this is necessary so that we can create a clean new
+ # class for every test method, and we want that because it will
+ # ensure we get a fresh loop, which is necessary for test_increment_lock
+ # to be able to pass.
+ class DummyCog:
+ """A dummy cog, for dummies."""
+
+ redis = RedisCache()
+
+ def __init__(self, bot: helpers.MockBot):
+ self.bot = bot
+
+ self.cog = DummyCog(self.bot)
+
+ await self.cog.redis.clear()
def test_class_attribute_namespace(self):
"""Test that RedisDict creates a namespace automatically for class attributes."""
- self.assertEqual(self.redis._namespace, "RedisCacheTests.redis")
+ self.assertEqual(self.cog.redis._namespace, "DummyCog.redis")
async def test_class_attribute_required(self):
"""Test that errors are raised when not assigned as a class attribute."""
@@ -31,9 +45,13 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
def test_namespace_collision(self):
"""Test that we prevent colliding namespaces."""
- bad_cache = RedisCache()
- bad_cache._set_namespace("RedisCacheTests.redis")
- self.assertEqual(bad_cache._namespace, "RedisCacheTests.redis_")
+ bob_cache_1 = RedisCache()
+ bob_cache_1._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_1._namespace, "BobRoss")
+
+ bob_cache_2 = RedisCache()
+ bob_cache_2._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_2._namespace, "BobRoss_")
async def test_set_get_item(self):
"""Test that users can set and get items from the RedisDict."""
@@ -45,35 +63,35 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
# Test that we can get and set different types.
for test in test_cases:
- await self.redis.set(*test)
- self.assertEqual(await self.redis.get(test[0]), test[1])
+ await self.cog.redis.set(*test)
+ self.assertEqual(await self.cog.redis.get(test[0]), test[1])
# Test that .get allows a default value
- self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+ self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
async def test_set_item_type(self):
"""Test that .set rejects keys and values that are not strings, ints or floats."""
fruits = ["lemon", "melon", "apple"]
with self.assertRaises(TypeError):
- await self.redis.set(fruits, "nice")
+ await self.cog.redis.set(fruits, "nice")
async def test_delete_item(self):
"""Test that .delete allows us to delete stuff from the RedisCache."""
# Add an item and verify that it gets added
- await self.redis.set("internet", "firetruck")
- self.assertEqual(await self.redis.get("internet"), "firetruck")
+ await self.cog.redis.set("internet", "firetruck")
+ self.assertEqual(await self.cog.redis.get("internet"), "firetruck")
# Delete that item and verify that it gets deleted
- await self.redis.delete("internet")
- self.assertIs(await self.redis.get("internet"), None)
+ await self.cog.redis.delete("internet")
+ self.assertIs(await self.cog.redis.get("internet"), None)
async def test_contains(self):
"""Test that we can check membership with .contains."""
- await self.redis.set('favorite_country', "Burkina Faso")
+ await self.cog.redis.set('favorite_country', "Burkina Faso")
- self.assertIs(await self.redis.contains('favorite_country'), True)
- self.assertIs(await self.redis.contains('favorite_dentist'), False)
+ self.assertIs(await self.cog.redis.contains('favorite_country'), True)
+ self.assertIs(await self.cog.redis.contains('favorite_dentist'), False)
async def test_items(self):
"""Test that the RedisDict can be iterated."""
@@ -84,10 +102,10 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
('third_favorite_turtle', 'Raphael'),
]
for key, value in test_cases:
- await self.redis.set(key, value)
+ await self.cog.redis.set(key, value)
# Consume the AsyncIterator into a regular list, easier to compare that way.
- redis_items = [item for item in await self.redis.items()]
+ redis_items = [item for item in await self.cog.redis.items()]
# These sequences are probably in the same order now, but probably
# isn't good enough for tests. Let's not rely on .hgetall always
@@ -100,43 +118,43 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
async def test_length(self):
"""Test that we can get the correct .length from the RedisDict."""
- await self.redis.set('one', 1)
- await self.redis.set('two', 2)
- await self.redis.set('three', 3)
- self.assertEqual(await self.redis.length(), 3)
+ await self.cog.redis.set('one', 1)
+ await self.cog.redis.set('two', 2)
+ await self.cog.redis.set('three', 3)
+ self.assertEqual(await self.cog.redis.length(), 3)
- await self.redis.set('four', 4)
- self.assertEqual(await self.redis.length(), 4)
+ await self.cog.redis.set('four', 4)
+ self.assertEqual(await self.cog.redis.length(), 4)
async def test_to_dict(self):
"""Test that the .to_dict method returns a workable dictionary copy."""
- copy = await self.redis.to_dict()
- local_copy = {key: value for key, value in await self.redis.items()}
+ copy = await self.cog.redis.to_dict()
+ local_copy = {key: value for key, value in await self.cog.redis.items()}
self.assertIs(type(copy), dict)
self.assertDictEqual(copy, local_copy)
async def test_clear(self):
"""Test that the .clear method removes the entire hash."""
- await self.redis.set('teddy', 'with me')
- await self.redis.set('in my dreams', 'you have a weird hat')
- self.assertEqual(await self.redis.length(), 2)
+ await self.cog.redis.set('teddy', 'with me')
+ await self.cog.redis.set('in my dreams', 'you have a weird hat')
+ self.assertEqual(await self.cog.redis.length(), 2)
- await self.redis.clear()
- self.assertEqual(await self.redis.length(), 0)
+ await self.cog.redis.clear()
+ self.assertEqual(await self.cog.redis.length(), 0)
async def test_pop(self):
"""Test that we can .pop an item from the RedisDict."""
- await self.redis.set('john', 'was afraid')
+ await self.cog.redis.set('john', 'was afraid')
- self.assertEqual(await self.redis.pop('john'), 'was afraid')
- self.assertEqual(await self.redis.pop('pete', 'breakneck'), 'breakneck')
- self.assertEqual(await self.redis.length(), 0)
+ self.assertEqual(await self.cog.redis.pop('john'), 'was afraid')
+ self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck')
+ self.assertEqual(await self.cog.redis.length(), 0)
async def test_update(self):
"""Test that we can .update the RedisDict with multiple items."""
- await self.redis.set("reckfried", "lona")
- await self.redis.set("bel air", "prince")
- await self.redis.update({
+ await self.cog.redis.set("reckfried", "lona")
+ await self.cog.redis.set("bel air", "prince")
+ await self.cog.redis.update({
"reckfried": "jona",
"mega": "hungry, though",
})
@@ -146,7 +164,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
"bel air": "prince",
"mega": "hungry, though",
}
- self.assertDictEqual(await self.redis.to_dict(), result)
+ self.assertDictEqual(await self.cog.redis.to_dict(), result)
def test_typestring_conversion(self):
"""Test the typestring-related helper functions."""
@@ -158,58 +176,75 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
# Test conversion to typestring
for _input, expected in conversion_tests:
- self.assertEqual(self.redis._to_typestring(_input), expected)
+ self.assertEqual(self.cog.redis._to_typestring(_input), expected)
# Test conversion from typestrings
for _input, expected in conversion_tests:
- self.assertEqual(self.redis._from_typestring(expected), _input)
+ self.assertEqual(self.cog.redis._from_typestring(expected), _input)
# Test that exceptions are raised on invalid input
with self.assertRaises(TypeError):
- self.redis._to_typestring(["internet"])
- self.redis._from_typestring("o|firedog")
+ self.cog.redis._to_typestring(["internet"])
+ self.cog.redis._from_typestring("o|firedog")
async def test_increment_decrement(self):
"""Test .increment and .decrement methods."""
- await self.redis.set("entropic", 5)
- await self.redis.set("disentropic", 12.5)
+ await self.cog.redis.set("entropic", 5)
+ await self.cog.redis.set("disentropic", 12.5)
# Test default increment
- await self.redis.increment("entropic")
- self.assertEqual(await self.redis.get("entropic"), 6)
+ await self.cog.redis.increment("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 6)
# Test default decrement
- await self.redis.decrement("entropic")
- self.assertEqual(await self.redis.get("entropic"), 5)
+ await self.cog.redis.decrement("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
# Test float increment with float
- await self.redis.increment("disentropic", 2.0)
- self.assertEqual(await self.redis.get("disentropic"), 14.5)
+ await self.cog.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 14.5)
# Test float increment with int
- await self.redis.increment("disentropic", 2)
- self.assertEqual(await self.redis.get("disentropic"), 16.5)
+ await self.cog.redis.increment("disentropic", 2)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 16.5)
# Test negative increments, because why not.
- await self.redis.increment("entropic", -5)
- self.assertEqual(await self.redis.get("entropic"), 0)
+ await self.cog.redis.increment("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 0)
# Negative decrements? Sure.
- await self.redis.decrement("entropic", -5)
- self.assertEqual(await self.redis.get("entropic"), 5)
+ await self.cog.redis.decrement("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
# What about if we use a negative float to decrement an int?
# This should convert the type into a float.
- await self.redis.decrement("entropic", -2.5)
- self.assertEqual(await self.redis.get("entropic"), 7.5)
+ await self.cog.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 7.5)
# Let's test that they raise the right errors
with self.assertRaises(KeyError):
- await self.redis.increment("doesn't_exist!")
+ await self.cog.redis.increment("doesn't_exist!")
- await self.redis.set("stringthing", "stringthing")
+ await self.cog.redis.set("stringthing", "stringthing")
with self.assertRaises(TypeError):
- await self.redis.increment("stringthing")
+ await self.cog.redis.increment("stringthing")
+
+ async def test_increment_lock(self):
+ """Test that we can't produce a race condition in .increment."""
+ await self.cog.redis.set("test_key", 0)
+ tasks = []
+
+ # Increment this a lot in different tasks
+ for _ in range(100):
+ task = asyncio.create_task(
+ self.cog.redis.increment("test_key", 1)
+ )
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ # Confirm that the value has been incremented the exact right number of times.
+ value = await self.cog.redis.get("test_key")
+ self.assertEqual(value, 100)
async def test_exceptions_raised(self):
"""Testing that the various RuntimeErrors are reachable."""