diff options
| author | 2020-02-27 22:16:05 +0000 | |
|---|---|---|
| committer | 2020-02-27 22:16:05 +0000 | |
| commit | afeb92010e58569e4910c708fa63f0e808278e26 (patch) | |
| tree | 1e4f9baa3be386499d52b5373a901b3f2ddf24a8 /tests | |
| parent | Merge branch 'master' into spoiler-check (diff) | |
| parent | Merge pull request #798 from python-discord/bug/mod/bot-1v/infr-edit-task-cancel (diff) | |
Merge branch 'master' into spoiler-check
Diffstat (limited to 'tests')
25 files changed, 2798 insertions, 405 deletions
diff --git a/tests/README.md b/tests/README.md index 6ab9bc93e..be78821bf 100644 --- a/tests/README.md +++ b/tests/README.md @@ -2,7 +2,7 @@ Our bot is one of the most important tools we have for running our community. As we don't want that tool break, we decided that we wanted to write unit tests for it. We hope that in the future, we'll have a 100% test coverage for the bot. This guide will help you get started with writing the tests needed to achieve that. -_**Note:** This is a practical guide to getting started with writing tests for our bot, not a general introduction to writing unit tests in Python. If you're looking for a more general introduction, you may like Corey Schafer's [Python Tutorial: Unit Testing Your Code with the unittest Module](https://www.youtube.com/watch?v=6tNS--WetLI) or Ned Batchelder's PyCon talk [Getting Started Testing](https://www.youtube.com/watch?v=FxSsnHeWQBY)._ +_**Note:** This is a practical guide to getting started with writing tests for our bot, not a general introduction to writing unit tests in Python. If you're looking for a more general introduction, you can take a look at the [Additional resources](#additional-resources) section at the bottom of this page._ ## Tools @@ -15,6 +15,7 @@ We are using the following modules and packages for our unit tests: To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts: - `pipenv run test` will run `unittest` with `coverage.py` +- `pipenv run test path/to/test.py` will run a specific test. - `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report. If you want a coverage report, make sure to run the tests with `pipenv run test` *first*. @@ -211,3 +212,10 @@ All in all, it's not only important to consider if all statements or branches we Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. + +## Additional resources + +* [Ned Batchelder's PyCon talk: Getting Started Testing](https://www.youtube.com/watch?v=FxSsnHeWQBY) +* [Corey Schafer video about unittest](https://youtu.be/6tNS--WetLI) +* [RealPython tutorial on unittest testing](https://realpython.com/python-testing/) +* [RealPython tutorial on mocking](https://realpython.com/python-mock-library/) diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@ import logging import unittest from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase): standard_message = self._truncateMessage(base_message, record_message) msg = self._formatMessage(msg, standard_message) self.fail(msg) + + +class CommandTestCase(unittest.TestCase): + """TestCase with additional assertions that are useful for testing Discord commands.""" + + @helpers.async_test + async def assertHasPermissionsCheck( + self, + cmd: commands.Command, + permissions: Dict[str, bool], + ) -> None: + """ + Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + + Every permission in `permissions` is expected to be reported as missing. In other words, do + not include permissions which should not raise an exception along with those which should. + """ + # Invert permission values because it's more intuitive to pass to this assertion the same + # permissions as those given to the check decorator. + permissions = {k: not v for k, v in permissions.items()} + + ctx = helpers.MockContext() + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + + with self.assertRaises(commands.MissingPermissions) as cm: + await cmd.can_run(ctx) + + self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..e6a6f9688 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = helpers.AsyncMock() + _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + @helpers.async_test + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + @helpers.async_test + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.devcore) + + @helpers.async_test + async def test_send_prompt_returns_None_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + @helpers.async_test + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + @helpers.async_test + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + @helpers.async_test + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + @helpers.async_test + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + @helpers.async_test + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + @helpers.async_test + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + @helpers.async_test + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): + with self.subTest(size=size): + self.syncer._send_prompt = helpers.AsyncMock() + self.syncer._wait_for_confirmation = helpers.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: + with self.subTest(msg=msg): + self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): + """ + A MagicMock subclass to mock Syncer objects. + + Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` + instances. For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + # These patch the type. When the type is called, a MockSyncer instanced is returned. + # MockSyncer is needed so that our custom AsyncMock is used. + # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. + self.role_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.RoleSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.UserSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild") + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task.reset_mock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + @helpers.async_test + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + @helpers.async_test + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + @helpers.async_test + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + + @helpers.async_test + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + @helpers.async_test + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + @helpers.async_test + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data) + after_role = helpers.MockRole(**after_role_data) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_remove(self): + """Member should patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember() + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + updated_information={"in_guild": False} + ) + + @helpers.async_test + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles) + after_member = helpers.MockMember(roles=before_roles[1:]) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + + @helpers.async_test + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}) + after_member = helpers.MockMember(**{attribute: new_value}) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "avatar": "old avatar", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args[0][0], after_user.id) + self.assertIn("updated_information", call_args[1]) + + updated_information = call_args[1]["updated_information"] + self.assertIn(api_field, updated_information) + self.assertEqual(updated_information[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + ) + + data = { + "avatar_hash": member.avatar, + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + @helpers.async_test + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + @helpers.async_test + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + @helpers.async_test + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): - """Tests constructing the roles to synchronize with the site.""" - - def test_get_roles_for_sync_empty_return_for_equal_roles(self): - """No roles should be synced when no diff is found.""" - api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), set(), set()) - ) - - def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): - """Roles to be synced are returned when non-ID attributes differ.""" - api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), guild_roles, set()) - ) - - def test_get_roles_only_returns_roles_that_require_update(self): - """Roles that require an update should be returned as the second tuple element.""" - api_roles = { - Role(id=41, name='old name', colour=33, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - guild_roles = { - Role(id=41, name='new name', colour=35, permissions=0x8, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_new_roles_in_first_tuple_element(self): - """Newly created roles are returned as the first tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=2) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, - set(), - set(), - ) - ) - - def test_get_roles_returns_roles_to_update_and_new_roles(self): - """Newly created and updated roles should be returned together.""" - api_roles = { - Role(id=41, name='old name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='new name', colour=40, permissions=0x16, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, - {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_roles_to_delete(self): - """Roles to be deleted should be returned as the third tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - set(), - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) - - def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): - """When roles were added, updated, and removed, all of them are returned properly.""" - api_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - Role(id=71, name='to update', colour=99, permissions=0x9, position=3), - } - guild_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=81, name='to create', colour=99, permissions=0x9, position=4), - Role(id=71, name='updated', colour=101, permissions=0x5, position=3), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, - {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + @helpers.async_test + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers def fake_user(**kwargs): - kwargs.setdefault('id', 43) - kwargs.setdefault('name', 'bob the test man') - kwargs.setdefault('discriminator', 1337) - kwargs.setdefault('avatar_hash', None) - kwargs.setdefault('roles', (666,)) - kwargs.setdefault('in_guild', True) - return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): - """Tests constructing the users to synchronize with the site.""" - - def test_get_users_for_sync_returns_nothing_for_empty_params(self): - """When no users are given, none are returned.""" - self.assertEqual( - get_users_for_sync({}, {}), - (set(), set()) - ) - - def test_get_users_for_sync_returns_nothing_for_equal_users(self): - """When no users are updated, none are returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) - - def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): - """When a non-ID-field differs, the user to update is returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(name='new fancy name')} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(name='new fancy name')}) - ) - - def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): - """When new users join the guild, they are returned as the first tuple element.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(), 63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, set()) - ) - - def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("avatar_hash", None) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + member["avatar"] = member.pop("avatar_hash") + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + @helpers.async_test + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - api_users = {43: fake_user(), 63: fake_user(id=63)} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(id=63, in_guild=False)}) - ) - - def test_get_users_for_sync_updates_and_creates_users_as_needed(self): - """When one user left and another one was updated, both are returned.""" - api_users = {43: fake_user()} - guild_users = {63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, {fake_user(in_guild=False)}) - ) - - def test_get_users_for_sync_does_not_duplicate_update_users(self): - """When the API knows a user the guild doesn't, nothing is performed.""" - api_users = {43: fake_user(in_guild=False)} - guild_users = {} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py new file mode 100644 index 000000000..5b0a3b8c3 --- /dev/null +++ b/tests/bot/cogs/test_duck_pond.py @@ -0,0 +1,584 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import MagicMock, patch + +import discord + +from bot import constants +from bot.cogs import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.cogs.duck_pond" + + +class DuckPondTests(base.LoggingTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_loop.called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @helpers.async_test + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def test_send_webhook_correctly_passes_on_arguments(self): + """The `send_webhook` method should pass the arguments to the webhook correctly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + + content = "fake content" + username = "fake username" + avatar_url = "fake avatar_url" + embed = "fake embed" + + asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed)) + + self.cog.webhook.send.assert_called_once_with( + content=content, + username=username, + avatar_url=avatar_url, + embed=embed + ) + + def test_send_webhook_logs_when_sending_message_fails(self): + """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.") + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.send_webhook()) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + @helpers.async_test + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + @helpers.async_test + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(clean_content="", attachments=[]), False, False), + (helpers.MockMessage(clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") + + for side_effect in side_effects: + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + @helpers.async_test + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + @helpers.async_test + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + @helpers.async_test + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + @helpers.async_test + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 4496a2ae0..deae7ebad 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase): ) ], members=[ - *(helpers.MockMember(status='online') for _ in range(2)), - *(helpers.MockMember(status='idle') for _ in range(1)), - *(helpers.MockMember(status='dnd') for _ in range(4)), - *(helpers.MockMember(status='offline') for _ in range(3)), + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), ], member_count=1_234, icon_url='a-lemon.jpg', @@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase): **Counts** Members: {self.ctx.guild.member_count:,} Roles: {len(self.ctx.guild.roles)} - Text: 1 - Voice: 1 - Channel categories: 1 + Category channels: 1 + Text channels: 1 + Voice channels: 1 **Members** {constants.Emojis.status_online} 2 diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index efa7a50b1..9d1a62f7e 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -1,4 +1,3 @@ -import logging import unittest from unittest.mock import MagicMock @@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase): """Tests loading the `Security` cog.""" def test_security_cog_load(self): - """Cog loading logs a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MagicMock() - with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: - security.setup(bot) - bot.add_cog.assert_called_once() - - [line] = cm.output - self.assertIn("Cog loaded: Security", line) + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3276cf5a5..a54b839d7 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MockBot() - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: - setup_cog(bot) - - [line] = cm.output + setup_cog(bot) bot.add_cog.assert_called_once() - self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..36c986fe1 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): + """Encapsulation for test cases expected to fail.""" + recent_messages: List[MockMessage] + culprits: Iterable[str] + n_violations: int + + +class RuleTest(unittest.TestCase, metaclass=ABCMeta): + """ + Abstract class for antispam rule test cases. + + Tests for specific rules should inherit from `RuleTest` and implement + `relevant_messages` and `get_report`. Each instance should also set the + `apply` and `config` attributes as necessary. + + The execution of test cases can then be delegated to the `run_allowed` + and `run_disallowed` methods. + """ + + apply: Callable # The tested rule's apply function + config: Dict[str, int] + + async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: + """Run all `cases` against `self.apply` expecting them to pass.""" + for recent_messages in cases: + last_message = recent_messages[0] + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + config=self.config, + ): + self.assertIsNone( + await self.apply(last_message, recent_messages, self.config) + ) + + async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: + """Run all `cases` against `self.apply` expecting them to fail.""" + for case in cases: + recent_messages, culprits, n_violations = case + last_message = recent_messages[0] + relevant_messages = self.relevant_messages(case) + desired_output = ( + self.get_report(case), + culprits, + relevant_messages, + ) + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + relevant_messages=relevant_messages, + n_violations=n_violations, + config=self.config, + ): + self.assertTupleEqual( + await self.apply(last_message, recent_messages, self.config), + desired_output, + ) + + @abstractmethod + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + """Give expected relevant messages for `case`.""" + raise NotImplementedError + + @abstractmethod + def get_report(self, case: DisallowedCase) -> str: + """Give expected error report for `case`.""" + raise NotImplementedError diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index 4bb0acf7c..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,52 +1,71 @@ -import asyncio -import unittest -from dataclasses import dataclass -from typing import Any, List +from typing import Iterable from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test -# Using `MagicMock` sadly doesn't work for this usecase -# since it's __eq__ compares the MagicMock's ID. We just -# want to compare the actual attributes we set. -@dataclass -class FakeMessage: - author: str - attachments: List[Any] +def make_msg(author: str, total_attachments: int) -> MockMessage: + """Builds a message with `total_attachments` attachments.""" + return MockMessage(author=author, attachments=list(range(total_attachments))) -def msg(total_attachments: int) -> FakeMessage: - return FakeMessage(author='lemon', attachments=list(range(total_attachments))) +class AttachmentRuleTests(RuleTest): + """Tests applying the `attachments` antispam rule.""" + def setUp(self): + self.apply = attachments.apply + self.config = {"max": 5, "interval": 10} -class AttachmentRuleTests(unittest.TestCase): - """Tests applying the `attachment` antispam rule.""" - - def test_allows_messages_without_too_many_attachments(self): + @async_test + async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( - (msg(0), msg(0), msg(0)), - (msg(2), msg(2)), - (msg(0),), + [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], + [make_msg("bob", 2), make_msg("bob", 2)], + [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], ) - for last_message, *recent_messages in cases: - with self.subTest(last_message=last_message, recent_messages=recent_messages): - coro = attachments.apply(last_message, recent_messages, {'max': 5}) - self.assertIsNone(asyncio.run(coro)) + await self.run_allowed(cases) - def test_disallows_messages_with_too_many_attachments(self): + @async_test + async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( - ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), - ((msg(6),), [msg(6)], 6), - ((msg(1),) * 6, [msg(1)] * 6, 6), + DisallowedCase( + [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], + ("bob",), + 10, + ), + DisallowedCase( + [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], + ("bob",), + 6, + ), + DisallowedCase( + [make_msg("alice", 6)], + ("alice",), + 6, + ), + DisallowedCase( + [make_msg("alice", 1) for _ in range(6)], + ("alice",), + 6, + ), ) - for messages, relevant_messages, total in cases: - with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total): - last_message, *recent_messages = messages - coro = attachments.apply(last_message, recent_messages, {'max': 5}) - self.assertEqual( - asyncio.run(coro), - (f"sent {total} attachments in 5s", ('lemon',), relevant_messages) - ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and len(msg.attachments) > 0 + ) + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..72f0be0c7 --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,56 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with author set to `author`. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): + """Tests the `burst` antispam rule.""" + + def setUp(self): + self.apply = burst.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("bob"), make_msg("bob")], + [make_msg("bob"), make_msg("alice"), make_msg("bob")], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + ("bob",), + 3, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..47367a5f8 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,59 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with the passed arg. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): + """Tests the `burst_shared` antispam rule.""" + + def setUp(self): + self.apply = burst_shared.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """ + Cases that do not violate the rule. + + There really isn't more to test here than a single case. + """ + cases = ( + [make_msg("spongebob"), make_msg("patrick")], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + {"bob"}, + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + {"bob", "alice"}, + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return case.recent_messages + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..7cc36f49e --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_chars: int) -> MockMessage: + """Build a message with arbitrary content of `n_chars` length.""" + return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): + """Tests the `chars` antispam rule.""" + + def setUp(self): + self.apply = chars.apply + self.config = { + "max": 20, # Max allowed sum of chars per user + "interval": 10, + } + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of chars within limit.""" + cases = ( + [make_msg("bob", 0)], + [make_msg("bob", 20)], + [make_msg("bob", 15), make_msg("alice", 15)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the total amount of chars exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob", 21)], + ("bob",), + 21, + ), + DisallowedCase( + [make_msg("bob", 15), make_msg("bob", 15)], + ("bob",), + 30, + ), + DisallowedCase( + [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], + ("alice",), + 30, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..0239b0b00 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + +discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: + """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" + return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): + """Tests for the `discord_emojis` antispam rule.""" + + def setUp(self): + self.apply = discord_emojis.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of discord emojis within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of discord emojis.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..59e0fb6ef --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, content: str) -> MockMessage: + """Give a MockMessage instance with `author` and `content` attrs.""" + return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): + """Tests the `duplicates` antispam rule.""" + + def setUp(self): + self.apply = duplicates.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", "A"), make_msg("alice", "A")], + [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate + [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with too many duplicate messages from the same author.""" + cases = ( + DisallowedCase( + [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], + ("alice",), + 3, + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 duplicate messages, but only 3 from bob + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 message from bob, but only 3 duplicates + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and msg.content == last_message.content + ) + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index be832843b..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,32 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import links -from tests.helpers import async_test +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test -class FakeMessage(NamedTuple): - author: str - content: str - - -class Case(NamedTuple): - recent_messages: List[FakeMessage] - relevant_messages: Tuple[FakeMessage] - culprit: Tuple[str] - total_links: int - - -def msg(author: str, total_links: int) -> FakeMessage: - """Makes a message with *total_links* links.""" +def make_msg(author: str, total_links: int) -> MockMessage: + """Makes a message with `total_links` links.""" content = " ".join(["https://pydis.com"] * total_links) - return FakeMessage(author=author, content=content) + return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest): """Tests applying the `links` rule.""" def setUp(self): + self.apply = links.apply self.config = { "max": 2, "interval": 10 @@ -36,66 +25,45 @@ class LinksTests(unittest.TestCase): async def test_links_within_limit(self): """Messages with an allowed amount of links.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await links.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) @async_test async def test_links_exceeding_limit(self): """Messages with a a higher than allowed amount of links.""" cases = ( - Case( - [msg("bob", 1), msg("bob", 2)], - (msg("bob", 1), msg("bob", 2)), + DisallowedCase( + [make_msg("bob", 1), make_msg("bob", 2)], ("bob",), 3 ), - Case( - [msg("alice", 1), msg("alice", 1), msg("alice", 1)], - (msg("alice", 1), msg("alice", 1), msg("alice", 1)), + DisallowedCase( + [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], ("alice",), 3 ), - Case( - [msg("alice", 2), msg("bob", 3), msg("alice", 1)], - (msg("alice", 2), msg("alice", 1)), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], ("alice",), 3 ) ) - for recent_messages, relevant_messages, culprit, total_links in cases: - last_message = recent_messages[0] + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_links=total_links, - config=self.config - ): - desired_output = ( - f"sent {total_links} links in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await links.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py new file mode 100644 index 000000000..ebcdabac6 --- /dev/null +++ b/tests/bot/rules/test_mentions.py @@ -0,0 +1,67 @@ +from typing import Iterable + +from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, total_mentions: int) -> MockMessage: + """Makes a message with `total_mentions` mentions.""" + return MockMessage(author=author, mentions=list(range(total_mentions))) + + +class TestMentions(RuleTest): + """Tests applying the `mentions` antispam rule.""" + + def setUp(self): + self.apply = mentions.apply + self.config = { + "max": 2, + "interval": 10, + } + + @async_test + async def test_mentions_within_limit(self): + """Messages with an allowed amount of mentions.""" + cases = ( + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 1), make_msg("alice", 2)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_mentions_exceeding_limit(self): + """Messages with a higher than allowed amount of mentions.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], + ("alice",), + 3, + ), + DisallowedCase( + [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], + ("bob",), + 4, + ) + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..d61c4609d --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,105 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: + """Init a MockMessage instance with `author` and content configured by `newline_groups". + + Configure content by passing a list of ints, where each int `n` will generate + a separate group of `n` newlines. + + Example: + newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" + """ + content = " ".join("\n" * n for n in newline_groups) + return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): + """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + + def setUp(self): + self.apply = newlines.apply + self.config = { + "max": 5, # Max sum of newlines in relevant messages + "max_consecutive": 3, # Max newlines in one group, in one message + "interval": 10, + } + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", [])], # Single message with no newlines + [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages + [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author + [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_total(self): + """Cases which violate the rule by having too many newlines in total.""" + cases = ( + DisallowedCase( # Alice sends a total of 6 newlines (disallowed) + [make_msg("alice", [2, 2]), make_msg("alice", [2])], + ("alice",), + 6, + ), + DisallowedCase( # Here we test that only alice's newlines count in the sum + [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], + ("alice",), + 7, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): + """ + Tests the `newlines` antispam rule against max consecutive newline violations. + + As these violations yield a different error report, they require a different + `get_report` implementation. + """ + + def setUp(self): + self.apply = newlines.apply + self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + + @async_test + async def test_disallows_messages_consecutive(self): + """Cases which violate the rule due to having too many consecutive newlines.""" + cases = ( + DisallowedCase( # Bob sends a group of newlines too large + [make_msg("bob", [4])], + ("bob",), + 4, + ), + DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed) + [make_msg("alice", [1]), make_msg("alice", [4])], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..b339cccf7 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_mentions: int) -> MockMessage: + """Build a MockMessage instance with `n_mentions` role mentions.""" + return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): + """Tests for the `role_mentions` antispam rule.""" + + def setUp(self): + self.apply = role_mentions.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of role mentions within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of role mentions.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..bdfcc73e4 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,9 +1,7 @@ -import logging import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from bot import api -from tests.base import LoggingTestCase from tests.helpers import async_test @@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase): self.assertEqual(error.response_text, "") self.assertIs(error.response, self.error_api_response) - def test_responde_code_error_string_representation_default_initialization(self): + def test_response_code_error_string_representation_default_initialization(self): """Test the string representation of `ResponseCodeError` initialized without text or json.""" error = api.ResponseCodeError(response=self.error_api_response) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase): response_text=text_data ) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): - """Tests the bot's API Log Handler.""" - - @classmethod - def setUpClass(cls): - cls.debug_log_record = logging.LogRecord( - name='my.logger', level=logging.DEBUG, - pathname='my/logger.py', lineno=666, - msg="Lemon wins", args=(), - exc_info=None - ) - - cls.trace_log_record = logging.LogRecord( - name='my.logger', level=logging.TRACE, - pathname='my/logger.py', lineno=666, - msg="This will not be logged", args=(), - exc_info=None - ) - - def setUp(self): - self.log_handler = api.APILoggingHandler(None) - - def test_emit_appends_to_queue_with_stopped_event_loop(self): - """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" - with patch("bot.api.APILoggingHandler.ship_off") as ship_off: - # Patch `ship_off` to ease testing against the return value of this coroutine. - ship_off.return_value = 42 - self.log_handler.emit(self.debug_log_record) - - self.assertListEqual(self.log_handler.queue, [42]) - - def test_emit_ignores_less_than_debug(self): - """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" - self.log_handler.emit(self.trace_log_record) - self.assertListEqual(self.log_handler.queue, []) - - def test_schedule_queued_tasks_for_empty_queue(self): - """`APILoggingHandler` should not schedule anything when the queue is empty.""" - with self.assertNotLogs(level=logging.DEBUG): - self.log_handler.schedule_queued_tasks() - - def test_schedule_queued_tasks_for_nonempty_queue(self): - """`APILoggingHandler` should schedule logs when the queue is not empty.""" - log = logging.getLogger("bot.api") - - with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: - self.log_handler.queue = [555] - self.log_handler.schedule_queued_tasks() - self.assertListEqual(self.log_handler.queue, []) - create_task.assert_called_once_with(555) - - [record] = logs.records - self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") - self.assertEqual(record.levelno, logging.DEBUG) - self.assertEqual(record.name, 'bot.api') - self.assertIn('via_handler', record.__dict__) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase): instance = utils.CaseInsensitiveDict() instance.update({'FOO': 'bar'}) self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): - """Tests the `chunk` method.""" - - def test_empty_chunking(self): - """Tests chunking on an empty iterable.""" - generator = utils.chunks(iterable=[], size=5) - self.assertEqual(list(generator), []) - - def test_list_chunking(self): - """Tests chunking a non-empty list.""" - iterable = [1, 2, 3, 4, 5] - generator = utils.chunks(iterable=iterable, size=2) - self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + + def test_humanize_delta_should_normal_usage(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" + test_cases = (-1, 0) + + for max_units in test_cases: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) + ) + + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + self.assertIs(asyncio.run(time.wait_until(then, start)), None) + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that now and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) diff --git a/tests/helpers.py b/tests/helpers.py index 8a14aeef4..9d9dd5da6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,7 +10,10 @@ import unittest.mock from typing import Any, Iterable, Optional import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.api import APIClient +from bot.bot import Bot for logger in logging.Logger.manager.loggerDict.values(): @@ -120,8 +123,80 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): Python 3.8 will introduce an AsyncMock class in the standard library that will have some more features; this stand-in only overwrites the `__call__` method to an async version. """ + async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) + return super().__call__(*args, **kwargs) + + +class AsyncIteratorMock: + """ + A class to mock asynchronous iterators. + + This allows async for, which is used in certain Discord.py objects. For example, + an async iterator is returned by the Reaction.users() method. + """ + + def __init__(self, iterable: Iterable = None): + if iterable is None: + iterable = [] + + self.iter = iter(iterable) + self.iterable = iterable + + self.call_count = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + def __call__(self): + """ + Keeps track of the number of times an instance has been called. + + This is useful, since it typically shows that the iterator has actually been used somewhere after we have + instantiated the mock for an attribute that normally returns an iterator when called. + """ + self.call_count += 1 + return self + + @property + def return_value(self): + """Makes `self.iterable` accessible as self.return_value.""" + return self.iterable + + @return_value.setter + def return_value(self, iterable): + """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" + self.iter = iter(iterable) + self.iterable = iterable + + def assert_called(self): + """Asserts if the AsyncIteratorMock instance has been called at least once.""" + if self.call_count == 0: + raise AssertionError("Expected AsyncIteratorMock to have been called.") + + def assert_called_once(self): + """Asserts if the AsyncIteratorMock instance has been called exactly once.""" + if self.call_count != 1: + raise AssertionError( + f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." + ) + + def assert_not_called(self): + """Asserts if the AsyncIteratorMock instance has not been called.""" + if self.call_count != 0: + raise AssertionError( + f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." + ) + + def reset_mock(self): + """Resets the call count, but not the return value or iterator.""" + self.call_count = 0 # Create a guild instance to get a realistic Mock of `discord.Guild` @@ -195,9 +270,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): information, see the `MockGuild` docstring. """ def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} + default_kwargs = { + 'id': next(self.discord_id), + 'name': 'role', + 'position': 1, + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), + } super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + if isinstance(self.colour, int): + self.colour = discord.Colour(self.colour) + + if isinstance(self.permissions, int): + self.permissions = discord.Permissions(self.permissions) + if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -220,7 +307,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin information, see the `MockGuild` docstring. """ def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'name': 'member', 'id': next(self.discord_id)} + default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] @@ -231,6 +318,37 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" +# Create a User instance to get a realistic Mock of `discord.User` +user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock()) + + +class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): + """ + A Mock subclass to mock User objects. + + Instances of this class will follow the specifications of `discord.User` instances. For more + information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} + super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"@{self.name}" + + +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock APIClient objects. + + Instances of this class will follow the specifications of `bot.api.APIClient` instances. + For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=APIClient, **kwargs) + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None @@ -244,8 +362,10 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=bot_instance, **kwargs) + self.api_client = MockAPIClient() # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -281,6 +401,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) @@ -322,6 +443,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=context_instance, **kwargs) self.bot = kwargs.get('bot', MockBot()) @@ -330,6 +452,20 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.channel = kwargs.get('channel', MockTextChannel()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) + + +class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Attachment objects. + + Instances of this class will follow the specifications of `discord.Attachment` instances. For + more information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=attachment_instance, **kwargs) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. @@ -337,8 +473,10 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: - super().__init__(spec_set=message_instance, **kwargs) + default_kwargs = {'attachments': []} + super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -354,6 +492,7 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=emoji_instance, **kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -369,6 +508,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=partial_emoji_instance, **kwargs) @@ -383,7 +523,32 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=reaction_instance, **kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) + self.users = AsyncIteratorMock(kwargs.get('users', [])) + self.__str__.return_value = str(self.emoji) + + +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) + + +class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. + + Instances of this class will follow the specifications of `discord.Webhook` instances. For + more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=webhook_instance, **kwargs) + + # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined + # as coroutines. That's why we need to set the methods manually. + self.send = AsyncMock() + self.edit = AsyncMock() + self.delete = AsyncMock() + self.execute = AsyncMock() |