diff options
| author | 2020-02-29 17:29:57 +0530 | |
|---|---|---|
| committer | 2020-02-29 17:29:57 +0530 | |
| commit | d583e9b81ae0900f3b81f568c0562adc3adfd6e0 (patch) | |
| tree | ccb48a4fb9a000044ff3a7336e0b33ff7b805981 /tests | |
| parent | Re-corrected the lines which I had changed by mistake (diff) | |
| parent | Merge pull request #797 from Numerlor/fuzzy_zero_div (diff) | |
Merge branch 'master' into tags_overhaul
Diffstat (limited to '')
| -rw-r--r-- | tests/base.py | 34 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_base.py | 412 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 395 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_roles.py | 287 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_users.py | 241 | ||||
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 4 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 10 | ||||
| -rw-r--r-- | tests/bot/cogs/test_snekbox.py | 368 | ||||
| -rw-r--r-- | tests/bot/test_converters.py | 2 | ||||
| -rw-r--r-- | tests/bot/test_utils.py | 15 | ||||
| -rw-r--r-- | tests/helpers.py | 41 | 
11 files changed, 1583 insertions, 226 deletions
| diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@  import logging  import unittest  from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers  class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase):              standard_message = self._truncateMessage(base_message, record_message)              msg = self._formatMessage(msg, standard_message)              self.fail(msg) + + +class CommandTestCase(unittest.TestCase): +    """TestCase with additional assertions that are useful for testing Discord commands.""" + +    @helpers.async_test +    async def assertHasPermissionsCheck( +        self, +        cmd: commands.Command, +        permissions: Dict[str, bool], +    ) -> None: +        """ +        Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + +        Every permission in `permissions` is expected to be reported as missing. In other words, do +        not include permissions which should not raise an exception along with those which should. +        """ +        # Invert permission values because it's more intuitive to pass to this assertion the same +        # permissions as those given to the check decorator. +        permissions = {k: not v for k, v in permissions.items()} + +        ctx = helpers.MockContext() +        ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + +        with self.assertRaises(commands.MissingPermissions) as cm: +            await cmd.can_run(ctx) + +        self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..c2e143865 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): +    """Syncer subclass with mocks for abstract methods for testing purposes.""" + +    name = "test" +    _get_diff = helpers.AsyncMock() +    _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): +    """Tests for the syncer base class.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +    def test_instantiation_fails_without_abstract_methods(self): +        """The class must have abstract methods implemented.""" +        with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): +            Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): +    """Tests for sending the sync confirmation prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) + +    def mock_get_channel(self): +        """Fixture to return a mock channel and message for when `get_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        mock_channel.send.return_value = mock_message +        self.bot.get_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    def mock_fetch_channel(self): +        """Fixture to return a mock channel and message for when `fetch_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        self.bot.get_channel.return_value = None +        mock_channel.send.return_value = mock_message +        self.bot.fetch_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    @helpers.async_test +    async def test_send_prompt_edits_and_returns_message(self): +        """The given message should be edited to display the prompt and then should be returned.""" +        msg = helpers.MockMessage() +        ret_val = await self.syncer._send_prompt(msg) + +        msg.edit.assert_called_once() +        self.assertIn("content", msg.edit.call_args[1]) +        self.assertEqual(ret_val, msg) + +    @helpers.async_test +    async def test_send_prompt_gets_dev_core_channel(self): +        """The dev-core channel should be retrieved if an extant message isn't given.""" +        subtests = ( +            (self.bot.get_channel, self.mock_get_channel), +            (self.bot.fetch_channel, self.mock_fetch_channel), +        ) + +        for method, mock_ in subtests: +            with self.subTest(method=method, msg=mock_.__name__): +                mock_() +                await self.syncer._send_prompt() + +                method.assert_called_once_with(constants.Channels.dev_core) + +    @helpers.async_test +    async def test_send_prompt_returns_None_if_channel_fetch_fails(self): +        """None should be returned if there's an HTTPException when fetching the channel.""" +        self.bot.get_channel.return_value = None +        self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + +        ret_val = await self.syncer._send_prompt() + +        self.assertIsNone(ret_val) + +    @helpers.async_test +    async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): +        """A new message mentioning core devs should be sent and returned if message isn't given.""" +        for mock_ in (self.mock_get_channel, self.mock_fetch_channel): +            with self.subTest(msg=mock_.__name__): +                mock_channel, mock_message = mock_() +                ret_val = await self.syncer._send_prompt() + +                mock_channel.send.assert_called_once() +                self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) +                self.assertEqual(ret_val, mock_message) + +    @helpers.async_test +    async def test_send_prompt_adds_reactions(self): +        """The message should have reactions for confirmation added.""" +        extant_message = helpers.MockMessage() +        subtests = ( +            (extant_message, lambda: (None, extant_message)), +            (None, self.mock_get_channel), +            (None, self.mock_fetch_channel), +        ) + +        for message_arg, mock_ in subtests: +            subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + +            with self.subTest(msg=subtest_msg): +                _, mock_message = mock_() +                await self.syncer._send_prompt(message_arg) + +                calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] +                mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): +    """Tests for waiting for a sync confirmation reaction on the prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) +        self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + +    @staticmethod +    def get_message_reaction(emoji): +        """Fixture to return a mock message an reaction from the given `emoji`.""" +        message = helpers.MockMessage() +        reaction = helpers.MockReaction(emoji=emoji, message=message) + +        return message, reaction + +    def test_reaction_check_for_valid_emoji_and_authors(self): +        """Should return True if authors are identical or are a bot and a core dev, respectively.""" +        user_subtests = ( +            ( +                helpers.MockMember(id=77), +                helpers.MockMember(id=77), +                "identical users", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "bot author and core-dev reactor", +            ), +        ) + +        for emoji in self.syncer._REACTION_EMOJIS: +            for author, user, msg in user_subtests: +                with self.subTest(author=author, user=user, emoji=emoji, msg=msg): +                    message, reaction = self.get_message_reaction(emoji) +                    ret_val = self.syncer._reaction_check(author, message, reaction, user) + +                    self.assertTrue(ret_val) + +    def test_reaction_check_for_invalid_reactions(self): +        """Should return False for invalid reaction events.""" +        valid_emoji = self.syncer._REACTION_EMOJIS[0] +        subtests = ( +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "users are not identical", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43), +                "reactor lacks the core-dev role", +            ), +            ( +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                "reactor is a bot", +            ), +            ( +                helpers.MockMember(id=77), +                helpers.MockMessage(id=95), +                helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), +                helpers.MockMember(id=77), +                "messages are not identical", +            ), +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction("InVaLiD"), +                helpers.MockMember(id=77), +                "emoji is invalid", +            ), +        ) + +        for *args, msg in subtests: +            kwargs = dict(zip(("author", "message", "reaction", "user"), args)) +            with self.subTest(**kwargs, msg=msg): +                ret_val = self.syncer._reaction_check(*args) +                self.assertFalse(ret_val) + +    @helpers.async_test +    async def test_wait_for_confirmation(self): +        """The message should always be edited and only return True if the emoji is a check mark.""" +        subtests = ( +            (constants.Emojis.check_mark, True, None), +            ("InVaLiD", False, None), +            (None, False, TimeoutError), +        ) + +        for emoji, ret_val, side_effect in subtests: +            for bot in (True, False): +                with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): +                    # Set up mocks +                    message = helpers.MockMessage() +                    member = helpers.MockMember(bot=bot) + +                    self.bot.wait_for.reset_mock() +                    self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) +                    self.bot.wait_for.side_effect = side_effect + +                    # Call the function +                    actual_return = await self.syncer._wait_for_confirmation(member, message) + +                    # Perform assertions +                    self.bot.wait_for.assert_called_once() +                    self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + +                    message.edit.assert_called_once() +                    kwargs = message.edit.call_args[1] +                    self.assertIn("content", kwargs) + +                    # Core devs should only be mentioned if the author is a bot. +                    if bot: +                        self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) +                    else: +                        self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + +                    self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): +    """Tests for main function orchestrating the sync.""" + +    def setUp(self): +        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) +        self.syncer = TestSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_respects_confirmation_result(self): +        """The sync should abort if confirmation fails and continue if confirmed.""" +        mock_message = helpers.MockMessage() +        subtests = ( +            (True, mock_message), +            (False, None), +        ) + +        for confirmed, message in subtests: +            with self.subTest(confirmed=confirmed): +                self.syncer._sync.reset_mock() +                self.syncer._get_diff.reset_mock() + +                diff = _Diff({1, 2, 3}, {4, 5}, None) +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(confirmed, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() + +                if confirmed: +                    self.syncer._sync.assert_called_once_with(diff) +                else: +                    self.syncer._sync.assert_not_called() + +    @helpers.async_test +    async def test_sync_diff_size(self): +        """The diff size should be correctly calculated.""" +        subtests = ( +            (6, _Diff({1, 2}, {3, 4}, {5, 6})), +            (5, _Diff({1, 2, 3}, None, {4, 5})), +            (0, _Diff(None, None, None)), +            (0, _Diff(set(), set(), set())), +        ) + +        for size, diff in subtests: +            with self.subTest(size=size, diff=diff): +                self.syncer._get_diff.reset_mock() +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + +    @helpers.async_test +    async def test_sync_message_edited(self): +        """The message should be edited if one was sent, even if the sync has an API error.""" +        subtests = ( +            (None, None, False), +            (helpers.MockMessage(), None, True), +            (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), +        ) + +        for message, side_effect, should_edit in subtests: +            with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): +                self.syncer._sync.side_effect = side_effect +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(True, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                if should_edit: +                    message.edit.assert_called_once() +                    self.assertIn("content", message.edit.call_args[1]) + +    @helpers.async_test +    async def test_sync_confirmation_context_redirect(self): +        """If ctx is given, a new message should be sent and author should be ctx's author.""" +        mock_member = helpers.MockMember() +        subtests = ( +            (None, self.bot.user, None), +            (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), +        ) + +        for ctx, author, message in subtests: +            with self.subTest(ctx=ctx, author=author, message=message): +                if ctx is not None: +                    ctx.send.return_value = message + +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild, ctx) + +                if ctx is not None: +                    ctx.send.assert_called_once() + +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_small_diff(self): +        """Should always return True and the given message if the diff size is too small.""" +        author = helpers.MockMember() +        expected_message = helpers.MockMessage() + +        for size in (3, 2): +            with self.subTest(size=size): +                self.syncer._send_prompt = helpers.AsyncMock() +                self.syncer._wait_for_confirmation = helpers.AsyncMock() + +                coro = self.syncer._get_confirmation_result(size, author, expected_message) +                result, actual_message = await coro + +                self.assertTrue(result) +                self.assertEqual(actual_message, expected_message) +                self.syncer._send_prompt.assert_not_called() +                self.syncer._wait_for_confirmation.assert_not_called() + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_large_diff(self): +        """Should return True if confirmed and False if _send_prompt fails or aborted.""" +        author = helpers.MockMember() +        mock_message = helpers.MockMessage() + +        subtests = ( +            (True, mock_message, True, "confirmed"), +            (False, None, False, "_send_prompt failed"), +            (False, mock_message, False, "aborted"), +        ) + +        for expected_result, expected_message, confirmed, msg in subtests: +            with self.subTest(msg=msg): +                self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) +                self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + +                coro = self.syncer._get_confirmation_result(4, author) +                actual_result, actual_message = await coro + +                self.syncer._send_prompt.assert_called_once_with(None)  # message defaults to None +                self.assertIs(actual_result, expected_result) +                self.assertEqual(actual_message, expected_message) + +                if expected_message: +                    self.syncer._wait_for_confirmation.assert_called_once_with( +                        author, expected_message +                    ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): +    """ +    A MagicMock subclass to mock Syncer objects. + +    Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` +    instances. For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): +    """Tests for the sync extension.""" + +    @staticmethod +    def test_extension_setup(): +        """The Sync cog should be added.""" +        bot = helpers.MockBot() +        sync.setup(bot) +        bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): +    """Base class for Sync cog tests. Sets up patches for syncers.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +        # These patch the type. When the type is called, a MockSyncer instanced is returned. +        # MockSyncer is needed so that our custom AsyncMock is used. +        # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. +        self.role_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.RoleSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.user_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.UserSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.RoleSyncer = self.role_syncer_patcher.start() +        self.UserSyncer = self.user_syncer_patcher.start() + +        self.cog = sync.Sync(self.bot) + +    def tearDown(self): +        self.role_syncer_patcher.stop() +        self.user_syncer_patcher.stop() + +    @staticmethod +    def response_error(status: int) -> ResponseCodeError: +        """Fixture to return a ResponseCodeError with the given status code.""" +        response = mock.MagicMock() +        response.status = status + +        return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): +    """Tests for the Sync cog.""" + +    @mock.patch.object(sync.Sync, "sync_guild") +    def test_sync_cog_init(self, sync_guild): +        """Should instantiate syncers and run a sync for the guild.""" +        # Reset because a Sync cog was already instantiated in setUp. +        self.RoleSyncer.reset_mock() +        self.UserSyncer.reset_mock() +        self.bot.loop.create_task.reset_mock() + +        mock_sync_guild_coro = mock.MagicMock() +        sync_guild.return_value = mock_sync_guild_coro + +        sync.Sync(self.bot) + +        self.RoleSyncer.assert_called_once_with(self.bot) +        self.UserSyncer.assert_called_once_with(self.bot) +        sync_guild.assert_called_once_with() +        self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + +    @helpers.async_test +    async def test_sync_cog_sync_guild(self): +        """Roles and users should be synced only if a guild is successfully retrieved.""" +        for guild in (helpers.MockGuild(), None): +            with self.subTest(guild=guild): +                self.bot.reset_mock() +                self.cog.role_syncer.reset_mock() +                self.cog.user_syncer.reset_mock() + +                self.bot.get_guild = mock.MagicMock(return_value=guild) + +                await self.cog.sync_guild() + +                self.bot.wait_until_guild_available.assert_called_once() +                self.bot.get_guild.assert_called_once_with(constants.Guild.id) + +                if guild is None: +                    self.cog.role_syncer.sync.assert_not_called() +                    self.cog.user_syncer.sync.assert_not_called() +                else: +                    self.cog.role_syncer.sync.assert_called_once_with(guild) +                    self.cog.user_syncer.sync.assert_called_once_with(guild) + +    async def patch_user_helper(self, side_effect: BaseException) -> None: +        """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" +        self.bot.api_client.patch.reset_mock(side_effect=True) +        self.bot.api_client.patch.side_effect = side_effect + +        user_id, updated_information = 5, {"key": 123} +        await self.cog.patch_user(user_id, updated_information) + +        self.bot.api_client.patch.assert_called_once_with( +            f"bot/users/{user_id}", +            json=updated_information, +        ) + +    @helpers.async_test +    async def test_sync_cog_patch_user(self): +        """A PATCH request should be sent and 404 errors ignored.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                await self.patch_user_helper(side_effect) + +    @helpers.async_test +    async def test_sync_cog_patch_user_non_404(self): +        """A PATCH request should be sent and the error raised if it's not a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): +    """Tests for the listeners of the Sync cog.""" + +    def setUp(self): +        super().setUp() +        self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_create(self): +        """A POST request should be sent with the new role's data.""" +        self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        role = helpers.MockRole(**role_data) +        await self.cog.on_guild_role_create(role) + +        self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_delete(self): +        """A DELETE request should be sent.""" +        self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + +        role = helpers.MockRole(id=99) +        await self.cog.on_guild_role_delete(role) + +        self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_update(self): +        """A PUT request should be sent if the colour, name, permissions, or position changes.""" +        self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        subtests = ( +            (True, ("colour", "name", "permissions", "position")), +            (False, ("hoist", "mentionable")), +        ) + +        for should_put, attributes in subtests: +            for attribute in attributes: +                with self.subTest(should_put=should_put, changed_attribute=attribute): +                    self.bot.api_client.put.reset_mock() + +                    after_role_data = role_data.copy() +                    after_role_data[attribute] = 876 + +                    before_role = helpers.MockRole(**role_data) +                    after_role = helpers.MockRole(**after_role_data) + +                    await self.cog.on_guild_role_update(before_role, after_role) + +                    if should_put: +                        self.bot.api_client.put.assert_called_once_with( +                            f"bot/roles/{after_role.id}", +                            json=after_role_data +                        ) +                    else: +                        self.bot.api_client.put.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_remove(self): +        """Member should patched to set in_guild as False.""" +        self.assertTrue(self.cog.on_member_remove.__cog_listener__) + +        member = helpers.MockMember() +        await self.cog.on_member_remove(member) + +        self.cog.patch_user.assert_called_once_with( +            member.id, +            updated_information={"in_guild": False} +        ) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_roles(self): +        """Members should be patched if their roles have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        # Roles are intentionally unsorted. +        before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] +        before_member = helpers.MockMember(roles=before_roles) +        after_member = helpers.MockMember(roles=before_roles[1:]) + +        await self.cog.on_member_update(before_member, after_member) + +        data = {"roles": sorted(role.id for role in after_member.roles)} +        self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_other(self): +        """Members should not be patched if other attributes have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        subtests = ( +            ("activities", discord.Game("Pong"), discord.Game("Frogger")), +            ("nick", "old nick", "new nick"), +            ("status", discord.Status.online, discord.Status.offline), +        ) + +        for attribute, old_value, new_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                before_member = helpers.MockMember(**{attribute: old_value}) +                after_member = helpers.MockMember(**{attribute: new_value}) + +                await self.cog.on_member_update(before_member, after_member) + +                self.cog.patch_user.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_user_update(self): +        """A user should be patched only if the name, discriminator, or avatar changes.""" +        self.assertTrue(self.cog.on_user_update.__cog_listener__) + +        before_data = { +            "name": "old name", +            "discriminator": "1234", +            "avatar": "old avatar", +            "bot": False, +        } + +        subtests = ( +            (True, "name", "name", "new name", "new name"), +            (True, "discriminator", "discriminator", "8765", 8765), +            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), +            (False, "bot", "bot", True, True), +        ) + +        for should_patch, attribute, api_field, value, api_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                after_data = before_data.copy() +                after_data[attribute] = value +                before_user = helpers.MockUser(**before_data) +                after_user = helpers.MockUser(**after_data) + +                await self.cog.on_user_update(before_user, after_user) + +                if should_patch: +                    self.cog.patch_user.assert_called_once() + +                    # Don't care if *all* keys are present; only the changed one is required +                    call_args = self.cog.patch_user.call_args +                    self.assertEqual(call_args[0][0], after_user.id) +                    self.assertIn("updated_information", call_args[1]) + +                    updated_information = call_args[1]["updated_information"] +                    self.assertIn(api_field, updated_information) +                    self.assertEqual(updated_information[api_field], api_value) +                else: +                    self.cog.patch_user.assert_not_called() + +    async def on_member_join_helper(self, side_effect: Exception) -> dict: +        """ +        Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + +        The request data for the mock member is returned. All exceptions will be re-raised. +        """ +        member = helpers.MockMember( +            discriminator="1234", +            roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], +        ) + +        data = { +            "avatar_hash": member.avatar, +            "discriminator": int(member.discriminator), +            "id": member.id, +            "in_guild": True, +            "name": member.name, +            "roles": sorted(role.id for role in member.roles) +        } + +        self.bot.api_client.put.reset_mock(side_effect=True) +        self.bot.api_client.put.side_effect = side_effect + +        try: +            await self.cog.on_member_join(member) +        except Exception: +            raise +        finally: +            self.bot.api_client.put.assert_called_once_with( +                f"bot/users/{member.id}", +                json=data +            ) + +        return data + +    @helpers.async_test +    async def test_sync_cog_on_member_join(self): +        """Should PUT user's data or POST it if the user doesn't exist.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                self.bot.api_client.post.reset_mock() +                data = await self.on_member_join_helper(side_effect) + +                if side_effect: +                    self.bot.api_client.post.assert_called_once_with("bot/users", json=data) +                else: +                    self.bot.api_client.post.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_join_non_404(self): +        """ResponseCodeError should be re-raised if status code isn't a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.on_member_join_helper(self.response_error(500)) + +        self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): +    """Tests for the commands in the Sync cog.""" + +    @helpers.async_test +    async def test_sync_roles_command(self): +        """sync() should be called on the RoleSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_roles_command.callback(self.cog, ctx) + +        self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    @helpers.async_test +    async def test_sync_users_command(self): +        """sync() should be called on the UserSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_users_command.callback(self.cog, ctx) + +        self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    def test_commands_require_admin(self): +        """The sync commands should only run if the author has the administrator permission.""" +        cmds = ( +            self.cog.sync_group, +            self.cog.sync_roles_command, +            self.cog.sync_users_command, +        ) + +        for cmd in cmds: +            with self.subTest(cmd=cmd): +                self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): -    """Tests constructing the roles to synchronize with the site.""" - -    def test_get_roles_for_sync_empty_return_for_equal_roles(self): -        """No roles should be synced when no diff is found.""" -        api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), set(), set()) -        ) - -    def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): -        """Roles to be synced are returned when non-ID attributes differ.""" -        api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), guild_roles, set()) -        ) - -    def test_get_roles_only_returns_roles_that_require_update(self): -        """Roles that require an update should be returned as the second tuple element.""" -        api_roles = { -            Role(id=41, name='old name', colour=33, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } -        guild_roles = { -            Role(id=41, name='new name', colour=35, permissions=0x8, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_new_roles_in_first_tuple_element(self): -        """Newly created roles are returned as the first tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=2) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, -                set(), -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_update_and_new_roles(self): -        """Newly created and updated roles should be returned together.""" -        api_roles = { -            Role(id=41, name='old name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='new name', colour=40, permissions=0x16, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, -                {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_delete(self): -        """Roles to be deleted should be returned as the third tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                set(), -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) - -    def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): -        """When roles were added, updated, and removed, all of them are returned properly.""" -        api_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -            Role(id=71, name='to update', colour=99, permissions=0x9, position=3), -        } -        guild_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=81, name='to create', colour=99, permissions=0x9, position=4), -            Role(id=71, name='updated', colour=101, permissions=0x5, position=3), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, -                {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): +    """Fixture to return a dictionary representing a role with default values set.""" +    kwargs.setdefault("id", 9) +    kwargs.setdefault("name", "fake role") +    kwargs.setdefault("colour", 7) +    kwargs.setdefault("permissions", 0) +    kwargs.setdefault("position", 55) + +    return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @staticmethod +    def get_guild(*roles): +        """Fixture to return a guild object with the given roles.""" +        guild = helpers.MockGuild() +        guild.roles = [] + +        for role in roles: +            mock_role = helpers.MockRole(**role) +            mock_role.colour = discord.Colour(role["colour"]) +            mock_role.permissions = discord.Permissions(role["permissions"]) +            guild.roles.append(mock_role) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_identical_roles(self): +        """No differences should be found if the roles in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_roles(self): +        """Only updated roles should be added to the 'updated' set of the diff.""" +        updated_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] +        guild = self.get_guild(updated_role, fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_Role(**updated_role)}, set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_roles(self): +        """Only new roles should be added to the 'created' set of the diff.""" +        new_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role(), new_role) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new_role)}, set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_deleted_roles(self): +        """Only deleted roles should be added to the 'deleted' set of the diff.""" +        deleted_role = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [fake_role(), deleted_role] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), {_Role(**deleted_role)}) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_deleted_roles(self): +        """When roles are added, updated, and removed, all of them are returned properly.""" +        new = fake_role(id=41, name="new") +        updated = fake_role(id=71, name="updated") +        deleted = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [ +            fake_role(), +            fake_role(id=71, name="updated name"), +            deleted, +        ] +        guild = self.get_guild(fake_role(), new, updated) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + +        self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync roles.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_roles(self): +        """Only POST requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(role_tuples, set(), set()) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/roles", json=role) for role in roles] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_roles(self): +        """Only PUT requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), role_tuples, set()) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_deleted_roles(self): +        """Only DELETE requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), set(), role_tuples) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] +        self.bot.api_client.delete.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers  def fake_user(**kwargs): -    kwargs.setdefault('id', 43) -    kwargs.setdefault('name', 'bob the test man') -    kwargs.setdefault('discriminator', 1337) -    kwargs.setdefault('avatar_hash', None) -    kwargs.setdefault('roles', (666,)) -    kwargs.setdefault('in_guild', True) -    return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): -    """Tests constructing the users to synchronize with the site.""" - -    def test_get_users_for_sync_returns_nothing_for_empty_params(self): -        """When no users are given, none are returned.""" -        self.assertEqual( -            get_users_for_sync({}, {}), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_nothing_for_equal_users(self): -        """When no users are updated, none are returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): -        """When a non-ID-field differs, the user to update is returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(name='new fancy name')} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(name='new fancy name')}) -        ) - -    def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): -        """When new users join the guild, they are returned as the first tuple element.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(), 63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, set()) -        ) - -    def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): +    """Fixture to return a dictionary representing a user with default values set.""" +    kwargs.setdefault("id", 43) +    kwargs.setdefault("name", "bob the test man") +    kwargs.setdefault("discriminator", 1337) +    kwargs.setdefault("avatar_hash", None) +    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("in_guild", True) + +    return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between users in the DB and users in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @staticmethod +    def get_guild(*members): +        """Fixture to return a guild object with the given members.""" +        guild = helpers.MockGuild() +        guild.members = [] + +        for member in members: +            member = member.copy() +            member["avatar"] = member.pop("avatar_hash") +            del member["in_guild"] + +            mock_member = helpers.MockMember(**member) +            mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + +            guild.members.append(mock_member) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_no_users(self): +        """When no users are given, an empty diff should be returned.""" +        guild = self.get_guild() + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_identical_users(self): +        """No differences should be found if the users in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_users(self): +        """Only updated users should be added to the 'updated' set of the diff.""" +        updated_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        guild = self.get_guild(updated_user, fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**updated_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_users(self): +        """Only new users should be added to the 'created' set of the diff.""" +        new_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user(), new_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        api_users = {43: fake_user(), 63: fake_user(id=63)} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(id=63, in_guild=False)}) -        ) - -    def test_get_users_for_sync_updates_and_creates_users_as_needed(self): -        """When one user left and another one was updated, both are returned.""" -        api_users = {43: fake_user()} -        guild_users = {63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, {fake_user(in_guild=False)}) -        ) - -    def test_get_users_for_sync_does_not_duplicate_update_users(self): -        """When the API knows a user the guild doesn't, nothing is performed.""" -        api_users = {43: fake_user(in_guild=False)} -        guild_users = {} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_leaving_users(self): +        """When users are added, updated, and removed, all of them are returned properly.""" +        new_user = fake_user(id=99, name="new") +        updated_user = fake_user(id=55, name="updated") +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        guild = self.get_guild(fake_user(), new_user, updated_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_db_users_not_in_guild(self): +        """When the DB knows a user the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync users.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_users(self): +        """Only POST requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(user_tuples, set(), None) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/users", json=user) for user in users] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(users)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_users(self): +        """Only PUT requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(set(), user_tuples, None) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(users)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..5b0a3b8c3 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase):          asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase):          with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:              asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(len(log_watcher.records), 1) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index deae7ebad..8443cfe71 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase):      @classmethod      def setUpClass(cls): -        cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator) +        cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators)      def setUp(self):          """Sets up fresh objects for each test.""" @@ -521,7 +521,7 @@ class UserCommandTests(unittest.TestCase):          """A regular user should not be able to use this command outside of bot-commands."""          constants.MODERATION_ROLES = [self.moderator_role.id]          constants.STAFF_ROLES = [self.moderator_role.id] -        constants.Channels.bot = 50 +        constants.Channels.bot_commands = 50          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) @@ -533,7 +533,7 @@ class UserCommandTests(unittest.TestCase):      def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id] -        constants.Channels.bot = 50 +        constants.Channels.bot_commands = 50          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -546,7 +546,7 @@ class UserCommandTests(unittest.TestCase):      def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id] -        constants.Channels.bot = 50 +        constants.Channels.bot_commands = 50          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -559,7 +559,7 @@ class UserCommandTests(unittest.TestCase):      def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id] -        constants.Channels.bot = 50 +        constants.Channels.bot_commands = 50          ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py new file mode 100644 index 000000000..985bc66a1 --- /dev/null +++ b/tests/bot/cogs/test_snekbox.py @@ -0,0 +1,368 @@ +import asyncio +import logging +import unittest +from functools import partial +from unittest.mock import MagicMock, Mock, call, patch + +from bot.cogs import snekbox +from bot.cogs.snekbox import Snekbox +from bot.constants import URLs +from tests.helpers import ( +    AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test +) + + +class SnekboxTests(unittest.TestCase): +    def setUp(self): +        """Add mocked bot and cog to the instance.""" +        self.bot = MockBot() + +        self.mocked_post = MagicMock() +        self.mocked_post.json = AsyncMock() +        self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post)) + +        self.cog = Snekbox(bot=self.bot) + +    @async_test +    async def test_post_eval(self): +        """Post the eval code to the URLs.snekbox_eval_api endpoint.""" +        self.mocked_post.json.return_value = {'lemon': 'AI'} + +        self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'}) +        self.bot.http_session.post.assert_called_once_with( +            URLs.snekbox_eval_api, +            json={"input": "import random"}, +            raise_for_status=True +        ) + +    @async_test +    async def test_upload_output_reject_too_long(self): +        """Reject output longer than MAX_PASTE_LEN.""" +        result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) +        self.assertEqual(result, "too long to upload") + +    @async_test +    async def test_upload_output(self): +        """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" +        key = "RainbowDash" +        self.mocked_post.json.return_value = {"key": key} + +        self.assertEqual( +            await self.cog.upload_output("My awesome output"), +            URLs.paste_service.format(key=key) +        ) +        self.bot.http_session.post.assert_called_once_with( +            URLs.paste_service.format(key="documents"), +            data="My awesome output", +            raise_for_status=True +        ) + +    @async_test +    async def test_upload_output_gracefully_fallback_if_exception_during_request(self): +        """Output upload gracefully fallback if the upload fail.""" +        self.mocked_post.json.side_effect = Exception +        log = logging.getLogger("bot.cogs.snekbox") +        with self.assertLogs(logger=log, level='ERROR'): +            await self.cog.upload_output('My awesome output!') + +    @async_test +    async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): +        """Output upload gracefully fallback if there is no key entry in the response body.""" +        self.mocked_post.json.return_value = {} +        self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + +    def test_prepare_input(self): +        cases = ( +            ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), +            ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), +            ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), +            ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), +        ) +        for case, expected, testname in cases: +            with self.subTest(msg=f'Extract code from {testname}.'): +                self.assertEqual(self.cog.prepare_input(case), expected) + +    def test_get_results_message(self): +        """Return error and message according to the eval result.""" +        cases = ( +            ('ERROR', None, ('Your eval job has failed', 'ERROR')), +            ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), +            ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) +        ) +        for stdout, returncode, expected in cases: +            with self.subTest(stdout=stdout, returncode=returncode, expected=expected): +                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) +                self.assertEqual(actual, expected) + +    @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) +    def test_get_results_message_invalid_signal(self, mock_Signals: Mock): +        self.assertEqual( +            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            ('Your eval job has completed with return code 127', '') +        ) + +    @patch('bot.cogs.snekbox.Signals') +    def test_get_results_message_valid_signal(self, mock_Signals: Mock): +        mock_Signals.return_value.name = 'SIGTEST' +        self.assertEqual( +            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            ('Your eval job has completed with return code 127 (SIGTEST)', '') +        ) + +    def test_get_status_emoji(self): +        """Return emoji according to the eval result.""" +        cases = ( +            (' ', -1, ':warning:'), +            ('Hello world!', 0, ':white_check_mark:'), +            ('Invalid beard size', -1, ':x:') +        ) +        for stdout, returncode, expected in cases: +            with self.subTest(stdout=stdout, returncode=returncode, expected=expected): +                actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) +                self.assertEqual(actual, expected) + +    @async_test +    async def test_format_output(self): +        """Test output formatting.""" +        self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + +        too_many_lines = ( +            '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' +            '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' +        ) +        too_long_too_many_lines = ( +            "\n".join( +                f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) +            )[:1000] + "\n... (truncated - too long, too many lines)" +        ) + +        cases = ( +            ('', ('[No output]', None), 'No output'), +            ('My awesome output', ('My awesome output', None), 'One line output'), +            ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), +            ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'), +            ( +                '\u202E\u202E\u202E', +                ('Code block escape attempt detected; will not output result', None), +                'Detect RIGHT-TO-LEFT OVERRIDE' +            ), +            ( +                '\u200B\u200B\u200B', +                ('Code block escape attempt detected; will not output result', None), +                'Detect ZERO WIDTH SPACE' +            ), +            ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'), +            ( +                'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd', +                (too_many_lines, 'https://testificate.com/'), +                '12 lines output' +            ), +            ( +                'verylongbeard' * 100, +                ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'), +                '1300 characters output' +            ), +            ( +                ('verylongbeard' * 10 + '\n') * 15, +                (too_long_too_many_lines, 'https://testificate.com/'), +                '15 lines, 1965 characters output' +            ), +        ) +        for case, expected, testname in cases: +            with self.subTest(msg=testname, case=case, expected=expected): +                self.assertEqual(await self.cog.format_output(case), expected) + +    @async_test +    async def test_eval_command_evaluate_once(self): +        """Test the eval command procedure.""" +        ctx = MockContext() +        response = MockMessage() +        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') +        self.cog.send_eval = AsyncMock(return_value=response) +        self.cog.continue_eval = AsyncMock(return_value=None) + +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') +        self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') +        self.cog.continue_eval.assert_called_once_with(ctx, response) + +    @async_test +    async def test_eval_command_evaluate_twice(self): +        """Test the eval and re-eval command procedure.""" +        ctx = MockContext() +        response = MockMessage() +        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') +        self.cog.send_eval = AsyncMock(return_value=response) +        self.cog.continue_eval = AsyncMock() +        self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) +        self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') +        self.cog.continue_eval.assert_called_with(ctx, response) + +    @async_test +    async def test_eval_command_reject_two_eval_at_the_same_time(self): +        """Test if the eval command rejects an eval if the author already have a running eval.""" +        ctx = MockContext() +        ctx.author.id = 42 +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.send = AsyncMock() +        self.cog.jobs = (42,) +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        ctx.send.assert_called_once_with( +            "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" +        ) + +    @async_test +    async def test_eval_command_call_help(self): +        """Test if the eval command call the help command if no code is provided.""" +        ctx = MockContext() +        ctx.invoke = AsyncMock() +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') +        ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") + +    @async_test +    async def test_send_eval(self): +        """Test the send_eval function.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) +        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.format_output.assert_called_once_with('') + +    @async_test +    async def test_send_eval_with_paste_link(self): +        """Test the send_eval function with a too long output that generate a paste link.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) +        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :yay!: Return code 0.' +            '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.format_output.assert_called_once_with('Way too long beard') + +    @async_test +    async def test_send_eval_with_non_zero_eval(self): +        """Test the send_eval function with a code returning a non-zero code.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) +        self.cog.get_status_emoji = MagicMock(return_value=':nope!:') +        self.cog.format_output = AsyncMock()  # This function isn't called + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.format_output.assert_not_called() + +    @async_test +    async def test_continue_eval_does_continue(self): +        """Test that the continue_eval function does continue if required conditions are met.""" +        ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) +        response = MockMessage(delete=AsyncMock()) +        new_msg = MockMessage(content='!e NewCode') +        self.bot.wait_for.side_effect = ((None, new_msg), None) + +        actual = await self.cog.continue_eval(ctx, response) +        self.assertEqual(actual, 'NewCode') +        self.bot.wait_for.has_calls( +            call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10), +            call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +        ) +        ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) +        ctx.message.clear_reactions.assert_called_once() +        response.delete.assert_called_once() + +    @async_test +    async def test_continue_eval_does_not_continue(self): +        ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) +        self.bot.wait_for.side_effect = asyncio.TimeoutError + +        actual = await self.cog.continue_eval(ctx, MockMessage()) +        self.assertEqual(actual, None) +        ctx.message.clear_reactions.assert_called_once() + +    def test_predicate_eval_message_edit(self): +        """Test the predicate_eval_message_edit function.""" +        msg0 = MockMessage(id=1, content='abc') +        msg1 = MockMessage(id=2, content='abcdef') +        msg2 = MockMessage(id=1, content='abcdef') + +        cases = ( +            (msg0, msg0, False, 'same ID, same content'), +            (msg0, msg1, False, 'different ID, different content'), +            (msg0, msg2, True, 'same ID, different content') +        ) +        for ctx_msg, new_msg, expected, testname in cases: +            with self.subTest(msg=f'Messages with {testname} return {expected}'): +                ctx = MockContext(message=ctx_msg) +                actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) +                self.assertEqual(actual, expected) + +    def test_predicate_eval_emoji_reaction(self): +        """Test the predicate_eval_emoji_reaction function.""" +        valid_reaction = MockReaction(message=MockMessage(id=1)) +        valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI +        valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) +        valid_user = MockUser(id=2) + +        invalid_reaction_id = MockReaction(message=MockMessage(id=42)) +        invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI +        invalid_user_id = MockUser(id=42) +        invalid_reaction_str = MockReaction(message=MockMessage(id=1)) +        invalid_reaction_str.__str__.return_value = ':longbeard:' + +        cases = ( +            (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), +            (valid_reaction, invalid_user_id, False, 'invalid user ID'), +            (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), +            (valid_reaction, valid_user, True, 'matching attributes') +        ) +        for reaction, user, expected, testname in cases: +            with self.subTest(msg=f'Test with {testname} and expected return {expected}'): +                actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) +                self.assertEqual(actual, expected) + + +class SnekboxSetupTests(unittest.TestCase): +    """Tests setup of the `Snekbox` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        snekbox.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index b2b78d9dd..1e5ca62ae 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -68,7 +68,7 @@ class ConverterTests(unittest.TestCase):              ('👋', "Don't be ridiculous, you can't use that character!"),              ('', "Tag names should not be empty, or filled with whitespace."),              ('  ', "Tag names should not be empty, or filled with whitespace."), -            ('42', "Tag names can't be numbers."), +            ('42', "Tag names must contain at least one letter."),              ('x' * 128, "Are you insane? That's way too long!"),          ) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):          instance = utils.CaseInsensitiveDict()          instance.update({'FOO': 'bar'})          self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): -    """Tests the `chunk` method.""" - -    def test_empty_chunking(self): -        """Tests chunking on an empty iterable.""" -        generator = utils.chunks(iterable=[], size=5) -        self.assertEqual(list(generator), []) - -    def test_list_chunking(self): -        """Tests chunking a non-empty list.""" -        iterable = [1, 2, 3, 4, 5] -        generator = utils.chunks(iterable=iterable, size=2) -        self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..6f50f6ae3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional  import discord  from discord.ext.commands import Context +from bot.api import APIClient  from bot.bot import Bot @@ -127,6 +128,18 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):          return super().__call__(*args, **kwargs) +class AsyncContextManagerMock(unittest.mock.MagicMock): +    def __init__(self, return_value: Any): +        super().__init__() +        self._return_value = return_value + +    async def __aenter__(self): +        return self._return_value + +    async def __aexit__(self, *args): +        pass + +  class AsyncIteratorMock:      """      A class to mock asynchronous iterators. @@ -269,9 +282,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      information, see the `MockGuild` docstring.      """      def __init__(self, **kwargs) -> None: -        default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} +        default_kwargs = { +            'id': next(self.discord_id), +            'name': 'role', +            'position': 1, +            'colour': discord.Colour(0xdeadbf), +            'permissions': discord.Permissions(), +        }          super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        if isinstance(self.colour, int): +            self.colour = discord.Colour(self.colour) + +        if isinstance(self.permissions, int): +            self.permissions = discord.Permissions(self.permissions) +          if 'mention' not in kwargs:              self.mention = f'&{self.name}' @@ -324,6 +349,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):              self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock APIClient objects. + +    Instances of this class will follow the specifications of `bot.api.APIClient` instances. +    For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=APIClient, **kwargs) + +  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`  bot_instance = Bot(command_prefix=unittest.mock.MagicMock())  bot_instance.http_session = None @@ -340,6 +377,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      def __init__(self, **kwargs) -> None:          super().__init__(spec_set=bot_instance, **kwargs) +        self.api_client = MockAPIClient()          # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and          # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -503,6 +541,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage())          self.users = AsyncIteratorMock(kwargs.get('users', [])) +        self.__str__.return_value = str(self.emoji)  webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) | 
