diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 2 | ||||
| -rw-r--r-- | tests/bot/cogs/test_snekbox.py | 15 | ||||
| -rw-r--r-- | tests/bot/test_constants.py | 43 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 273 | ||||
| -rw-r--r-- | tests/helpers.py | 30 | 
5 files changed, 342 insertions, 21 deletions
| diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 7e6bfc748..a8c0107c6 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          self.assertEqual(cog.bot, bot)          self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) -        bot.loop.create_loop.called_once_with(cog.fetch_webhook()) +        bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())      def test_fetch_webhook_succeeds_without_connectivity_issues(self):          """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 14299e766..cf9adbee0 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Post the eval code to the URLs.snekbox_eval_api endpoint."""          resp = MagicMock()          resp.json = AsyncMock(return_value="return") -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(await self.cog.post_eval("import random"), "return")          self.bot.http_session.post.assert_called_with( @@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          key = "MarkDiamond"          resp = MagicMock()          resp.json = AsyncMock(return_value={"key": key}) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(              await self.cog.upload_output("My awesome output"), @@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Output upload gracefully fallback if the upload fail."""          resp = MagicMock()          resp.json = AsyncMock(side_effect=Exception) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          log = logging.getLogger("bot.cogs.snekbox")          with self.assertLogs(logger=log, level='ERROR'): diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index dae7c066c..f10d6fbe8 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,14 +1,40 @@  import inspect +import typing  import unittest  from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: +    """ +    Return True if `value` is an instance of the type represented by `annotation`. + +    This doesn't account for things like Unions or checking for homogenous types in collections. +    """ +    origin = typing.get_origin(annotation) + +    # This is done in case a bare e.g. `typing.List` is used. +    # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. +    # `get_origin()` does this normalisation for us. +    type_ = annotation if origin is None else origin + +    return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: +    """Return True if `value` is an instance of any type in `types`.""" +    for type_ in types: +        if is_annotation_instance(value, type_): +            return True + +    return False + +  class ConstantsTests(unittest.TestCase):      """Tests for our constants."""      def test_section_configuration_matches_type_specification(self): -        """The section annotations should match the actual types of the sections.""" +        """"The section annotations should match the actual types of the sections."""          sections = (              cls @@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):          )          for section in sections:              for name, annotation in section.__annotations__.items(): -                with self.subTest(section=section, name=name, annotation=annotation): +                with self.subTest(section=section.__name__, name=name, annotation=annotation):                      value = getattr(section, name) +                    origin = typing.get_origin(annotation) +                    annotation_args = typing.get_args(annotation) +                    failure_msg = f"{value} is not an instance of {annotation}" -                    if getattr(annotation, '_name', None) in ('Dict', 'List'): -                        self.skipTest("Cannot validate containers yet.") - -                    self.assertIsInstance(value, annotation) +                    if origin is typing.Union: +                        is_instance = is_any_instance(value, annotation_args) +                        self.assertTrue(is_instance, failure_msg) +                    else: +                        is_instance = is_annotation_instance(value, annotation) +                        self.assertTrue(is_instance, failure_msg) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py new file mode 100644 index 000000000..8c1a40640 --- /dev/null +++ b/tests/bot/utils/test_redis_cache.py @@ -0,0 +1,273 @@ +import asyncio +import unittest + +import fakeredis.aioredis + +from bot.utils import RedisCache +from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError +from tests import helpers + + +class RedisCacheTests(unittest.IsolatedAsyncioTestCase): +    """Tests the RedisCache class from utils.redis_dict.py.""" + +    async def asyncSetUp(self):  # noqa: N802 +        """Sets up the objects that only have to be initialized once.""" +        self.bot = helpers.MockBot() +        self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() + +        # Okay, so this is necessary so that we can create a clean new +        # class for every test method, and we want that because it will +        # ensure we get a fresh loop, which is necessary for test_increment_lock +        # to be able to pass. +        class DummyCog: +            """A dummy cog, for dummies.""" + +            redis = RedisCache() + +            def __init__(self, bot: helpers.MockBot): +                self.bot = bot + +        self.cog = DummyCog(self.bot) + +        await self.cog.redis.clear() + +    def test_class_attribute_namespace(self): +        """Test that RedisDict creates a namespace automatically for class attributes.""" +        self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") + +    async def test_class_attribute_required(self): +        """Test that errors are raised when not assigned as a class attribute.""" +        bad_cache = RedisCache() +        self.assertIs(bad_cache._namespace, None) + +        with self.assertRaises(RuntimeError): +            await bad_cache.set("test", "me_up_deadman") + +    def test_namespace_collision(self): +        """Test that we prevent colliding namespaces.""" +        bob_cache_1 = RedisCache() +        bob_cache_1._set_namespace("BobRoss") +        self.assertEqual(bob_cache_1._namespace, "BobRoss") + +        bob_cache_2 = RedisCache() +        bob_cache_2._set_namespace("BobRoss") +        self.assertEqual(bob_cache_2._namespace, "BobRoss_") + +    async def test_set_get_item(self): +        """Test that users can set and get items from the RedisDict.""" +        test_cases = ( +            ('favorite_fruit', 'melon'), +            ('favorite_number', 86), +            ('favorite_fraction', 86.54) +        ) + +        # Test that we can get and set different types. +        for test in test_cases: +            await self.cog.redis.set(*test) +            self.assertEqual(await self.cog.redis.get(test[0]), test[1]) + +        # Test that .get allows a default value +        self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + +    async def test_set_item_type(self): +        """Test that .set rejects keys and values that are not permitted.""" +        fruits = ["lemon", "melon", "apple"] + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(fruits, "nice") + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(4.23, "nice") + +    async def test_delete_item(self): +        """Test that .delete allows us to delete stuff from the RedisCache.""" +        # Add an item and verify that it gets added +        await self.cog.redis.set("internet", "firetruck") +        self.assertEqual(await self.cog.redis.get("internet"), "firetruck") + +        # Delete that item and verify that it gets deleted +        await self.cog.redis.delete("internet") +        self.assertIs(await self.cog.redis.get("internet"), None) + +    async def test_contains(self): +        """Test that we can check membership with .contains.""" +        await self.cog.redis.set('favorite_country', "Burkina Faso") + +        self.assertIs(await self.cog.redis.contains('favorite_country'), True) +        self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) + +    async def test_items(self): +        """Test that the RedisDict can be iterated.""" +        # Set up our test cases in the Redis cache +        test_cases = [ +            ('favorite_turtle', 'Donatello'), +            ('second_favorite_turtle', 'Leonardo'), +            ('third_favorite_turtle', 'Raphael'), +        ] +        for key, value in test_cases: +            await self.cog.redis.set(key, value) + +        # Consume the AsyncIterator into a regular list, easier to compare that way. +        redis_items = [item for item in await self.cog.redis.items()] + +        # These sequences are probably in the same order now, but probably +        # isn't good enough for tests. Let's not rely on .hgetall always +        # returning things in sequence, and just sort both lists to be safe. +        redis_items = sorted(redis_items) +        test_cases = sorted(test_cases) + +        # If these are equal now, everything works fine. +        self.assertSequenceEqual(test_cases, redis_items) + +    async def test_length(self): +        """Test that we can get the correct .length from the RedisDict.""" +        await self.cog.redis.set('one', 1) +        await self.cog.redis.set('two', 2) +        await self.cog.redis.set('three', 3) +        self.assertEqual(await self.cog.redis.length(), 3) + +        await self.cog.redis.set('four', 4) +        self.assertEqual(await self.cog.redis.length(), 4) + +    async def test_to_dict(self): +        """Test that the .to_dict method returns a workable dictionary copy.""" +        copy = await self.cog.redis.to_dict() +        local_copy = {key: value for key, value in await self.cog.redis.items()} +        self.assertIs(type(copy), dict) +        self.assertDictEqual(copy, local_copy) + +    async def test_clear(self): +        """Test that the .clear method removes the entire hash.""" +        await self.cog.redis.set('teddy', 'with me') +        await self.cog.redis.set('in my dreams', 'you have a weird hat') +        self.assertEqual(await self.cog.redis.length(), 2) + +        await self.cog.redis.clear() +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_pop(self): +        """Test that we can .pop an item from the RedisDict.""" +        await self.cog.redis.set('john', 'was afraid') + +        self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') +        self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_update(self): +        """Test that we can .update the RedisDict with multiple items.""" +        await self.cog.redis.set("reckfried", "lona") +        await self.cog.redis.set("bel air", "prince") +        await self.cog.redis.update({ +            "reckfried": "jona", +            "mega": "hungry, though", +        }) + +        result = { +            "reckfried": "jona", +            "bel air": "prince", +            "mega": "hungry, though", +        } +        self.assertDictEqual(await self.cog.redis.to_dict(), result) + +    def test_typestring_conversion(self): +        """Test the typestring-related helper functions.""" +        conversion_tests = ( +            (12, "i|12"), +            (12.4, "f|12.4"), +            ("cowabunga", "s|cowabunga"), +        ) + +        # Test conversion to typestring +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) + +        # Test conversion from typestrings +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) + +        # Test that exceptions are raised on invalid input +        with self.assertRaises(TypeError): +            self.cog.redis._value_to_typestring(["internet"]) +            self.cog.redis._value_from_typestring("o|firedog") + +    async def test_increment_decrement(self): +        """Test .increment and .decrement methods.""" +        await self.cog.redis.set("entropic", 5) +        await self.cog.redis.set("disentropic", 12.5) + +        # Test default increment +        await self.cog.redis.increment("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 6) + +        # Test default decrement +        await self.cog.redis.decrement("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # Test float increment with float +        await self.cog.redis.increment("disentropic", 2.0) +        self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) + +        # Test float increment with int +        await self.cog.redis.increment("disentropic", 2) +        self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) + +        # Test negative increments, because why not. +        await self.cog.redis.increment("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 0) + +        # Negative decrements? Sure. +        await self.cog.redis.decrement("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # What about if we use a negative float to decrement an int? +        # This should convert the type into a float. +        await self.cog.redis.decrement("entropic", -2.5) +        self.assertEqual(await self.cog.redis.get("entropic"), 7.5) + +        # Let's test that they raise the right errors +        with self.assertRaises(KeyError): +            await self.cog.redis.increment("doesn't_exist!") + +        await self.cog.redis.set("stringthing", "stringthing") +        with self.assertRaises(TypeError): +            await self.cog.redis.increment("stringthing") + +    async def test_increment_lock(self): +        """Test that we can't produce a race condition in .increment.""" +        await self.cog.redis.set("test_key", 0) +        tasks = [] + +        # Increment this a lot in different tasks +        for _ in range(100): +            task = asyncio.create_task( +                self.cog.redis.increment("test_key", 1) +            ) +            tasks.append(task) +        await asyncio.gather(*tasks) + +        # Confirm that the value has been incremented the exact right number of times. +        value = await self.cog.redis.get("test_key") +        self.assertEqual(value, 100) + +    async def test_exceptions_raised(self): +        """Testing that the various RuntimeErrors are reachable.""" +        class MyCog: +            cache = RedisCache() + +            def __init__(self): +                self.other_cache = RedisCache() + +        cog = MyCog() + +        # Raises "No Bot instance" +        with self.assertRaises(NoBotInstanceError): +            await cog.cache.get("john") + +        # Raises "RedisCache has no namespace" +        with self.assertRaises(NoNamespaceError): +            await cog.other_cache.get("was") + +        # Raises "You must access the RedisCache instance through the cog instance" +        with self.assertRaises(NoParentInstanceError): +            await MyCog.cache.get("afraid") diff --git a/tests/helpers.py b/tests/helpers.py index 2b79a6c2a..13283339b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,12 +4,15 @@ import collections  import itertools  import logging  import unittest.mock +from asyncio import AbstractEventLoop  from typing import Iterable, Optional  import discord +from aiohttp import ClientSession  from discord.ext.commands import Context  from bot.api import APIClient +from bot.async_stats import AsyncStatsClient  from bot.bot import Bot @@ -264,10 +267,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):      spec_set = APIClient -# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -bot_instance.http_session = None -bot_instance.api_client = None +def _get_mock_loop() -> unittest.mock.Mock: +    """Return a mocked asyncio.AbstractEventLoop.""" +    loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) + +    # Since calling `create_task` on our MockBot does not actually schedule the coroutine object +    # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object +    # to prevent "has not been awaited"-warnings. +    loop.create_task.side_effect = lambda coroutine: coroutine.close() + +    return loop  class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -277,17 +286,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ -    spec_set = bot_instance -    additional_spec_asyncs = ("wait_for",) +    spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) +    additional_spec_asyncs = ("wait_for", "redis_ready")      def __init__(self, **kwargs) -> None:          super().__init__(**kwargs) -        self.api_client = MockAPIClient() -        # Since calling `create_task` on our MockBot does not actually schedule the coroutine object -        # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object -        # to prevent "has not been awaited"-warnings. -        self.loop.create_task.side_effect = lambda coroutine: coroutine.close() +        self.loop = _get_mock_loop() +        self.api_client = MockAPIClient(loop=self.loop) +        self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) +        self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` | 
