diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/utils/test_redis_cache.py | 265 | ||||
-rw-r--r-- | tests/helpers.py | 6 |
2 files changed, 5 insertions, 266 deletions
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py deleted file mode 100644 index a2f0fe55d..000000000 --- a/tests/bot/utils/test_redis_cache.py +++ /dev/null @@ -1,265 +0,0 @@ -import asyncio -import unittest - -import fakeredis.aioredis - -from bot.utils import RedisCache -from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError -from tests import helpers - - -class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisCache class from utils.redis_dict.py.""" - - async def asyncSetUp(self): # noqa: N802 - """Sets up the objects that only have to be initialized once.""" - self.bot = helpers.MockBot() - self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - - # Okay, so this is necessary so that we can create a clean new - # class for every test method, and we want that because it will - # ensure we get a fresh loop, which is necessary for test_increment_lock - # to be able to pass. - class DummyCog: - """A dummy cog, for dummies.""" - - redis = RedisCache() - - def __init__(self, bot: helpers.MockBot): - self.bot = bot - - self.cog = DummyCog(self.bot) - - await self.cog.redis.clear() - - def test_class_attribute_namespace(self): - """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") - - async def test_class_attribute_required(self): - """Test that errors are raised when not assigned as a class attribute.""" - bad_cache = RedisCache() - self.assertIs(bad_cache._namespace, None) - - with self.assertRaises(RuntimeError): - await bad_cache.set("test", "me_up_deadman") - - async def test_set_get_item(self): - """Test that users can set and get items from the RedisDict.""" - test_cases = ( - ('favorite_fruit', 'melon'), - ('favorite_number', 86), - ('favorite_fraction', 86.54), - ('favorite_boolean', False), - ('other_boolean', True), - ) - - # Test that we can get and set different types. - for test in test_cases: - await self.cog.redis.set(*test) - self.assertEqual(await self.cog.redis.get(test[0]), test[1]) - - # Test that .get allows a default value - self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") - - async def test_set_item_type(self): - """Test that .set rejects keys and values that are not permitted.""" - fruits = ["lemon", "melon", "apple"] - - with self.assertRaises(TypeError): - await self.cog.redis.set(fruits, "nice") - - with self.assertRaises(TypeError): - await self.cog.redis.set(4.23, "nice") - - async def test_delete_item(self): - """Test that .delete allows us to delete stuff from the RedisCache.""" - # Add an item and verify that it gets added - await self.cog.redis.set("internet", "firetruck") - self.assertEqual(await self.cog.redis.get("internet"), "firetruck") - - # Delete that item and verify that it gets deleted - await self.cog.redis.delete("internet") - self.assertIs(await self.cog.redis.get("internet"), None) - - async def test_contains(self): - """Test that we can check membership with .contains.""" - await self.cog.redis.set('favorite_country', "Burkina Faso") - - self.assertIs(await self.cog.redis.contains('favorite_country'), True) - self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) - - async def test_items(self): - """Test that the RedisDict can be iterated.""" - # Set up our test cases in the Redis cache - test_cases = [ - ('favorite_turtle', 'Donatello'), - ('second_favorite_turtle', 'Leonardo'), - ('third_favorite_turtle', 'Raphael'), - ] - for key, value in test_cases: - await self.cog.redis.set(key, value) - - # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.cog.redis.items()] - - # These sequences are probably in the same order now, but probably - # isn't good enough for tests. Let's not rely on .hgetall always - # returning things in sequence, and just sort both lists to be safe. - redis_items = sorted(redis_items) - test_cases = sorted(test_cases) - - # If these are equal now, everything works fine. - self.assertSequenceEqual(test_cases, redis_items) - - async def test_length(self): - """Test that we can get the correct .length from the RedisDict.""" - await self.cog.redis.set('one', 1) - await self.cog.redis.set('two', 2) - await self.cog.redis.set('three', 3) - self.assertEqual(await self.cog.redis.length(), 3) - - await self.cog.redis.set('four', 4) - self.assertEqual(await self.cog.redis.length(), 4) - - async def test_to_dict(self): - """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.cog.redis.to_dict() - local_copy = {key: value for key, value in await self.cog.redis.items()} - self.assertIs(type(copy), dict) - self.assertDictEqual(copy, local_copy) - - async def test_clear(self): - """Test that the .clear method removes the entire hash.""" - await self.cog.redis.set('teddy', 'with me') - await self.cog.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.cog.redis.length(), 2) - - await self.cog.redis.clear() - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_pop(self): - """Test that we can .pop an item from the RedisDict.""" - await self.cog.redis.set('john', 'was afraid') - - self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') - self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_update(self): - """Test that we can .update the RedisDict with multiple items.""" - await self.cog.redis.set("reckfried", "lona") - await self.cog.redis.set("bel air", "prince") - await self.cog.redis.update({ - "reckfried": "jona", - "mega": "hungry, though", - }) - - result = { - "reckfried": "jona", - "bel air": "prince", - "mega": "hungry, though", - } - self.assertDictEqual(await self.cog.redis.to_dict(), result) - - def test_typestring_conversion(self): - """Test the typestring-related helper functions.""" - conversion_tests = ( - (12, "i|12"), - (12.4, "f|12.4"), - ("cowabunga", "s|cowabunga"), - ) - - # Test conversion to typestring - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) - - # Test conversion from typestrings - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) - - # Test that exceptions are raised on invalid input - with self.assertRaises(TypeError): - self.cog.redis._value_to_typestring(["internet"]) - self.cog.redis._value_from_typestring("o|firedog") - - async def test_increment_decrement(self): - """Test .increment and .decrement methods.""" - await self.cog.redis.set("entropic", 5) - await self.cog.redis.set("disentropic", 12.5) - - # Test default increment - await self.cog.redis.increment("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 6) - - # Test default decrement - await self.cog.redis.decrement("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # Test float increment with float - await self.cog.redis.increment("disentropic", 2.0) - self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) - - # Test float increment with int - await self.cog.redis.increment("disentropic", 2) - self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) - - # Test negative increments, because why not. - await self.cog.redis.increment("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 0) - - # Negative decrements? Sure. - await self.cog.redis.decrement("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # What about if we use a negative float to decrement an int? - # This should convert the type into a float. - await self.cog.redis.decrement("entropic", -2.5) - self.assertEqual(await self.cog.redis.get("entropic"), 7.5) - - # Let's test that they raise the right errors - with self.assertRaises(KeyError): - await self.cog.redis.increment("doesn't_exist!") - - await self.cog.redis.set("stringthing", "stringthing") - with self.assertRaises(TypeError): - await self.cog.redis.increment("stringthing") - - async def test_increment_lock(self): - """Test that we can't produce a race condition in .increment.""" - await self.cog.redis.set("test_key", 0) - tasks = [] - - # Increment this a lot in different tasks - for _ in range(100): - task = asyncio.create_task( - self.cog.redis.increment("test_key", 1) - ) - tasks.append(task) - await asyncio.gather(*tasks) - - # Confirm that the value has been incremented the exact right number of times. - value = await self.cog.redis.get("test_key") - self.assertEqual(value, 100) - - async def test_exceptions_raised(self): - """Testing that the various RuntimeErrors are reachable.""" - class MyCog: - cache = RedisCache() - - def __init__(self): - self.other_cache = RedisCache() - - cog = MyCog() - - # Raises "No Bot instance" - with self.assertRaises(NoBotInstanceError): - await cog.cache.get("john") - - # Raises "RedisCache has no namespace" - with self.assertRaises(NoNamespaceError): - await cog.other_cache.get("was") - - # Raises "You must access the RedisCache instance through the cog instance" - with self.assertRaises(NoParentInstanceError): - await MyCog.cache.get("afraid") diff --git a/tests/helpers.py b/tests/helpers.py index facc4e1af..e47fdf28f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -308,7 +308,11 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ - spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) + spec_set = Bot( + command_prefix=unittest.mock.MagicMock(), + loop=_get_mock_loop(), + redis_session=unittest.mock.MagicMock(), + ) additional_spec_asyncs = ("wait_for", "redis_ready") def __init__(self, **kwargs) -> None: |