diff options
author | 2020-02-23 15:42:20 -0800 | |
---|---|---|
committer | 2020-02-23 15:42:20 -0800 | |
commit | c81a4d401ea434e98b0a1ece51d3d10f1a3ad226 (patch) | |
tree | 1fe8caf0ae751cfe94a71fcf02ffbbfda5bd8812 /tests | |
parent | Merge pull request #749 from python-discord/reminder-enhancements (diff) | |
parent | Merge remote-tracking branch 'origin/master' into bug/backend/b704/ready-miss... (diff) |
Merge pull request #711 from python-discord/bug/backend/b704/ready-missing-cache
Prevent the role syncer from wiping the database table during API latency
Diffstat (limited to 'tests')
-rw-r--r-- | tests/base.py | 34 | ||||
-rw-r--r-- | tests/bot/cogs/sync/test_base.py | 412 | ||||
-rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 395 | ||||
-rw-r--r-- | tests/bot/cogs/sync/test_roles.py | 287 | ||||
-rw-r--r-- | tests/bot/cogs/sync/test_users.py | 241 | ||||
-rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 4 | ||||
-rw-r--r-- | tests/helpers.py | 29 |
7 files changed, 1197 insertions, 205 deletions
diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@ import logging import unittest from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase): standard_message = self._truncateMessage(base_message, record_message) msg = self._formatMessage(msg, standard_message) self.fail(msg) + + +class CommandTestCase(unittest.TestCase): + """TestCase with additional assertions that are useful for testing Discord commands.""" + + @helpers.async_test + async def assertHasPermissionsCheck( + self, + cmd: commands.Command, + permissions: Dict[str, bool], + ) -> None: + """ + Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + + Every permission in `permissions` is expected to be reported as missing. In other words, do + not include permissions which should not raise an exception along with those which should. + """ + # Invert permission values because it's more intuitive to pass to this assertion the same + # permissions as those given to the check decorator. + permissions = {k: not v for k, v in permissions.items()} + + ctx = helpers.MockContext() + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + + with self.assertRaises(commands.MissingPermissions) as cm: + await cmd.can_run(ctx) + + self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..e6a6f9688 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = helpers.AsyncMock() + _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + @helpers.async_test + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + @helpers.async_test + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.devcore) + + @helpers.async_test + async def test_send_prompt_returns_None_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + @helpers.async_test + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + @helpers.async_test + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + @helpers.async_test + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + @helpers.async_test + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + @helpers.async_test + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + @helpers.async_test + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + @helpers.async_test + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): + with self.subTest(size=size): + self.syncer._send_prompt = helpers.AsyncMock() + self.syncer._wait_for_confirmation = helpers.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: + with self.subTest(msg=msg): + self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): + """ + A MagicMock subclass to mock Syncer objects. + + Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` + instances. For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + # These patch the type. When the type is called, a MockSyncer instanced is returned. + # MockSyncer is needed so that our custom AsyncMock is used. + # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. + self.role_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.RoleSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.UserSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild") + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task.reset_mock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + @helpers.async_test + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + @helpers.async_test + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + @helpers.async_test + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + + @helpers.async_test + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + @helpers.async_test + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + @helpers.async_test + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data) + after_role = helpers.MockRole(**after_role_data) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_remove(self): + """Member should patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember() + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + updated_information={"in_guild": False} + ) + + @helpers.async_test + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles) + after_member = helpers.MockMember(roles=before_roles[1:]) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + + @helpers.async_test + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}) + after_member = helpers.MockMember(**{attribute: new_value}) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "avatar": "old avatar", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args[0][0], after_user.id) + self.assertIn("updated_information", call_args[1]) + + updated_information = call_args[1]["updated_information"] + self.assertIn(api_field, updated_information) + self.assertEqual(updated_information[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + ) + + data = { + "avatar_hash": member.avatar, + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + @helpers.async_test + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + @helpers.async_test + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + @helpers.async_test + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): - """Tests constructing the roles to synchronize with the site.""" - - def test_get_roles_for_sync_empty_return_for_equal_roles(self): - """No roles should be synced when no diff is found.""" - api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), set(), set()) - ) - - def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): - """Roles to be synced are returned when non-ID attributes differ.""" - api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), guild_roles, set()) - ) - - def test_get_roles_only_returns_roles_that_require_update(self): - """Roles that require an update should be returned as the second tuple element.""" - api_roles = { - Role(id=41, name='old name', colour=33, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - guild_roles = { - Role(id=41, name='new name', colour=35, permissions=0x8, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_new_roles_in_first_tuple_element(self): - """Newly created roles are returned as the first tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=2) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, - set(), - set(), - ) - ) - - def test_get_roles_returns_roles_to_update_and_new_roles(self): - """Newly created and updated roles should be returned together.""" - api_roles = { - Role(id=41, name='old name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='new name', colour=40, permissions=0x16, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, - {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_roles_to_delete(self): - """Roles to be deleted should be returned as the third tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - set(), - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) - - def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): - """When roles were added, updated, and removed, all of them are returned properly.""" - api_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - Role(id=71, name='to update', colour=99, permissions=0x9, position=3), - } - guild_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=81, name='to create', colour=99, permissions=0x9, position=4), - Role(id=71, name='updated', colour=101, permissions=0x5, position=3), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, - {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + @helpers.async_test + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers def fake_user(**kwargs): - kwargs.setdefault('id', 43) - kwargs.setdefault('name', 'bob the test man') - kwargs.setdefault('discriminator', 1337) - kwargs.setdefault('avatar_hash', None) - kwargs.setdefault('roles', (666,)) - kwargs.setdefault('in_guild', True) - return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): - """Tests constructing the users to synchronize with the site.""" - - def test_get_users_for_sync_returns_nothing_for_empty_params(self): - """When no users are given, none are returned.""" - self.assertEqual( - get_users_for_sync({}, {}), - (set(), set()) - ) - - def test_get_users_for_sync_returns_nothing_for_equal_users(self): - """When no users are updated, none are returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) - - def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): - """When a non-ID-field differs, the user to update is returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(name='new fancy name')} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(name='new fancy name')}) - ) - - def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): - """When new users join the guild, they are returned as the first tuple element.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(), 63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, set()) - ) - - def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("avatar_hash", None) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + member["avatar"] = member.pop("avatar_hash") + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + @helpers.async_test + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - api_users = {43: fake_user(), 63: fake_user(id=63)} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(id=63, in_guild=False)}) - ) - - def test_get_users_for_sync_updates_and_creates_users_as_needed(self): - """When one user left and another one was updated, both are returned.""" - api_users = {43: fake_user()} - guild_users = {63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, {fake_user(in_guild=False)}) - ) - - def test_get_users_for_sync_does_not_duplicate_update_users(self): - """When the API knows a user the guild doesn't, nothing is performed.""" - api_users = {43: fake_user(in_guild=False)} - guild_users = {} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..5b0a3b8c3 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase): asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase): with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(len(log_watcher.records), 1) diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..9d9dd5da6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional import discord from discord.ext.commands import Context +from bot.api import APIClient from bot.bot import Bot @@ -269,9 +270,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): information, see the `MockGuild` docstring. """ def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} + default_kwargs = { + 'id': next(self.discord_id), + 'name': 'role', + 'position': 1, + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), + } super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + if isinstance(self.colour, int): + self.colour = discord.Colour(self.colour) + + if isinstance(self.permissions, int): + self.permissions = discord.Permissions(self.permissions) + if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -324,6 +337,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock APIClient objects. + + Instances of this class will follow the specifications of `bot.api.APIClient` instances. + For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=APIClient, **kwargs) + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None @@ -340,6 +365,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): def __init__(self, **kwargs) -> None: super().__init__(spec_set=bot_instance, **kwargs) + self.api_client = MockAPIClient() # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -503,6 +529,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) self.users = AsyncIteratorMock(kwargs.get('users', [])) + self.__str__.return_value = str(self.emoji) webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) |