aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/utils/test_redis_cache.py163
1 files changed, 99 insertions, 64 deletions
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
index 900a6d035..efd168dac 100644
--- a/tests/bot/utils/test_redis_cache.py
+++ b/tests/bot/utils/test_redis_cache.py
@@ -1,3 +1,4 @@
+import asyncio
import unittest
import fakeredis.aioredis
@@ -9,17 +10,30 @@ from tests import helpers
class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
"""Tests the RedisCache class from utils.redis_dict.py."""
- redis = RedisCache()
-
async def asyncSetUp(self): # noqa: N802
"""Sets up the objects that only have to be initialized once."""
self.bot = helpers.MockBot()
self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
- await self.redis.clear()
+
+ # Okay, so this is necessary so that we can create a clean new
+ # class for every test method, and we want that because it will
+ # ensure we get a fresh loop, which is necessary for test_increment_lock
+ # to be able to pass.
+ class DummyCog:
+ """A dummy cog, for dummies."""
+
+ redis = RedisCache()
+
+ def __init__(self, bot: helpers.MockBot):
+ self.bot = bot
+
+ self.cog = DummyCog(self.bot)
+
+ await self.cog.redis.clear()
def test_class_attribute_namespace(self):
"""Test that RedisDict creates a namespace automatically for class attributes."""
- self.assertEqual(self.redis._namespace, "RedisCacheTests.redis")
+ self.assertEqual(self.cog.redis._namespace, "DummyCog.redis")
async def test_class_attribute_required(self):
"""Test that errors are raised when not assigned as a class attribute."""
@@ -31,9 +45,13 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
def test_namespace_collision(self):
"""Test that we prevent colliding namespaces."""
- bad_cache = RedisCache()
- bad_cache._set_namespace("RedisCacheTests.redis")
- self.assertEqual(bad_cache._namespace, "RedisCacheTests.redis_")
+ bob_cache_1 = RedisCache()
+ bob_cache_1._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_1._namespace, "BobRoss")
+
+ bob_cache_2 = RedisCache()
+ bob_cache_2._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_2._namespace, "BobRoss_")
async def test_set_get_item(self):
"""Test that users can set and get items from the RedisDict."""
@@ -45,35 +63,35 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
# Test that we can get and set different types.
for test in test_cases:
- await self.redis.set(*test)
- self.assertEqual(await self.redis.get(test[0]), test[1])
+ await self.cog.redis.set(*test)
+ self.assertEqual(await self.cog.redis.get(test[0]), test[1])
# Test that .get allows a default value
- self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+ self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
async def test_set_item_type(self):
"""Test that .set rejects keys and values that are not strings, ints or floats."""
fruits = ["lemon", "melon", "apple"]
with self.assertRaises(TypeError):
- await self.redis.set(fruits, "nice")
+ await self.cog.redis.set(fruits, "nice")
async def test_delete_item(self):
"""Test that .delete allows us to delete stuff from the RedisCache."""
# Add an item and verify that it gets added
- await self.redis.set("internet", "firetruck")
- self.assertEqual(await self.redis.get("internet"), "firetruck")
+ await self.cog.redis.set("internet", "firetruck")
+ self.assertEqual(await self.cog.redis.get("internet"), "firetruck")
# Delete that item and verify that it gets deleted
- await self.redis.delete("internet")
- self.assertIs(await self.redis.get("internet"), None)
+ await self.cog.redis.delete("internet")
+ self.assertIs(await self.cog.redis.get("internet"), None)
async def test_contains(self):
"""Test that we can check membership with .contains."""
- await self.redis.set('favorite_country', "Burkina Faso")
+ await self.cog.redis.set('favorite_country', "Burkina Faso")
- self.assertIs(await self.redis.contains('favorite_country'), True)
- self.assertIs(await self.redis.contains('favorite_dentist'), False)
+ self.assertIs(await self.cog.redis.contains('favorite_country'), True)
+ self.assertIs(await self.cog.redis.contains('favorite_dentist'), False)
async def test_items(self):
"""Test that the RedisDict can be iterated."""
@@ -84,10 +102,10 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
('third_favorite_turtle', 'Raphael'),
]
for key, value in test_cases:
- await self.redis.set(key, value)
+ await self.cog.redis.set(key, value)
# Consume the AsyncIterator into a regular list, easier to compare that way.
- redis_items = [item for item in await self.redis.items()]
+ redis_items = [item for item in await self.cog.redis.items()]
# These sequences are probably in the same order now, but probably
# isn't good enough for tests. Let's not rely on .hgetall always
@@ -100,43 +118,43 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
async def test_length(self):
"""Test that we can get the correct .length from the RedisDict."""
- await self.redis.set('one', 1)
- await self.redis.set('two', 2)
- await self.redis.set('three', 3)
- self.assertEqual(await self.redis.length(), 3)
+ await self.cog.redis.set('one', 1)
+ await self.cog.redis.set('two', 2)
+ await self.cog.redis.set('three', 3)
+ self.assertEqual(await self.cog.redis.length(), 3)
- await self.redis.set('four', 4)
- self.assertEqual(await self.redis.length(), 4)
+ await self.cog.redis.set('four', 4)
+ self.assertEqual(await self.cog.redis.length(), 4)
async def test_to_dict(self):
"""Test that the .to_dict method returns a workable dictionary copy."""
- copy = await self.redis.to_dict()
- local_copy = {key: value for key, value in await self.redis.items()}
+ copy = await self.cog.redis.to_dict()
+ local_copy = {key: value for key, value in await self.cog.redis.items()}
self.assertIs(type(copy), dict)
self.assertDictEqual(copy, local_copy)
async def test_clear(self):
"""Test that the .clear method removes the entire hash."""
- await self.redis.set('teddy', 'with me')
- await self.redis.set('in my dreams', 'you have a weird hat')
- self.assertEqual(await self.redis.length(), 2)
+ await self.cog.redis.set('teddy', 'with me')
+ await self.cog.redis.set('in my dreams', 'you have a weird hat')
+ self.assertEqual(await self.cog.redis.length(), 2)
- await self.redis.clear()
- self.assertEqual(await self.redis.length(), 0)
+ await self.cog.redis.clear()
+ self.assertEqual(await self.cog.redis.length(), 0)
async def test_pop(self):
"""Test that we can .pop an item from the RedisDict."""
- await self.redis.set('john', 'was afraid')
+ await self.cog.redis.set('john', 'was afraid')
- self.assertEqual(await self.redis.pop('john'), 'was afraid')
- self.assertEqual(await self.redis.pop('pete', 'breakneck'), 'breakneck')
- self.assertEqual(await self.redis.length(), 0)
+ self.assertEqual(await self.cog.redis.pop('john'), 'was afraid')
+ self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck')
+ self.assertEqual(await self.cog.redis.length(), 0)
async def test_update(self):
"""Test that we can .update the RedisDict with multiple items."""
- await self.redis.set("reckfried", "lona")
- await self.redis.set("bel air", "prince")
- await self.redis.update({
+ await self.cog.redis.set("reckfried", "lona")
+ await self.cog.redis.set("bel air", "prince")
+ await self.cog.redis.update({
"reckfried": "jona",
"mega": "hungry, though",
})
@@ -146,7 +164,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
"bel air": "prince",
"mega": "hungry, though",
}
- self.assertDictEqual(await self.redis.to_dict(), result)
+ self.assertDictEqual(await self.cog.redis.to_dict(), result)
def test_typestring_conversion(self):
"""Test the typestring-related helper functions."""
@@ -158,58 +176,75 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
# Test conversion to typestring
for _input, expected in conversion_tests:
- self.assertEqual(self.redis._to_typestring(_input), expected)
+ self.assertEqual(self.cog.redis._to_typestring(_input), expected)
# Test conversion from typestrings
for _input, expected in conversion_tests:
- self.assertEqual(self.redis._from_typestring(expected), _input)
+ self.assertEqual(self.cog.redis._from_typestring(expected), _input)
# Test that exceptions are raised on invalid input
with self.assertRaises(TypeError):
- self.redis._to_typestring(["internet"])
- self.redis._from_typestring("o|firedog")
+ self.cog.redis._to_typestring(["internet"])
+ self.cog.redis._from_typestring("o|firedog")
async def test_increment_decrement(self):
"""Test .increment and .decrement methods."""
- await self.redis.set("entropic", 5)
- await self.redis.set("disentropic", 12.5)
+ await self.cog.redis.set("entropic", 5)
+ await self.cog.redis.set("disentropic", 12.5)
# Test default increment
- await self.redis.increment("entropic")
- self.assertEqual(await self.redis.get("entropic"), 6)
+ await self.cog.redis.increment("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 6)
# Test default decrement
- await self.redis.decrement("entropic")
- self.assertEqual(await self.redis.get("entropic"), 5)
+ await self.cog.redis.decrement("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
# Test float increment with float
- await self.redis.increment("disentropic", 2.0)
- self.assertEqual(await self.redis.get("disentropic"), 14.5)
+ await self.cog.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 14.5)
# Test float increment with int
- await self.redis.increment("disentropic", 2)
- self.assertEqual(await self.redis.get("disentropic"), 16.5)
+ await self.cog.redis.increment("disentropic", 2)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 16.5)
# Test negative increments, because why not.
- await self.redis.increment("entropic", -5)
- self.assertEqual(await self.redis.get("entropic"), 0)
+ await self.cog.redis.increment("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 0)
# Negative decrements? Sure.
- await self.redis.decrement("entropic", -5)
- self.assertEqual(await self.redis.get("entropic"), 5)
+ await self.cog.redis.decrement("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
# What about if we use a negative float to decrement an int?
# This should convert the type into a float.
- await self.redis.decrement("entropic", -2.5)
- self.assertEqual(await self.redis.get("entropic"), 7.5)
+ await self.cog.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 7.5)
# Let's test that they raise the right errors
with self.assertRaises(KeyError):
- await self.redis.increment("doesn't_exist!")
+ await self.cog.redis.increment("doesn't_exist!")
- await self.redis.set("stringthing", "stringthing")
+ await self.cog.redis.set("stringthing", "stringthing")
with self.assertRaises(TypeError):
- await self.redis.increment("stringthing")
+ await self.cog.redis.increment("stringthing")
+
+ async def test_increment_lock(self):
+ """Test that we can't produce a race condition in .increment."""
+ await self.cog.redis.set("test_key", 0)
+ tasks = []
+
+ # Increment this a lot in different tasks
+ for _ in range(100):
+ task = asyncio.create_task(
+ self.cog.redis.increment("test_key", 1)
+ )
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ # Confirm that the value has been incremented the exact right number of times.
+ value = await self.cog.redis.get("test_key")
+ self.assertEqual(value, 100)
async def test_exceptions_raised(self):
"""Testing that the various RuntimeErrors are reachable."""