diff options
| author | 2020-02-25 10:15:26 -0800 | |
|---|---|---|
| committer | 2020-02-25 10:15:26 -0800 | |
| commit | 2dff2967c1f4348dbb9fcc98b6c79c0f5136eb57 (patch) | |
| tree | fd126276a6089f07552ea91878c7b4fd9ed5b869 /tests | |
| parent | Scheduler: make _scheduled_tasks private (diff) | |
| parent | Merge pull request #781 from python-discord/bug/utils/bot-1c/reminder-unsched... (diff) | |
Merge remote-tracking branch 'origin/master' into bug/backend/b754/scheduler-suppresses-errors
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/base.py | 34 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_base.py | 412 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 395 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_roles.py | 287 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_users.py | 241 | ||||
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 4 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 14 | ||||
| -rw-r--r-- | tests/bot/rules/__init__.py | 76 | ||||
| -rw-r--r-- | tests/bot/rules/test_attachments.py | 97 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst.py | 56 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst_shared.py | 59 | ||||
| -rw-r--r-- | tests/bot/rules/test_chars.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 54 | ||||
| -rw-r--r-- | tests/bot/rules/test_duplicates.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_links.py | 84 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 90 | ||||
| -rw-r--r-- | tests/bot/rules/test_newlines.py | 105 | ||||
| -rw-r--r-- | tests/bot/rules/test_role_mentions.py | 57 | ||||
| -rw-r--r-- | tests/bot/test_api.py | 64 | ||||
| -rw-r--r-- | tests/helpers.py | 29 | 
20 files changed, 1839 insertions, 451 deletions
| diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@  import logging  import unittest  from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers  class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase):              standard_message = self._truncateMessage(base_message, record_message)              msg = self._formatMessage(msg, standard_message)              self.fail(msg) + + +class CommandTestCase(unittest.TestCase): +    """TestCase with additional assertions that are useful for testing Discord commands.""" + +    @helpers.async_test +    async def assertHasPermissionsCheck( +        self, +        cmd: commands.Command, +        permissions: Dict[str, bool], +    ) -> None: +        """ +        Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + +        Every permission in `permissions` is expected to be reported as missing. In other words, do +        not include permissions which should not raise an exception along with those which should. +        """ +        # Invert permission values because it's more intuitive to pass to this assertion the same +        # permissions as those given to the check decorator. +        permissions = {k: not v for k, v in permissions.items()} + +        ctx = helpers.MockContext() +        ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + +        with self.assertRaises(commands.MissingPermissions) as cm: +            await cmd.can_run(ctx) + +        self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..e6a6f9688 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): +    """Syncer subclass with mocks for abstract methods for testing purposes.""" + +    name = "test" +    _get_diff = helpers.AsyncMock() +    _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): +    """Tests for the syncer base class.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +    def test_instantiation_fails_without_abstract_methods(self): +        """The class must have abstract methods implemented.""" +        with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): +            Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): +    """Tests for sending the sync confirmation prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) + +    def mock_get_channel(self): +        """Fixture to return a mock channel and message for when `get_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        mock_channel.send.return_value = mock_message +        self.bot.get_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    def mock_fetch_channel(self): +        """Fixture to return a mock channel and message for when `fetch_channel` is used.""" +        self.bot.reset_mock() + +        mock_channel = helpers.MockTextChannel() +        mock_message = helpers.MockMessage() + +        self.bot.get_channel.return_value = None +        mock_channel.send.return_value = mock_message +        self.bot.fetch_channel.return_value = mock_channel + +        return mock_channel, mock_message + +    @helpers.async_test +    async def test_send_prompt_edits_and_returns_message(self): +        """The given message should be edited to display the prompt and then should be returned.""" +        msg = helpers.MockMessage() +        ret_val = await self.syncer._send_prompt(msg) + +        msg.edit.assert_called_once() +        self.assertIn("content", msg.edit.call_args[1]) +        self.assertEqual(ret_val, msg) + +    @helpers.async_test +    async def test_send_prompt_gets_dev_core_channel(self): +        """The dev-core channel should be retrieved if an extant message isn't given.""" +        subtests = ( +            (self.bot.get_channel, self.mock_get_channel), +            (self.bot.fetch_channel, self.mock_fetch_channel), +        ) + +        for method, mock_ in subtests: +            with self.subTest(method=method, msg=mock_.__name__): +                mock_() +                await self.syncer._send_prompt() + +                method.assert_called_once_with(constants.Channels.devcore) + +    @helpers.async_test +    async def test_send_prompt_returns_None_if_channel_fetch_fails(self): +        """None should be returned if there's an HTTPException when fetching the channel.""" +        self.bot.get_channel.return_value = None +        self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + +        ret_val = await self.syncer._send_prompt() + +        self.assertIsNone(ret_val) + +    @helpers.async_test +    async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): +        """A new message mentioning core devs should be sent and returned if message isn't given.""" +        for mock_ in (self.mock_get_channel, self.mock_fetch_channel): +            with self.subTest(msg=mock_.__name__): +                mock_channel, mock_message = mock_() +                ret_val = await self.syncer._send_prompt() + +                mock_channel.send.assert_called_once() +                self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) +                self.assertEqual(ret_val, mock_message) + +    @helpers.async_test +    async def test_send_prompt_adds_reactions(self): +        """The message should have reactions for confirmation added.""" +        extant_message = helpers.MockMessage() +        subtests = ( +            (extant_message, lambda: (None, extant_message)), +            (None, self.mock_get_channel), +            (None, self.mock_fetch_channel), +        ) + +        for message_arg, mock_ in subtests: +            subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + +            with self.subTest(msg=subtest_msg): +                _, mock_message = mock_() +                await self.syncer._send_prompt(message_arg) + +                calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] +                mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): +    """Tests for waiting for a sync confirmation reaction on the prompt.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = TestSyncer(self.bot) +        self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer) + +    @staticmethod +    def get_message_reaction(emoji): +        """Fixture to return a mock message an reaction from the given `emoji`.""" +        message = helpers.MockMessage() +        reaction = helpers.MockReaction(emoji=emoji, message=message) + +        return message, reaction + +    def test_reaction_check_for_valid_emoji_and_authors(self): +        """Should return True if authors are identical or are a bot and a core dev, respectively.""" +        user_subtests = ( +            ( +                helpers.MockMember(id=77), +                helpers.MockMember(id=77), +                "identical users", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "bot author and core-dev reactor", +            ), +        ) + +        for emoji in self.syncer._REACTION_EMOJIS: +            for author, user, msg in user_subtests: +                with self.subTest(author=author, user=user, emoji=emoji, msg=msg): +                    message, reaction = self.get_message_reaction(emoji) +                    ret_val = self.syncer._reaction_check(author, message, reaction, user) + +                    self.assertTrue(ret_val) + +    def test_reaction_check_for_invalid_reactions(self): +        """Should return False for invalid reaction events.""" +        valid_emoji = self.syncer._REACTION_EMOJIS[0] +        subtests = ( +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43, roles=[self.core_dev_role]), +                "users are not identical", +            ), +            ( +                helpers.MockMember(id=77, bot=True), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=43), +                "reactor lacks the core-dev role", +            ), +            ( +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                *self.get_message_reaction(valid_emoji), +                helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), +                "reactor is a bot", +            ), +            ( +                helpers.MockMember(id=77), +                helpers.MockMessage(id=95), +                helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), +                helpers.MockMember(id=77), +                "messages are not identical", +            ), +            ( +                helpers.MockMember(id=77), +                *self.get_message_reaction("InVaLiD"), +                helpers.MockMember(id=77), +                "emoji is invalid", +            ), +        ) + +        for *args, msg in subtests: +            kwargs = dict(zip(("author", "message", "reaction", "user"), args)) +            with self.subTest(**kwargs, msg=msg): +                ret_val = self.syncer._reaction_check(*args) +                self.assertFalse(ret_val) + +    @helpers.async_test +    async def test_wait_for_confirmation(self): +        """The message should always be edited and only return True if the emoji is a check mark.""" +        subtests = ( +            (constants.Emojis.check_mark, True, None), +            ("InVaLiD", False, None), +            (None, False, TimeoutError), +        ) + +        for emoji, ret_val, side_effect in subtests: +            for bot in (True, False): +                with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): +                    # Set up mocks +                    message = helpers.MockMessage() +                    member = helpers.MockMember(bot=bot) + +                    self.bot.wait_for.reset_mock() +                    self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) +                    self.bot.wait_for.side_effect = side_effect + +                    # Call the function +                    actual_return = await self.syncer._wait_for_confirmation(member, message) + +                    # Perform assertions +                    self.bot.wait_for.assert_called_once() +                    self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + +                    message.edit.assert_called_once() +                    kwargs = message.edit.call_args[1] +                    self.assertIn("content", kwargs) + +                    # Core devs should only be mentioned if the author is a bot. +                    if bot: +                        self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) +                    else: +                        self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + +                    self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): +    """Tests for main function orchestrating the sync.""" + +    def setUp(self): +        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) +        self.syncer = TestSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_respects_confirmation_result(self): +        """The sync should abort if confirmation fails and continue if confirmed.""" +        mock_message = helpers.MockMessage() +        subtests = ( +            (True, mock_message), +            (False, None), +        ) + +        for confirmed, message in subtests: +            with self.subTest(confirmed=confirmed): +                self.syncer._sync.reset_mock() +                self.syncer._get_diff.reset_mock() + +                diff = _Diff({1, 2, 3}, {4, 5}, None) +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(confirmed, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() + +                if confirmed: +                    self.syncer._sync.assert_called_once_with(diff) +                else: +                    self.syncer._sync.assert_not_called() + +    @helpers.async_test +    async def test_sync_diff_size(self): +        """The diff size should be correctly calculated.""" +        subtests = ( +            (6, _Diff({1, 2}, {3, 4}, {5, 6})), +            (5, _Diff({1, 2, 3}, None, {4, 5})), +            (0, _Diff(None, None, None)), +            (0, _Diff(set(), set(), set())), +        ) + +        for size, diff in subtests: +            with self.subTest(size=size, diff=diff): +                self.syncer._get_diff.reset_mock() +                self.syncer._get_diff.return_value = diff +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                self.syncer._get_diff.assert_called_once_with(guild) +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + +    @helpers.async_test +    async def test_sync_message_edited(self): +        """The message should be edited if one was sent, even if the sync has an API error.""" +        subtests = ( +            (None, None, False), +            (helpers.MockMessage(), None, True), +            (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), +        ) + +        for message, side_effect, should_edit in subtests: +            with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): +                self.syncer._sync.side_effect = side_effect +                self.syncer._get_confirmation_result = helpers.AsyncMock( +                    return_value=(True, message) +                ) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild) + +                if should_edit: +                    message.edit.assert_called_once() +                    self.assertIn("content", message.edit.call_args[1]) + +    @helpers.async_test +    async def test_sync_confirmation_context_redirect(self): +        """If ctx is given, a new message should be sent and author should be ctx's author.""" +        mock_member = helpers.MockMember() +        subtests = ( +            (None, self.bot.user, None), +            (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), +        ) + +        for ctx, author, message in subtests: +            with self.subTest(ctx=ctx, author=author, message=message): +                if ctx is not None: +                    ctx.send.return_value = message + +                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + +                guild = helpers.MockGuild() +                await self.syncer.sync(guild, ctx) + +                if ctx is not None: +                    ctx.send.assert_called_once() + +                self.syncer._get_confirmation_result.assert_called_once() +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) +                self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_small_diff(self): +        """Should always return True and the given message if the diff size is too small.""" +        author = helpers.MockMember() +        expected_message = helpers.MockMessage() + +        for size in (3, 2): +            with self.subTest(size=size): +                self.syncer._send_prompt = helpers.AsyncMock() +                self.syncer._wait_for_confirmation = helpers.AsyncMock() + +                coro = self.syncer._get_confirmation_result(size, author, expected_message) +                result, actual_message = await coro + +                self.assertTrue(result) +                self.assertEqual(actual_message, expected_message) +                self.syncer._send_prompt.assert_not_called() +                self.syncer._wait_for_confirmation.assert_not_called() + +    @mock.patch.object(constants.Sync, "max_diff", new=3) +    @helpers.async_test +    async def test_confirmation_result_large_diff(self): +        """Should return True if confirmed and False if _send_prompt fails or aborted.""" +        author = helpers.MockMember() +        mock_message = helpers.MockMessage() + +        subtests = ( +            (True, mock_message, True, "confirmed"), +            (False, None, False, "_send_prompt failed"), +            (False, mock_message, False, "aborted"), +        ) + +        for expected_result, expected_message, confirmed, msg in subtests: +            with self.subTest(msg=msg): +                self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) +                self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + +                coro = self.syncer._get_confirmation_result(4, author) +                actual_result, actual_message = await coro + +                self.syncer._send_prompt.assert_called_once_with(None)  # message defaults to None +                self.assertIs(actual_result, expected_result) +                self.assertEqual(actual_message, expected_message) + +                if expected_message: +                    self.syncer._wait_for_confirmation.assert_called_once_with( +                        author, expected_message +                    ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): +    """ +    A MagicMock subclass to mock Syncer objects. + +    Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` +    instances. For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): +    """Tests for the sync extension.""" + +    @staticmethod +    def test_extension_setup(): +        """The Sync cog should be added.""" +        bot = helpers.MockBot() +        sync.setup(bot) +        bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): +    """Base class for Sync cog tests. Sets up patches for syncers.""" + +    def setUp(self): +        self.bot = helpers.MockBot() + +        # These patch the type. When the type is called, a MockSyncer instanced is returned. +        # MockSyncer is needed so that our custom AsyncMock is used. +        # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. +        self.role_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.RoleSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.user_syncer_patcher = mock.patch( +            "bot.cogs.sync.syncers.UserSyncer", +            new=mock.MagicMock(return_value=MockSyncer()) +        ) +        self.RoleSyncer = self.role_syncer_patcher.start() +        self.UserSyncer = self.user_syncer_patcher.start() + +        self.cog = sync.Sync(self.bot) + +    def tearDown(self): +        self.role_syncer_patcher.stop() +        self.user_syncer_patcher.stop() + +    @staticmethod +    def response_error(status: int) -> ResponseCodeError: +        """Fixture to return a ResponseCodeError with the given status code.""" +        response = mock.MagicMock() +        response.status = status + +        return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): +    """Tests for the Sync cog.""" + +    @mock.patch.object(sync.Sync, "sync_guild") +    def test_sync_cog_init(self, sync_guild): +        """Should instantiate syncers and run a sync for the guild.""" +        # Reset because a Sync cog was already instantiated in setUp. +        self.RoleSyncer.reset_mock() +        self.UserSyncer.reset_mock() +        self.bot.loop.create_task.reset_mock() + +        mock_sync_guild_coro = mock.MagicMock() +        sync_guild.return_value = mock_sync_guild_coro + +        sync.Sync(self.bot) + +        self.RoleSyncer.assert_called_once_with(self.bot) +        self.UserSyncer.assert_called_once_with(self.bot) +        sync_guild.assert_called_once_with() +        self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + +    @helpers.async_test +    async def test_sync_cog_sync_guild(self): +        """Roles and users should be synced only if a guild is successfully retrieved.""" +        for guild in (helpers.MockGuild(), None): +            with self.subTest(guild=guild): +                self.bot.reset_mock() +                self.cog.role_syncer.reset_mock() +                self.cog.user_syncer.reset_mock() + +                self.bot.get_guild = mock.MagicMock(return_value=guild) + +                await self.cog.sync_guild() + +                self.bot.wait_until_guild_available.assert_called_once() +                self.bot.get_guild.assert_called_once_with(constants.Guild.id) + +                if guild is None: +                    self.cog.role_syncer.sync.assert_not_called() +                    self.cog.user_syncer.sync.assert_not_called() +                else: +                    self.cog.role_syncer.sync.assert_called_once_with(guild) +                    self.cog.user_syncer.sync.assert_called_once_with(guild) + +    async def patch_user_helper(self, side_effect: BaseException) -> None: +        """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" +        self.bot.api_client.patch.reset_mock(side_effect=True) +        self.bot.api_client.patch.side_effect = side_effect + +        user_id, updated_information = 5, {"key": 123} +        await self.cog.patch_user(user_id, updated_information) + +        self.bot.api_client.patch.assert_called_once_with( +            f"bot/users/{user_id}", +            json=updated_information, +        ) + +    @helpers.async_test +    async def test_sync_cog_patch_user(self): +        """A PATCH request should be sent and 404 errors ignored.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                await self.patch_user_helper(side_effect) + +    @helpers.async_test +    async def test_sync_cog_patch_user_non_404(self): +        """A PATCH request should be sent and the error raised if it's not a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): +    """Tests for the listeners of the Sync cog.""" + +    def setUp(self): +        super().setUp() +        self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_create(self): +        """A POST request should be sent with the new role's data.""" +        self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        role = helpers.MockRole(**role_data) +        await self.cog.on_guild_role_create(role) + +        self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_delete(self): +        """A DELETE request should be sent.""" +        self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + +        role = helpers.MockRole(id=99) +        await self.cog.on_guild_role_delete(role) + +        self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + +    @helpers.async_test +    async def test_sync_cog_on_guild_role_update(self): +        """A PUT request should be sent if the colour, name, permissions, or position changes.""" +        self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + +        role_data = { +            "colour": 49, +            "id": 777, +            "name": "rolename", +            "permissions": 8, +            "position": 23, +        } +        subtests = ( +            (True, ("colour", "name", "permissions", "position")), +            (False, ("hoist", "mentionable")), +        ) + +        for should_put, attributes in subtests: +            for attribute in attributes: +                with self.subTest(should_put=should_put, changed_attribute=attribute): +                    self.bot.api_client.put.reset_mock() + +                    after_role_data = role_data.copy() +                    after_role_data[attribute] = 876 + +                    before_role = helpers.MockRole(**role_data) +                    after_role = helpers.MockRole(**after_role_data) + +                    await self.cog.on_guild_role_update(before_role, after_role) + +                    if should_put: +                        self.bot.api_client.put.assert_called_once_with( +                            f"bot/roles/{after_role.id}", +                            json=after_role_data +                        ) +                    else: +                        self.bot.api_client.put.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_remove(self): +        """Member should patched to set in_guild as False.""" +        self.assertTrue(self.cog.on_member_remove.__cog_listener__) + +        member = helpers.MockMember() +        await self.cog.on_member_remove(member) + +        self.cog.patch_user.assert_called_once_with( +            member.id, +            updated_information={"in_guild": False} +        ) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_roles(self): +        """Members should be patched if their roles have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        # Roles are intentionally unsorted. +        before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] +        before_member = helpers.MockMember(roles=before_roles) +        after_member = helpers.MockMember(roles=before_roles[1:]) + +        await self.cog.on_member_update(before_member, after_member) + +        data = {"roles": sorted(role.id for role in after_member.roles)} +        self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + +    @helpers.async_test +    async def test_sync_cog_on_member_update_other(self): +        """Members should not be patched if other attributes have changed.""" +        self.assertTrue(self.cog.on_member_update.__cog_listener__) + +        subtests = ( +            ("activities", discord.Game("Pong"), discord.Game("Frogger")), +            ("nick", "old nick", "new nick"), +            ("status", discord.Status.online, discord.Status.offline), +        ) + +        for attribute, old_value, new_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                before_member = helpers.MockMember(**{attribute: old_value}) +                after_member = helpers.MockMember(**{attribute: new_value}) + +                await self.cog.on_member_update(before_member, after_member) + +                self.cog.patch_user.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_user_update(self): +        """A user should be patched only if the name, discriminator, or avatar changes.""" +        self.assertTrue(self.cog.on_user_update.__cog_listener__) + +        before_data = { +            "name": "old name", +            "discriminator": "1234", +            "avatar": "old avatar", +            "bot": False, +        } + +        subtests = ( +            (True, "name", "name", "new name", "new name"), +            (True, "discriminator", "discriminator", "8765", 8765), +            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), +            (False, "bot", "bot", True, True), +        ) + +        for should_patch, attribute, api_field, value, api_value in subtests: +            with self.subTest(attribute=attribute): +                self.cog.patch_user.reset_mock() + +                after_data = before_data.copy() +                after_data[attribute] = value +                before_user = helpers.MockUser(**before_data) +                after_user = helpers.MockUser(**after_data) + +                await self.cog.on_user_update(before_user, after_user) + +                if should_patch: +                    self.cog.patch_user.assert_called_once() + +                    # Don't care if *all* keys are present; only the changed one is required +                    call_args = self.cog.patch_user.call_args +                    self.assertEqual(call_args[0][0], after_user.id) +                    self.assertIn("updated_information", call_args[1]) + +                    updated_information = call_args[1]["updated_information"] +                    self.assertIn(api_field, updated_information) +                    self.assertEqual(updated_information[api_field], api_value) +                else: +                    self.cog.patch_user.assert_not_called() + +    async def on_member_join_helper(self, side_effect: Exception) -> dict: +        """ +        Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + +        The request data for the mock member is returned. All exceptions will be re-raised. +        """ +        member = helpers.MockMember( +            discriminator="1234", +            roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], +        ) + +        data = { +            "avatar_hash": member.avatar, +            "discriminator": int(member.discriminator), +            "id": member.id, +            "in_guild": True, +            "name": member.name, +            "roles": sorted(role.id for role in member.roles) +        } + +        self.bot.api_client.put.reset_mock(side_effect=True) +        self.bot.api_client.put.side_effect = side_effect + +        try: +            await self.cog.on_member_join(member) +        except Exception: +            raise +        finally: +            self.bot.api_client.put.assert_called_once_with( +                f"bot/users/{member.id}", +                json=data +            ) + +        return data + +    @helpers.async_test +    async def test_sync_cog_on_member_join(self): +        """Should PUT user's data or POST it if the user doesn't exist.""" +        for side_effect in (None, self.response_error(404)): +            with self.subTest(side_effect=side_effect): +                self.bot.api_client.post.reset_mock() +                data = await self.on_member_join_helper(side_effect) + +                if side_effect: +                    self.bot.api_client.post.assert_called_once_with("bot/users", json=data) +                else: +                    self.bot.api_client.post.assert_not_called() + +    @helpers.async_test +    async def test_sync_cog_on_member_join_non_404(self): +        """ResponseCodeError should be re-raised if status code isn't a 404.""" +        with self.assertRaises(ResponseCodeError): +            await self.on_member_join_helper(self.response_error(500)) + +        self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): +    """Tests for the commands in the Sync cog.""" + +    @helpers.async_test +    async def test_sync_roles_command(self): +        """sync() should be called on the RoleSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_roles_command.callback(self.cog, ctx) + +        self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    @helpers.async_test +    async def test_sync_users_command(self): +        """sync() should be called on the UserSyncer.""" +        ctx = helpers.MockContext() +        await self.cog.sync_users_command.callback(self.cog, ctx) + +        self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + +    def test_commands_require_admin(self): +        """The sync commands should only run if the author has the administrator permission.""" +        cmds = ( +            self.cog.sync_group, +            self.cog.sync_roles_command, +            self.cog.sync_users_command, +        ) + +        for cmd in cmds: +            with self.subTest(cmd=cmd): +                self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): -    """Tests constructing the roles to synchronize with the site.""" - -    def test_get_roles_for_sync_empty_return_for_equal_roles(self): -        """No roles should be synced when no diff is found.""" -        api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), set(), set()) -        ) - -    def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): -        """Roles to be synced are returned when non-ID attributes differ.""" -        api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} -        guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            (set(), guild_roles, set()) -        ) - -    def test_get_roles_only_returns_roles_that_require_update(self): -        """Roles that require an update should be returned as the second tuple element.""" -        api_roles = { -            Role(id=41, name='old name', colour=33, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } -        guild_roles = { -            Role(id=41, name='new name', colour=35, permissions=0x8, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_new_roles_in_first_tuple_element(self): -        """Newly created roles are returned as the first tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=53, name='other role', colour=55, permissions=0, position=2) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, -                set(), -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_update_and_new_roles(self): -        """Newly created and updated roles should be returned together.""" -        api_roles = { -            Role(id=41, name='old name', colour=35, permissions=0x8, position=1), -        } -        guild_roles = { -            Role(id=41, name='new name', colour=40, permissions=0x16, position=2), -            Role(id=53, name='other role', colour=55, permissions=0, position=3) -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, -                {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, -                set(), -            ) -        ) - -    def test_get_roles_returns_roles_to_delete(self): -        """Roles to be deleted should be returned as the third tuple element.""" -        api_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -        } -        guild_roles = { -            Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                set(), -                set(), -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) - -    def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): -        """When roles were added, updated, and removed, all of them are returned properly.""" -        api_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -            Role(id=71, name='to update', colour=99, permissions=0x9, position=3), -        } -        guild_roles = { -            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -            Role(id=81, name='to create', colour=99, permissions=0x9, position=4), -            Role(id=71, name='updated', colour=101, permissions=0x5, position=3), -        } - -        self.assertEqual( -            get_roles_for_sync(guild_roles, api_roles), -            ( -                {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, -                {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, -                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -            ) -        ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): +    """Fixture to return a dictionary representing a role with default values set.""" +    kwargs.setdefault("id", 9) +    kwargs.setdefault("name", "fake role") +    kwargs.setdefault("colour", 7) +    kwargs.setdefault("permissions", 0) +    kwargs.setdefault("position", 55) + +    return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @staticmethod +    def get_guild(*roles): +        """Fixture to return a guild object with the given roles.""" +        guild = helpers.MockGuild() +        guild.roles = [] + +        for role in roles: +            mock_role = helpers.MockRole(**role) +            mock_role.colour = discord.Colour(role["colour"]) +            mock_role.permissions = discord.Permissions(role["permissions"]) +            guild.roles.append(mock_role) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_identical_roles(self): +        """No differences should be found if the roles in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_roles(self): +        """Only updated roles should be added to the 'updated' set of the diff.""" +        updated_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] +        guild = self.get_guild(updated_role, fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_Role(**updated_role)}, set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_roles(self): +        """Only new roles should be added to the 'created' set of the diff.""" +        new_role = fake_role(id=41, name="new") + +        self.bot.api_client.get.return_value = [fake_role()] +        guild = self.get_guild(fake_role(), new_role) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new_role)}, set(), set()) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_deleted_roles(self): +        """Only deleted roles should be added to the 'deleted' set of the diff.""" +        deleted_role = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [fake_role(), deleted_role] +        guild = self.get_guild(fake_role()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), {_Role(**deleted_role)}) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_deleted_roles(self): +        """When roles are added, updated, and removed, all of them are returned properly.""" +        new = fake_role(id=41, name="new") +        updated = fake_role(id=71, name="updated") +        deleted = fake_role(id=61, name="deleted") + +        self.bot.api_client.get.return_value = [ +            fake_role(), +            fake_role(id=71, name="updated name"), +            deleted, +        ] +        guild = self.get_guild(fake_role(), new, updated) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + +        self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync roles.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = RoleSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_roles(self): +        """Only POST requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(role_tuples, set(), set()) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/roles", json=role) for role in roles] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_roles(self): +        """Only PUT requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), role_tuples, set()) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_deleted_roles(self): +        """Only DELETE requests should be made with the correct payload.""" +        roles = [fake_role(id=111), fake_role(id=222)] + +        role_tuples = {_Role(**role) for role in roles} +        diff = _Diff(set(), set(), role_tuples) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] +        self.bot.api_client.delete.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@  import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers  def fake_user(**kwargs): -    kwargs.setdefault('id', 43) -    kwargs.setdefault('name', 'bob the test man') -    kwargs.setdefault('discriminator', 1337) -    kwargs.setdefault('avatar_hash', None) -    kwargs.setdefault('roles', (666,)) -    kwargs.setdefault('in_guild', True) -    return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): -    """Tests constructing the users to synchronize with the site.""" - -    def test_get_users_for_sync_returns_nothing_for_empty_params(self): -        """When no users are given, none are returned.""" -        self.assertEqual( -            get_users_for_sync({}, {}), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_nothing_for_equal_users(self): -        """When no users are updated, none are returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) - -    def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): -        """When a non-ID-field differs, the user to update is returned.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(name='new fancy name')} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(name='new fancy name')}) -        ) - -    def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): -        """When new users join the guild, they are returned as the first tuple element.""" -        api_users = {43: fake_user()} -        guild_users = {43: fake_user(), 63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, set()) -        ) - -    def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): +    """Fixture to return a dictionary representing a user with default values set.""" +    kwargs.setdefault("id", 43) +    kwargs.setdefault("name", "bob the test man") +    kwargs.setdefault("discriminator", 1337) +    kwargs.setdefault("avatar_hash", None) +    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("in_guild", True) + +    return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): +    """Tests for determining differences between users in the DB and users in the Guild cache.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @staticmethod +    def get_guild(*members): +        """Fixture to return a guild object with the given members.""" +        guild = helpers.MockGuild() +        guild.members = [] + +        for member in members: +            member = member.copy() +            member["avatar"] = member.pop("avatar_hash") +            del member["in_guild"] + +            mock_member = helpers.MockMember(**member) +            mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + +            guild.members.append(mock_member) + +        return guild + +    @helpers.async_test +    async def test_empty_diff_for_no_users(self): +        """When no users are given, an empty diff should be returned.""" +        guild = self.get_guild() + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_identical_users(self): +        """No differences should be found if the users in the guild and DB are identical.""" +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_updated_users(self): +        """Only updated users should be added to the 'updated' set of the diff.""" +        updated_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        guild = self.get_guild(updated_user, fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**updated_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_users(self): +        """Only new users should be added to the 'created' set of the diff.""" +        new_user = fake_user(id=99, name="new") + +        self.bot.api_client.get.return_value = [fake_user()] +        guild = self.get_guild(fake_user(), new_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, set(), None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        api_users = {43: fake_user(), 63: fake_user(id=63)} -        guild_users = {43: fake_user()} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), {fake_user(id=63, in_guild=False)}) -        ) - -    def test_get_users_for_sync_updates_and_creates_users_as_needed(self): -        """When one user left and another one was updated, both are returned.""" -        api_users = {43: fake_user()} -        guild_users = {63: fake_user(id=63)} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            ({fake_user(id=63)}, {fake_user(in_guild=False)}) -        ) - -    def test_get_users_for_sync_does_not_duplicate_update_users(self): -        """When the API knows a user the guild doesn't, nothing is performed.""" -        api_users = {43: fake_user(in_guild=False)} -        guild_users = {} - -        self.assertEqual( -            get_users_for_sync(guild_users, api_users), -            (set(), set()) -        ) +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), {_User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_diff_for_new_updated_and_leaving_users(self): +        """When users are added, updated, and removed, all of them are returned properly.""" +        new_user = fake_user(id=99, name="new") +        updated_user = fake_user(id=55, name="updated") +        leaving_user = fake_user(id=63, in_guild=False) + +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        guild = self.get_guild(fake_user(), new_user, updated_user) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + +        self.assertEqual(actual_diff, expected_diff) + +    @helpers.async_test +    async def test_empty_diff_for_db_users_not_in_guild(self): +        """When the DB knows a user the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        guild = self.get_guild(fake_user()) + +        actual_diff = await self.syncer._get_diff(guild) +        expected_diff = (set(), set(), None) + +        self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): +    """Tests for the API requests that sync users.""" + +    def setUp(self): +        self.bot = helpers.MockBot() +        self.syncer = UserSyncer(self.bot) + +    @helpers.async_test +    async def test_sync_created_users(self): +        """Only POST requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(user_tuples, set(), None) +        await self.syncer._sync(diff) + +        calls = [mock.call("bot/users", json=user) for user in users] +        self.bot.api_client.post.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.post.call_count, len(users)) + +        self.bot.api_client.put.assert_not_called() +        self.bot.api_client.delete.assert_not_called() + +    @helpers.async_test +    async def test_sync_updated_users(self): +        """Only PUT requests should be made with the correct payload.""" +        users = [fake_user(id=111), fake_user(id=222)] + +        user_tuples = {_User(**user) for user in users} +        diff = _Diff(set(), user_tuples, None) +        await self.syncer._sync(diff) + +        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] +        self.bot.api_client.put.assert_has_calls(calls, any_order=True) +        self.assertEqual(self.bot.api_client.put.call_count, len(users)) + +        self.bot.api_client.post.assert_not_called() +        self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..5b0a3b8c3 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase):          asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase):          with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:              asyncio.run(self.cog.fetch_webhook()) -        self.bot.wait_until_ready.assert_called_once() +        self.bot.wait_until_guild_available.assert_called_once()          self.bot.fetch_webhook.assert_called_once_with(1)          self.assertEqual(len(log_watcher.records), 1) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 4496a2ae0..deae7ebad 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase):                  )              ],              members=[ -                *(helpers.MockMember(status='online') for _ in range(2)), -                *(helpers.MockMember(status='idle') for _ in range(1)), -                *(helpers.MockMember(status='dnd') for _ in range(4)), -                *(helpers.MockMember(status='offline') for _ in range(3)), +                *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), +                *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), +                *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), +                *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),              ],              member_count=1_234,              icon_url='a-lemon.jpg', @@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase):                  **Counts**                  Members: {self.ctx.guild.member_count:,}                  Roles: {len(self.ctx.guild.roles)} -                Text: 1 -                Voice: 1 -                Channel categories: 1 +                Category channels: 1 +                Text channels: 1 +                Voice channels: 1                  **Members**                  {constants.Emojis.status_online} 2 diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..36c986fe1 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): +    """Encapsulation for test cases expected to fail.""" +    recent_messages: List[MockMessage] +    culprits: Iterable[str] +    n_violations: int + + +class RuleTest(unittest.TestCase, metaclass=ABCMeta): +    """ +    Abstract class for antispam rule test cases. + +    Tests for specific rules should inherit from `RuleTest` and implement +    `relevant_messages` and `get_report`. Each instance should also set the +    `apply` and `config` attributes as necessary. + +    The execution of test cases can then be delegated to the `run_allowed` +    and `run_disallowed` methods. +    """ + +    apply: Callable  # The tested rule's apply function +    config: Dict[str, int] + +    async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to pass.""" +        for recent_messages in cases: +            last_message = recent_messages[0] + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                config=self.config, +            ): +                self.assertIsNone( +                    await self.apply(last_message, recent_messages, self.config) +                ) + +    async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to fail.""" +        for case in cases: +            recent_messages, culprits, n_violations = case +            last_message = recent_messages[0] +            relevant_messages = self.relevant_messages(case) +            desired_output = ( +                self.get_report(case), +                culprits, +                relevant_messages, +            ) + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                relevant_messages=relevant_messages, +                n_violations=n_violations, +                config=self.config, +            ): +                self.assertTupleEqual( +                    await self.apply(last_message, recent_messages, self.config), +                    desired_output, +                ) + +    @abstractmethod +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        """Give expected relevant messages for `case`.""" +        raise NotImplementedError + +    @abstractmethod +    def get_report(self, case: DisallowedCase) -> str: +        """Give expected error report for `case`.""" +        raise NotImplementedError diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index d7187f315..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,98 +1,71 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_attachments: int - - -def msg(author: str, total_attachments: int) -> MockMessage: +def make_msg(author: str, total_attachments: int) -> MockMessage:      """Builds a message with `total_attachments` attachments."""      return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest):      """Tests applying the `attachments` antispam rule."""      def setUp(self): -        self.config = {"max": 5} +        self.apply = attachments.apply +        self.config = {"max": 5, "interval": 10}      @async_test      async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( -            [msg("bob", 0), msg("bob", 0), msg("bob", 0)], -            [msg("bob", 2), msg("bob", 2)], -            [msg("bob", 2), msg("alice", 2), msg("bob", 2)], +            [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], +            [make_msg("bob", 2), make_msg("bob", 2)], +            [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await attachments.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( -            Case( -                [msg("bob", 4), msg("bob", 0), msg("bob", 6)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],                  ("bob",), -                10 +                10,              ), -            Case( -                [msg("bob", 4), msg("alice", 6), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],                  ("bob",), -                6 +                6,              ), -            Case( -                [msg("alice", 6)], +            DisallowedCase( +                [make_msg("alice", 6)],                  ("alice",), -                6 +                6,              ), -            ( -                [msg("alice", 1) for _ in range(6)], +            DisallowedCase( +                [make_msg("alice", 1) for _ in range(6)],                  ("alice",), -                6 +                6,              ),          ) -        for recent_messages, culprit, total_attachments in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if ( -                    msg.author == last_message.author -                    and len(msg.attachments) > 0 -                ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and len(msg.attachments) > 0              ) +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                total_attachments=total_attachments, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_attachments} attachments in {self.config['max']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await attachments.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..72f0be0c7 --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,56 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with author set to `author`. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): +    """Tests the `burst` antispam rule.""" + +    def setUp(self): +        self.apply = burst.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("bob"), make_msg("bob")], +            [make_msg("bob"), make_msg("alice"), make_msg("bob")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                ("bob",), +                3, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..47367a5f8 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,59 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with the passed arg. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): +    """Tests the `burst_shared` antispam rule.""" + +    def setUp(self): +        self.apply = burst_shared.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """ +        Cases that do not violate the rule. + +        There really isn't more to test here than a single case. +        """ +        cases = ( +            [make_msg("spongebob"), make_msg("patrick")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                {"bob"}, +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                {"bob", "alice"}, +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return case.recent_messages + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..7cc36f49e --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_chars: int) -> MockMessage: +    """Build a message with arbitrary content of `n_chars` length.""" +    return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): +    """Tests the `chars` antispam rule.""" + +    def setUp(self): +        self.apply = chars.apply +        self.config = { +            "max": 20,  # Max allowed sum of chars per user +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of chars within limit.""" +        cases = ( +            [make_msg("bob", 0)], +            [make_msg("bob", 20)], +            [make_msg("bob", 15), make_msg("alice", 15)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the total amount of chars exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 21)], +                ("bob",), +                21, +            ), +            DisallowedCase( +                [make_msg("bob", 15), make_msg("bob", 15)], +                ("bob",), +                30, +            ), +            DisallowedCase( +                [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], +                ("alice",), +                30, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..0239b0b00 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + +discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: +    """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" +    return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): +    """Tests for the `discord_emojis` antispam rule.""" + +    def setUp(self): +        self.apply = discord_emojis.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of discord emojis within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of discord emojis.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..59e0fb6ef --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, content: str) -> MockMessage: +    """Give a MockMessage instance with `author` and `content` attrs.""" +    return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): +    """Tests the `duplicates` antispam rule.""" + +    def setUp(self): +        self.apply = duplicates.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", "A"), make_msg("alice", "A")], +            [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")],  # Non-duplicate +            [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")],  # Different author +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with too many duplicate messages from the same author.""" +        cases = ( +            DisallowedCase( +                [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], +                ("alice",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 duplicate messages, but only 3 from bob +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 message from bob, but only 3 duplicates +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and msg.content == last_message.content +            ) +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 02a5d5501..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,26 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import links +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_links: int - - -def msg(author: str, total_links: int) -> MockMessage: +def make_msg(author: str, total_links: int) -> MockMessage:      """Makes a message with `total_links` links."""      content = " ".join(["https://pydis.com"] * total_links)      return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest):      """Tests applying the `links` rule."""      def setUp(self): +        self.apply = links.apply          self.config = {              "max": 2,              "interval": 10 @@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase):      async def test_links_within_limit(self):          """Messages with an allowed amount of links."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 2), msg("alice", 2)]  # Only messages from latest author count +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await links.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( -            Case( -                [msg("bob", 1), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 1), make_msg("bob", 2)],                  ("bob",),                  3              ), -            Case( -                [msg("alice", 1), msg("alice", 1), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],                  ("alice",),                  3              ), -            Case( -                [msg("alice", 2), msg("bob", 3), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],                  ("alice",),                  3              )          ) -        for recent_messages, culprit, total_links in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_links=total_links, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_links} links in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await links.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage:      """Makes a message with `total_mentions` mentions."""      return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest):      """Tests applying the `mentions` antispam rule."""      def setUp(self): +        self.apply = mentions.apply          self.config = {              "max": 2, -            "interval": 10 +            "interval": 10,          }      @async_test      async def test_mentions_within_limit(self):          """Messages with an allowed amount of mentions."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 1), msg("alice", 2)] +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 1), make_msg("alice", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await mentions.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( -            Case( -                [msg("bob", 3)], +            DisallowedCase( +                [make_msg("bob", 3)],                  ("bob",), -                3 +                3,              ), -            Case( -                [msg("alice", 2), msg("alice", 0), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],                  ("alice",), -                3 +                3,              ), -            Case( -                [msg("bob", 2), msg("alice", 3), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",), -                4 +                4,              )          ) -        for recent_messages, culprit, total_mentions in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_mentions=total_mentions, -                cofig=self.config -            ): -                desired_output = ( -                    f"sent {total_mentions} mentions in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await mentions.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..d61c4609d --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,105 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: +    """Init a MockMessage instance with `author` and content configured by `newline_groups". + +    Configure content by passing a list of ints, where each int `n` will generate +    a separate group of `n` newlines. + +    Example: +        newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" +    """ +    content = " ".join("\n" * n for n in newline_groups) +    return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): +    """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + +    def setUp(self): +        self.apply = newlines.apply +        self.config = { +            "max": 5,  # Max sum of newlines in relevant messages +            "max_consecutive": 3,  # Max newlines in one group, in one message +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", [])],  # Single message with no newlines +            [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])],  # 5 newlines in 2 messages +            [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])],  # 5 newlines from each author +            [make_msg("bob", [1]), make_msg("alice", [5])],  # Alice breaks the rule, but only bob is relevant +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_total(self): +        """Cases which violate the rule by having too many newlines in total.""" +        cases = ( +            DisallowedCase(  # Alice sends a total of 6 newlines (disallowed) +                [make_msg("alice", [2, 2]), make_msg("alice", [2])], +                ("alice",), +                6, +            ), +            DisallowedCase(  # Here we test that only alice's newlines count in the sum +                [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], +                ("alice",), +                7, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): +    """ +    Tests the `newlines` antispam rule against max consecutive newline violations. + +    As these violations yield a different error report, they require a different +    `get_report` implementation. +    """ + +    def setUp(self): +        self.apply = newlines.apply +        self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + +    @async_test +    async def test_disallows_messages_consecutive(self): +        """Cases which violate the rule due to having too many consecutive newlines.""" +        cases = ( +            DisallowedCase(  # Bob sends a group of newlines too large +                [make_msg("bob", [4])], +                ("bob",), +                4, +            ), +            DisallowedCase(  # Alice sends 5 in total (allowed), but 4 in one group (disallowed) +                [make_msg("alice", [1]), make_msg("alice", [4])], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..b339cccf7 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_mentions: int) -> MockMessage: +    """Build a MockMessage instance with `n_mentions` role mentions.""" +    return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): +    """Tests for the `role_mentions` antispam rule.""" + +    def setUp(self): +        self.apply = role_mentions.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of role mentions within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of role mentions.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..bdfcc73e4 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,9 +1,7 @@ -import logging  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock  from bot import api -from tests.base import LoggingTestCase  from tests.helpers import async_test @@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):          self.assertEqual(error.response_text, "")          self.assertIs(error.response, self.error_api_response) -    def test_responde_code_error_string_representation_default_initialization(self): +    def test_response_code_error_string_representation_default_initialization(self):          """Test the string representation of `ResponseCodeError` initialized without text or json."""          error = api.ResponseCodeError(response=self.error_api_response)          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):              response_text=text_data          )          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): -    """Tests the bot's API Log Handler.""" - -    @classmethod -    def setUpClass(cls): -        cls.debug_log_record = logging.LogRecord( -            name='my.logger', level=logging.DEBUG, -            pathname='my/logger.py', lineno=666, -            msg="Lemon wins", args=(), -            exc_info=None -        ) - -        cls.trace_log_record = logging.LogRecord( -            name='my.logger', level=logging.TRACE, -            pathname='my/logger.py', lineno=666, -            msg="This will not be logged", args=(), -            exc_info=None -        ) - -    def setUp(self): -        self.log_handler = api.APILoggingHandler(None) - -    def test_emit_appends_to_queue_with_stopped_event_loop(self): -        """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" -        with patch("bot.api.APILoggingHandler.ship_off") as ship_off: -            # Patch `ship_off` to ease testing against the return value of this coroutine. -            ship_off.return_value = 42 -            self.log_handler.emit(self.debug_log_record) - -        self.assertListEqual(self.log_handler.queue, [42]) - -    def test_emit_ignores_less_than_debug(self): -        """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" -        self.log_handler.emit(self.trace_log_record) -        self.assertListEqual(self.log_handler.queue, []) - -    def test_schedule_queued_tasks_for_empty_queue(self): -        """`APILoggingHandler` should not schedule anything when the queue is empty.""" -        with self.assertNotLogs(level=logging.DEBUG): -            self.log_handler.schedule_queued_tasks() - -    def test_schedule_queued_tasks_for_nonempty_queue(self): -        """`APILoggingHandler` should schedule logs when the queue is not empty.""" -        log = logging.getLogger("bot.api") - -        with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: -            self.log_handler.queue = [555] -            self.log_handler.schedule_queued_tasks() -            self.assertListEqual(self.log_handler.queue, []) -            create_task.assert_called_once_with(555) - -            [record] = logs.records -            self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") -            self.assertEqual(record.levelno, logging.DEBUG) -            self.assertEqual(record.name, 'bot.api') -            self.assertIn('via_handler', record.__dict__) diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..9d9dd5da6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional  import discord  from discord.ext.commands import Context +from bot.api import APIClient  from bot.bot import Bot @@ -269,9 +270,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      information, see the `MockGuild` docstring.      """      def __init__(self, **kwargs) -> None: -        default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} +        default_kwargs = { +            'id': next(self.discord_id), +            'name': 'role', +            'position': 1, +            'colour': discord.Colour(0xdeadbf), +            'permissions': discord.Permissions(), +        }          super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        if isinstance(self.colour, int): +            self.colour = discord.Colour(self.colour) + +        if isinstance(self.permissions, int): +            self.permissions = discord.Permissions(self.permissions) +          if 'mention' not in kwargs:              self.mention = f'&{self.name}' @@ -324,6 +337,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):              self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock APIClient objects. + +    Instances of this class will follow the specifications of `bot.api.APIClient` instances. +    For more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=APIClient, **kwargs) + +  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`  bot_instance = Bot(command_prefix=unittest.mock.MagicMock())  bot_instance.http_session = None @@ -340,6 +365,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      def __init__(self, **kwargs) -> None:          super().__init__(spec_set=bot_instance, **kwargs) +        self.api_client = MockAPIClient()          # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and          # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -503,6 +529,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage())          self.users = AsyncIteratorMock(kwargs.get('users', [])) +        self.__str__.return_value = str(self.emoji)  webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) | 
