diff options
author | 2020-11-21 09:37:57 +0200 | |
---|---|---|
committer | 2020-11-21 09:37:57 +0200 | |
commit | 95429b10aab81d849ab86241f411bbc122ea1594 (patch) | |
tree | 707b7f6055e8cd6d11f1034ed6296470bc437416 /tests | |
parent | EH tests: Fix InWhitelistCheckFailure import path (diff) | |
parent | Merge pull request #1287 from python-discord/help-channel-msg (diff) |
Merge branch 'master' into error-handler-test
Diffstat (limited to 'tests')
-rw-r--r-- | tests/_autospec.py | 64 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_base.py | 378 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py | 38 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py | 26 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 147 | ||||
-rw-r--r-- | tests/bot/exts/info/test_information.py | 175 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 148 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 587 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 25 | ||||
-rw-r--r-- | tests/bot/patches/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 29 | ||||
-rw-r--r-- | tests/bot/utils/test_services.py | 39 | ||||
-rw-r--r-- | tests/helpers.py | 21 |
13 files changed, 893 insertions, 784 deletions
diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): + """Skips adding patchings as args if their `dont_pass` attribute is True.""" + # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. + extra_args = [] + with contextlib.ExitStack() as exit_stack: + for patching in patched.patchings: + arg = exit_stack.enter_context(patching) + if not getattr(patching, "dont_pass", False): + # Only add the patching as an arg if dont_pass is False. + if patching.attribute_name is not None: + keywargs.update(arg) + elif patching.new is unittest.mock.DEFAULT: + extra_args.append(arg) + + args += tuple(extra_args) + yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): + """Copy the `dont_pass` attribute along with the standard copy operation.""" + patcher_copy = _copy.original(self) + patcher_copy.dont_pass = getattr(self, "dont_pass", False) + return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: + """ + Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + + If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. + """ + # Caller's kwargs should take priority and overwrite the defaults. + kwargs = dict(spec_set=True, autospec=True) + kwargs.update(patch_kwargs) + + # Import the target if it's a string. + # This is to support both object and string targets like patch.multiple. + if type(target) is str: + target = unittest.mock._importer(target) + + def decorator(func): + for attribute in attributes: + patcher = unittest.mock.patch.object(target, attribute, **kwargs) + if not pass_mocks: + # A custom attribute to keep track of which patchings should be skipped. + patcher.dont_pass = True + func = patcher(func) + return func + return decorator diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 886c243cf..3ad9db9c3 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -1,12 +1,9 @@ -import asyncio import unittest from unittest import mock -import discord -from bot import constants from bot.api import ResponseCodeError -from bot.exts.backend.sync._syncers import Syncer, _Diff +from bot.exts.backend.sync._syncers import Syncer from tests import helpers @@ -18,292 +15,21 @@ class TestSyncer(Syncer): _sync = mock.AsyncMock() -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - -class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): - """Tests for sending the sync confirmation prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - - def mock_get_channel(self): - """Fixture to return a mock channel and message for when `get_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - mock_channel.send.return_value = mock_message - self.bot.get_channel.return_value = mock_channel - - return mock_channel, mock_message - - def mock_fetch_channel(self): - """Fixture to return a mock channel and message for when `fetch_channel` is used.""" - self.bot.reset_mock() - - mock_channel = helpers.MockTextChannel() - mock_message = helpers.MockMessage() - - self.bot.get_channel.return_value = None - mock_channel.send.return_value = mock_message - self.bot.fetch_channel.return_value = mock_channel - - return mock_channel, mock_message - - async def test_send_prompt_edits_and_returns_message(self): - """The given message should be edited to display the prompt and then should be returned.""" - msg = helpers.MockMessage() - ret_val = await self.syncer._send_prompt(msg) - - msg.edit.assert_called_once() - self.assertIn("content", msg.edit.call_args[1]) - self.assertEqual(ret_val, msg) - - async def test_send_prompt_gets_dev_core_channel(self): - """The dev-core channel should be retrieved if an extant message isn't given.""" - subtests = ( - (self.bot.get_channel, self.mock_get_channel), - (self.bot.fetch_channel, self.mock_fetch_channel), - ) - - for method, mock_ in subtests: - with self.subTest(method=method, msg=mock_.__name__): - mock_() - await self.syncer._send_prompt() - - method.assert_called_once_with(constants.Channels.dev_core) - - async def test_send_prompt_returns_none_if_channel_fetch_fails(self): - """None should be returned if there's an HTTPException when fetching the channel.""" - self.bot.get_channel.return_value = None - self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") - - ret_val = await self.syncer._send_prompt() - - self.assertIsNone(ret_val) - - async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): - """A new message mentioning core devs should be sent and returned if message isn't given.""" - for mock_ in (self.mock_get_channel, self.mock_fetch_channel): - with self.subTest(msg=mock_.__name__): - mock_channel, mock_message = mock_() - ret_val = await self.syncer._send_prompt() - - mock_channel.send.assert_called_once() - self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) - self.assertEqual(ret_val, mock_message) - - async def test_send_prompt_adds_reactions(self): - """The message should have reactions for confirmation added.""" - extant_message = helpers.MockMessage() - subtests = ( - (extant_message, lambda: (None, extant_message)), - (None, self.mock_get_channel), - (None, self.mock_fetch_channel), - ) - - for message_arg, mock_ in subtests: - subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ - - with self.subTest(msg=subtest_msg): - _, mock_message = mock_() - await self.syncer._send_prompt(message_arg) - - calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] - mock_message.add_reaction.assert_has_calls(calls) - - -class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): - """Tests for waiting for a sync confirmation reaction on the prompt.""" - - def setUp(self): - self.bot = helpers.MockBot() - self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) - - @staticmethod - def get_message_reaction(emoji): - """Fixture to return a mock message an reaction from the given `emoji`.""" - message = helpers.MockMessage() - reaction = helpers.MockReaction(emoji=emoji, message=message) - - return message, reaction - - def test_reaction_check_for_valid_emoji_and_authors(self): - """Should return True if authors are identical or are a bot and a core dev, respectively.""" - user_subtests = ( - ( - helpers.MockMember(id=77), - helpers.MockMember(id=77), - "identical users", - ), - ( - helpers.MockMember(id=77, bot=True), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "bot author and core-dev reactor", - ), - ) - - for emoji in self.syncer._REACTION_EMOJIS: - for author, user, msg in user_subtests: - with self.subTest(author=author, user=user, emoji=emoji, msg=msg): - message, reaction = self.get_message_reaction(emoji) - ret_val = self.syncer._reaction_check(author, message, reaction, user) - - self.assertTrue(ret_val) - - def test_reaction_check_for_invalid_reactions(self): - """Should return False for invalid reaction events.""" - valid_emoji = self.syncer._REACTION_EMOJIS[0] - subtests = ( - ( - helpers.MockMember(id=77), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43, roles=[self.core_dev_role]), - "users are not identical", - ), - ( - helpers.MockMember(id=77, bot=True), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=43), - "reactor lacks the core-dev role", - ), - ( - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - *self.get_message_reaction(valid_emoji), - helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), - "reactor is a bot", - ), - ( - helpers.MockMember(id=77), - helpers.MockMessage(id=95), - helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), - helpers.MockMember(id=77), - "messages are not identical", - ), - ( - helpers.MockMember(id=77), - *self.get_message_reaction("InVaLiD"), - helpers.MockMember(id=77), - "emoji is invalid", - ), - ) - - for *args, msg in subtests: - kwargs = dict(zip(("author", "message", "reaction", "user"), args)) - with self.subTest(**kwargs, msg=msg): - ret_val = self.syncer._reaction_check(*args) - self.assertFalse(ret_val) - - async def test_wait_for_confirmation(self): - """The message should always be edited and only return True if the emoji is a check mark.""" - subtests = ( - (constants.Emojis.check_mark, True, None), - ("InVaLiD", False, None), - (None, False, asyncio.TimeoutError), - ) - - for emoji, ret_val, side_effect in subtests: - for bot in (True, False): - with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): - # Set up mocks - message = helpers.MockMessage() - member = helpers.MockMember(bot=bot) - - self.bot.wait_for.reset_mock() - self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) - self.bot.wait_for.side_effect = side_effect - - # Call the function - actual_return = await self.syncer._wait_for_confirmation(member, message) - - # Perform assertions - self.bot.wait_for.assert_called_once() - self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) - - message.edit.assert_called_once() - kwargs = message.edit.call_args[1] - self.assertIn("content", kwargs) - - # Core devs should only be mentioned if the author is a bot. - if bot: - self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - else: - self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) - - self.assertIs(actual_return, ret_val) - - class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for main function orchestrating the sync.""" def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) - - async def test_sync_respects_confirmation_result(self): - """The sync should abort if confirmation fails and continue if confirmed.""" - mock_message = helpers.MockMessage() - subtests = ( - (True, mock_message), - (False, None), - ) - - for confirmed, message in subtests: - with self.subTest(confirmed=confirmed): - self.syncer._sync.reset_mock() - self.syncer._get_diff.reset_mock() - - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(confirmed, message) - ) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) + patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True))) + self.bot = patcher.start() + self.addCleanup(patcher.stop) - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() + self.guild = helpers.MockGuild() - if confirmed: - self.syncer._sync.assert_called_once_with(diff) - else: - self.syncer._sync.assert_not_called() + TestSyncer._get_diff.reset_mock(return_value=True, side_effect=True) + TestSyncer._sync.reset_mock(return_value=True, side_effect=True) - async def test_sync_diff_size(self): - """The diff size should be correctly calculated.""" - subtests = ( - (6, _Diff({1, 2}, {3, 4}, {5, 6})), - (5, _Diff({1, 2, 3}, None, {4, 5})), - (0, _Diff(None, None, None)), - (0, _Diff(set(), set(), set())), - ) - - for size, diff in subtests: - with self.subTest(size=size, diff=diff): - self.syncer._get_diff.reset_mock() - self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild) - - self.syncer._get_diff.assert_called_once_with(guild) - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + TestSyncer._get_diff.return_value = mock.MagicMock() async def test_sync_message_edited(self): """The message should be edited if one was sent, even if the sync has an API error.""" @@ -315,90 +41,26 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): for message, side_effect, should_edit in subtests: with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = mock.AsyncMock( - return_value=(True, message) - ) + TestSyncer._sync.side_effect = side_effect + ctx = helpers.MockContext() + ctx.send.return_value = message - guild = helpers.MockGuild() - await self.syncer.sync(guild) + await TestSyncer.sync(self.guild, ctx) if should_edit: message.edit.assert_called_once() self.assertIn("content", message.edit.call_args[1]) - async def test_sync_confirmation_context_redirect(self): - """If ctx is given, a new message should be sent and author should be ctx's author.""" - mock_member = helpers.MockMember() + async def test_sync_message_sent(self): + """If ctx is given, a new message should be sent.""" subtests = ( - (None, self.bot.user, None), - (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + (None, None), + (helpers.MockContext(), helpers.MockMessage()), ) - for ctx, author, message in subtests: - with self.subTest(ctx=ctx, author=author, message=message): - if ctx is not None: - ctx.send.return_value = message - - # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() - - self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) - - guild = helpers.MockGuild() - await self.syncer.sync(guild, ctx) + for ctx, message in subtests: + with self.subTest(ctx=ctx, message=message): + await TestSyncer.sync(self.guild, ctx) if ctx is not None: ctx.send.assert_called_once() - - self.syncer._get_confirmation_result.assert_called_once() - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) - self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_small_diff(self): - """Should always return True and the given message if the diff size is too small.""" - author = helpers.MockMember() - expected_message = helpers.MockMessage() - - for size in (3, 2): # pragma: no cover - with self.subTest(size=size): - self.syncer._send_prompt = mock.AsyncMock() - self.syncer._wait_for_confirmation = mock.AsyncMock() - - coro = self.syncer._get_confirmation_result(size, author, expected_message) - result, actual_message = await coro - - self.assertTrue(result) - self.assertEqual(actual_message, expected_message) - self.syncer._send_prompt.assert_not_called() - self.syncer._wait_for_confirmation.assert_not_called() - - @mock.patch.object(constants.Sync, "max_diff", new=3) - async def test_confirmation_result_large_diff(self): - """Should return True if confirmed and False if _send_prompt fails or aborted.""" - author = helpers.MockMember() - mock_message = helpers.MockMessage() - - subtests = ( - (True, mock_message, True, "confirmed"), - (False, None, False, "_send_prompt failed"), - (False, mock_message, False, "aborted"), - ) - - for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover - with self.subTest(msg=msg): - self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) - - coro = self.syncer._get_confirmation_result(4, author) - actual_result, actual_message = await coro - - self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None - self.assertIs(actual_result, expected_result) - self.assertEqual(actual_message, expected_message) - - if expected_message: - self.syncer._wait_for_confirmation.assert_called_once_with( - author, expected_message - ) diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 1b89564f2..22a07313e 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = helpers.MockBot() - self.role_syncer_patcher = mock.patch( + role_syncer_patcher = mock.patch( "bot.exts.backend.sync._syncers.RoleSyncer", autospec=Syncer, spec_set=True ) - self.user_syncer_patcher = mock.patch( + user_syncer_patcher = mock.patch( "bot.exts.backend.sync._syncers.UserSyncer", autospec=Syncer, spec_set=True ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - self.cog = Sync(self.bot) + self.RoleSyncer = role_syncer_patcher.start() + self.UserSyncer = user_syncer_patcher.start() - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() + self.addCleanup(role_syncer_patcher.stop) + self.addCleanup(user_syncer_patcher.stop) + + self.cog = Sync(self.bot) @staticmethod def response_error(status: int) -> ResponseCodeError: @@ -73,8 +73,6 @@ class SyncCogTests(SyncCogTestCase): Sync(self.bot) - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) sync_guild.assert_called_once_with() self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) @@ -83,8 +81,8 @@ class SyncCogTests(SyncCogTestCase): for guild in (helpers.MockGuild(), None): with self.subTest(guild=guild): self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() self.bot.get_guild = mock.MagicMock(return_value=guild) @@ -94,11 +92,11 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild.assert_called_once_with(constants.Guild.id) if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() + self.RoleSyncer.sync.assert_not_called() + self.UserSyncer.sync.assert_not_called() else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) async def patch_user_helper(self, side_effect: BaseException) -> None: """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" @@ -392,16 +390,16 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): async def test_sync_roles_command(self): """sync() should be called on the RoleSyncer.""" ctx = helpers.MockContext() - await self.cog.sync_roles_command.callback(self.cog, ctx) + await self.cog.sync_roles_command(self.cog, ctx) - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + self.RoleSyncer.sync.assert_called_once_with(ctx.guild, ctx) async def test_sync_users_command(self): """sync() should be called on the UserSyncer.""" ctx = helpers.MockContext() - await self.cog.sync_users_command.callback(self.cog, ctx) + await self.cog.sync_users_command(self.cog, ctx) - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + self.UserSyncer.sync.assert_called_once_with(ctx.guild, ctx) async def test_commands_require_admin(self): """The sync commands should only run if the author has the administrator permission.""" diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 7b9f40cad..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -22,8 +22,9 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between roles in the DB and roles in the Guild cache.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @staticmethod def get_guild(*roles): @@ -44,7 +45,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role()] guild = self.get_guild(fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), set(), set()) self.assertEqual(actual_diff, expected_diff) @@ -56,7 +57,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] guild = self.get_guild(updated_role, fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), {_Role(**updated_role)}, set()) self.assertEqual(actual_diff, expected_diff) @@ -68,7 +69,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role()] guild = self.get_guild(fake_role(), new_role) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = ({_Role(**new_role)}, set(), set()) self.assertEqual(actual_diff, expected_diff) @@ -80,7 +81,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role(), deleted_role] guild = self.get_guild(fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), set(), {_Role(**deleted_role)}) self.assertEqual(actual_diff, expected_diff) @@ -98,7 +99,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): ] guild = self.get_guild(fake_role(), new, updated) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) self.assertEqual(actual_diff, expected_diff) @@ -108,8 +109,9 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync roles.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) async def test_sync_created_roles(self): """Only POST requests should be made with the correct payload.""" @@ -117,7 +119,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call("bot/roles", json=role) for role in roles] self.bot.api_client.post.assert_has_calls(calls, any_order=True) @@ -132,7 +134,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] self.bot.api_client.put.assert_has_calls(calls, any_order=True) @@ -147,7 +149,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] self.bot.api_client.delete.assert_has_calls(calls, any_order=True) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index c0a1da35c..61673e1bb 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers @@ -10,7 +10,7 @@ def fake_user(**kwargs): kwargs.setdefault("id", 43) kwargs.setdefault("name", "bob the test man") kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("roles", (666,)) + kwargs.setdefault("roles", [666]) kwargs.setdefault("in_guild", True) return kwargs @@ -20,8 +20,9 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between users in the DB and users in the Guild cache.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @staticmethod def get_guild(*members): @@ -40,22 +41,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): return guild + @staticmethod + def get_mock_member(member: dict): + member = member.copy() + del member["in_guild"] + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + return mock_member + async def test_empty_diff_for_no_users(self): """When no users are given, an empty diff should be returned.""" + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [] + } guild = self.get_guild() - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) async def test_empty_diff_for_identical_users(self): """No differences should be found if the users in the guild and DB are identical.""" - self.bot.api_client.get.return_value = [fake_user()] + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user()] + } guild = self.get_guild(fake_user()) - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) + guild.get_member.return_value = self.get_mock_member(fake_user()) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) @@ -63,59 +84,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Only updated users should be added to the 'updated' set of the diff.""" updated_user = fake_user(id=99, name="new") - self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user(id=99, name="old"), fake_user()] + } guild = self.get_guild(updated_user, fake_user()) + guild.get_member.side_effect = [ + self.get_mock_member(updated_user), + self.get_mock_member(fake_user()) + ] - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**updated_user)}, None) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([], [{"id": 99, "name": "new"}], None) self.assertEqual(actual_diff, expected_diff) async def test_diff_for_new_users(self): - """Only new users should be added to the 'created' set of the diff.""" + """Only new users should be added to the 'created' list of the diff.""" new_user = fake_user(id=99, name="new") - self.bot.api_client.get.return_value = [fake_user()] + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user()] + } guild = self.get_guild(fake_user(), new_user) - - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, set(), None) + guild.get_member.side_effect = [ + self.get_mock_member(fake_user()), + self.get_mock_member(new_user) + ] + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([new_user], [], None) self.assertEqual(actual_diff, expected_diff) async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - leaving_user = fake_user(id=63, in_guild=False) - - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user(), fake_user(id=63)] + } guild = self.get_guild(fake_user()) + guild.get_member.side_effect = [ + self.get_mock_member(fake_user()), + None + ] - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), {_User(**leaving_user)}, None) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([], [{"id": 63, "in_guild": False}], None) self.assertEqual(actual_diff, expected_diff) async def test_diff_for_new_updated_and_leaving_users(self): """When users are added, updated, and removed, all of them are returned properly.""" new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") - leaving_user = fake_user(id=63, in_guild=False) - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user(), fake_user(id=55), fake_user(id=63)] + } guild = self.get_guild(fake_user(), new_user, updated_user) + guild.get_member.side_effect = [ + self.get_mock_member(fake_user()), + self.get_mock_member(updated_user), + None + ] - actual_diff = await self.syncer._get_diff(guild) - expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None) self.assertEqual(actual_diff, expected_diff) async def test_empty_diff_for_db_users_not_in_guild(self): - """When the DB knows a user the guild doesn't, no difference is found.""" - self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + """When the DB knows a user, but the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = { + "count": 3, + "next_page_no": None, + "previous_page_no": None, + "results": [fake_user(), fake_user(id=63, in_guild=False)] + } guild = self.get_guild(fake_user()) + guild.get_member.side_effect = [ + self.get_mock_member(fake_user()), + None + ] - actual_diff = await self.syncer._get_diff(guild) - expected_diff = (set(), set(), None) + actual_diff = await UserSyncer._get_diff(guild) + expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) @@ -124,20 +188,18 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync users.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) async def test_sync_created_users(self): """Only POST requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] - user_tuples = {_User(**user) for user in users} - diff = _Diff(user_tuples, set(), None) - await self.syncer._sync(diff) + diff = _Diff(users, [], None) + await UserSyncer._sync(diff) - calls = [mock.call("bot/users", json=user) for user in users] - self.bot.api_client.post.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.post.call_count, len(users)) + self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created) self.bot.api_client.put.assert_not_called() self.bot.api_client.delete.assert_not_called() @@ -146,13 +208,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Only PUT requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] - user_tuples = {_User(**user) for user in users} - diff = _Diff(set(), user_tuples, None) - await self.syncer._sync(diff) + diff = _Diff([], users, None) + await UserSyncer._sync(diff) - calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] - self.bot.api_client.put.assert_has_calls(calls, any_order=True) - self.assertEqual(self.bot.api_client.put.call_count, len(users)) + self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated) self.bot.api_client.post.assert_not_called() self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index d3f2995fb..daede54c5 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -1,4 +1,3 @@ -import asyncio import textwrap import unittest import unittest.mock @@ -13,7 +12,7 @@ from tests import helpers COG_PATH = "bot.exts.info.information.Information" -class InformationCogTests(unittest.TestCase): +class InformationCogTests(unittest.IsolatedAsyncioTestCase): """Tests the Information cog.""" @classmethod @@ -29,16 +28,14 @@ class InformationCogTests(unittest.TestCase): self.ctx = helpers.MockContext() self.ctx.author.roles.append(self.moderator_role) - def test_roles_command_command(self): + async def test_roles_command_command(self): """Test if the `role_info` command correctly returns the `moderator_role`.""" self.ctx.guild.roles.append(self.moderator_role) self.cog.roles_info.can_run = unittest.mock.AsyncMock() self.cog.roles_info.can_run.return_value = True - coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - - self.assertIsNone(asyncio.run(coroutine)) + self.assertIsNone(await self.cog.roles_info(self.cog, self.ctx)) self.ctx.send.assert_called_once() _, kwargs = self.ctx.send.call_args @@ -48,7 +45,7 @@ class InformationCogTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - def test_role_info_command(self): + async def test_role_info_command(self): """Tests the `role info` command.""" dummy_role = helpers.MockRole( name="Dummy", @@ -73,9 +70,7 @@ class InformationCogTests(unittest.TestCase): self.cog.role_info.can_run = unittest.mock.AsyncMock() self.cog.role_info.can_run.return_value = True - coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - - self.assertIsNone(asyncio.run(coroutine)) + self.assertIsNone(await self.cog.role_info(self.cog, self.ctx, dummy_role, admin_role)) self.assertEqual(self.ctx.send.call_count, 2) @@ -97,80 +92,8 @@ class InformationCogTests(unittest.TestCase): self.assertEqual(admin_embed.title, "Admins info") self.assertEqual(admin_embed.colour, discord.Colour.red()) - @unittest.mock.patch('bot.exts.info.information.time_since') - def test_server_info_command(self, time_since_patch): - time_since_patch.return_value = '2 days ago' - - self.ctx.guild = helpers.MockGuild( - features=('lemons', 'apples'), - region="The Moon", - roles=[self.moderator_role], - channels=[ - discord.TextChannel( - state={}, - guild=self.ctx.guild, - data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} - ), - discord.CategoryChannel( - state={}, - guild=self.ctx.guild, - data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} - ), - discord.VoiceChannel( - state={}, - guild=self.ctx.guild, - data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} - ) - ], - members=[ - *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), - *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), - *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), - *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), - ], - member_count=1_234, - icon_url='a-lemon.jpg', - ) - - coroutine = self.cog.server_info.callback(self.cog, self.ctx) - self.assertIsNone(asyncio.run(coroutine)) - - time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') - _, kwargs = self.ctx.send.call_args - embed = kwargs.pop('embed') - self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual( - embed.description, - textwrap.dedent( - f""" - **Server information** - Created: {time_since_patch.return_value} - Voice region: {self.ctx.guild.region} - Features: {', '.join(self.ctx.guild.features)} - - **Channel counts** - Category channels: 1 - Text channels: 1 - Voice channels: 1 - Staff channels: 0 - - **Member counts** - Members: {self.ctx.guild.member_count:,} - Staff members: 0 - Roles: {len(self.ctx.guild.roles)} - - **Member statuses** - {constants.Emojis.status_online} 2 - {constants.Emojis.status_idle} 1 - {constants.Emojis.status_dnd} 4 - {constants.Emojis.status_offline} 3 - """ - ) - ) - self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - -class UserInfractionHelperMethodTests(unittest.TestCase): +class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase): """Tests for the helper methods of the `!user` command.""" def setUp(self): @@ -180,7 +103,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): self.cog = information.Information(self.bot) self.member = helpers.MockMember(id=1234) - def test_user_command_helper_method_get_requests(self): + async def test_user_command_helper_method_get_requests(self): """The helper methods should form the correct get requests.""" test_values = ( { @@ -202,11 +125,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase): endpoint, params = test_value["expected_args"] with self.subTest(method=helper_method, endpoint=endpoint, params=params): - asyncio.run(helper_method(self.member)) + await helper_method(self.member) self.bot.api_client.get.assert_called_once_with(endpoint, params=params) self.bot.api_client.get.reset_mock() - def _method_subtests(self, method, test_values, default_header): + async def _method_subtests(self, method, test_values, default_header): """Helper method that runs the subtests for the different helper methods.""" for test_value in test_values: api_response = test_value["api response"] @@ -216,11 +139,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase): self.bot.api_client.get.return_value = api_response expected_output = "\n".join(expected_lines) - actual_output = asyncio.run(method(self.member)) + actual_output = await method(self.member) self.assertEqual((default_header, expected_output), actual_output) - def test_basic_user_infraction_counts_returns_correct_strings(self): + async def test_basic_user_infraction_counts_returns_correct_strings(self): """The method should correctly list both the total and active number of non-hidden infractions.""" test_values = ( # No infractions means zero counts @@ -251,9 +174,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase): header = "Infractions" - self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) + await self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) - def test_expanded_user_infraction_counts_returns_correct_strings(self): + async def test_expanded_user_infraction_counts_returns_correct_strings(self): """The method should correctly list the total and active number of all infractions split by infraction type.""" test_values = ( { @@ -306,9 +229,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase): header = "Infractions" - self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) + await self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) - def test_user_nomination_counts_returns_correct_strings(self): + async def test_user_nomination_counts_returns_correct_strings(self): """The method should list the number of active and historical nominations for the user.""" test_values = ( { @@ -336,12 +259,12 @@ class UserInfractionHelperMethodTests(unittest.TestCase): header = "Nominations" - self._method_subtests(self.cog.user_nomination_counts, test_values, header) + await self._method_subtests(self.cog.user_nomination_counts, test_values, header) @unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) @unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): +class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """Tests for the creation of the `!user` embed.""" def setUp(self): @@ -354,14 +277,14 @@ class UserEmbedTests(unittest.TestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): + async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) user = helpers.MockMember() user.nick = None user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) self.assertEqual(embed.title, "Mr. Hemlock") @@ -369,14 +292,14 @@ class UserEmbedTests(unittest.TestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_uses_nick_in_title_if_available(self): + async def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) user = helpers.MockMember() user.nick = "Cat lover" user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") @@ -384,7 +307,7 @@ class UserEmbedTests(unittest.TestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_ignores_everyone_role(self): + async def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) admins_role = helpers.MockRole(name='Admins') @@ -393,14 +316,18 @@ class UserEmbedTests(unittest.TestCase): # A `MockMember` has the @Everyone role by default; we add the Admins to that. user = helpers.MockMember(roles=[admins_role], top_role=admins_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) self.assertIn("&Admins", embed.fields[1].value) self.assertNotIn("&Everyone", embed.fields[1].value) @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): + async def test_create_user_embed_expanded_information_in_moderation_channels( + self, + nomination_counts, + infraction_counts + ): """The embed should contain expanded infractions and nomination info in mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -411,7 +338,7 @@ class UserEmbedTests(unittest.TestCase): nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) infraction_counts.assert_called_once_with(user) nomination_counts.assert_called_once_with(user) @@ -434,7 +361,7 @@ class UserEmbedTests(unittest.TestCase): ) @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) - def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): + async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): """The embed should contain only basic infraction data outside of mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -444,7 +371,7 @@ class UserEmbedTests(unittest.TestCase): infraction_counts.return_value = ("Infractions", "basic infractions info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) infraction_counts.assert_called_once_with(user) @@ -467,14 +394,14 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual( "basic infractions info", - embed.fields[3].value + embed.fields[2].value ) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): + async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -482,7 +409,7 @@ class UserEmbedTests(unittest.TestCase): moderators_role.colour = 100 user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) @@ -490,12 +417,12 @@ class UserEmbedTests(unittest.TestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): + async def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() user = helpers.MockMember(id=217) - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) self.assertEqual(embed.colour, discord.Colour.blurple()) @@ -503,20 +430,20 @@ class UserEmbedTests(unittest.TestCase): f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) ) - def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): + async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() user = helpers.MockMember(id=217) user.avatar_url_as.return_value = "avatar url" - embed = asyncio.run(self.cog.create_user_embed(ctx, user)) + embed = await self.cog.create_user_embed(ctx, user) user.avatar_url_as.assert_called_once_with(static_format="png") self.assertEqual(embed.thumbnail.url, "avatar url") @unittest.mock.patch("bot.exts.info.information.constants") -class UserCommandTests(unittest.TestCase): +class UserCommandTests(unittest.IsolatedAsyncioTestCase): """Tests for the `!user` command.""" def setUp(self): @@ -536,16 +463,16 @@ class UserCommandTests(unittest.TestCase): # used as a default value for a parameter, which gets defined upon import. self.bot_command_channel = helpers.MockTextChannel(id=constants.Channels.bot_commands) - def test_regular_member_cannot_target_another_member(self, constants): + async def test_regular_member_cannot_target_another_member(self, constants): """A regular user should not be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] ctx = helpers.MockContext(author=self.author) - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + await self.cog.user_info(self.cog, ctx, self.target) ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") - def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): + async def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): """A regular user should not be able to use this command outside of bot-commands.""" constants.MODERATION_ROLES = [self.moderator_role.id] constants.STAFF_ROLES = [self.moderator_role.id] @@ -553,49 +480,49 @@ class UserCommandTests(unittest.TestCase): msg = "Sorry, but you may only use this command within <#50>." with self.assertRaises(InWhitelistCheckFailure, msg=msg): - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + await self.cog.user_info(self.cog, ctx) @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") - def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): + async def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel) - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + await self.cog.user_info(self.cog, ctx) create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") - def test_regular_user_can_explicitly_target_themselves(self, create_embed, _): + async def test_regular_user_can_explicitly_target_themselves(self, create_embed, _): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel) - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) + await self.cog.user_info(self.cog, ctx, self.author) create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") - def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): + async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) - asyncio.run(self.cog.user_info.callback(self.cog, ctx)) + await self.cog.user_info(self.cog, ctx) create_embed.assert_called_once_with(ctx, self.moderator) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") - def test_moderators_can_target_another_member(self, create_embed, constants): + async def test_moderators_can_target_another_member(self, create_embed, constants): """A moderator should be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] constants.STAFF_ROLES = [self.moderator_role.id] ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) - asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) + await self.cog.user_info(self.cog, ctx, self.target) create_embed.assert_called_once_with(ctx, self.target) ctx.send.assert_called_once() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index be1b649e1..bf557a484 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,7 +1,8 @@ import textwrap import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from bot.constants import Event from bot.exts.moderation.infraction.infractions import Infractions from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.cog.apply_infraction.assert_awaited_once_with( self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value ) + + +@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) +class VoiceBanTests(unittest.IsolatedAsyncioTestCase): + """Tests for voice ban related functions and commands.""" + + def setUp(self): + self.bot = MockBot() + self.mod = MockMember(top_role=10) + self.user = MockMember(top_role=1, roles=[MockRole(id=123456)]) + self.guild = MockGuild() + self.ctx = MockContext(bot=self.bot, author=self.mod) + self.cog = Infractions(self.bot) + + async def test_permanent_voice_ban(self): + """Should call voice ban applying function without expiry.""" + self.cog.apply_voice_ban = AsyncMock() + self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) + self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar") + + async def test_temporary_voice_ban(self): + """Should call voice ban applying function with expiry.""" + self.cog.apply_voice_ban = AsyncMock() + self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) + self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + + async def test_voice_unban(self): + """Should call infraction pardoning function.""" + self.cog.pardon_infraction = AsyncMock() + self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) + self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): + """Should return early when user already have Voice Ban infraction.""" + get_active_infraction.return_value = {"foo": "bar"} + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") + post_infraction_mock.assert_not_awaited() + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): + """Should return early when posting infraction fails.""" + self.cog.mod_log.ignore = MagicMock() + get_active_infraction.return_value = None + post_infraction_mock.return_value = None + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + post_infraction_mock.assert_awaited_once() + self.cog.mod_log.ignore.assert_not_called() + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): + """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" + get_active_infraction.return_value = None + # We don't want that this continue yet + post_infraction_mock.return_value = None + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) + post_infraction_mock.assert_awaited_once_with( + self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 + ) + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): + """Should ignore Voice Verified role removing.""" + self.cog.mod_log.ignore = MagicMock() + self.cog.apply_infraction = AsyncMock() + self.user.remove_roles = MagicMock(return_value="my_return_value") + + get_active_infraction.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): + """Should ignore Voice Verified role removing.""" + self.cog.mod_log.ignore = MagicMock() + self.cog.apply_infraction = AsyncMock() + self.user.remove_roles = MagicMock(return_value="my_return_value") + + get_active_infraction.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar") + self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + + @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") + @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") + async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): + """Should truncate reason for voice ban.""" + self.cog.mod_log.ignore = MagicMock() + self.cog.apply_infraction = AsyncMock() + self.user.remove_roles = MagicMock(return_value="my_return_value") + + get_active_infraction.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) + self.user.remove_roles.assert_called_once_with( + self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...") + ) + self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + + async def test_voice_unban_user_not_found(self): + """Should include info to return dict when user was not found from guild.""" + self.guild.get_member.return_value = None + result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + self.assertEqual(result, {"Info": "User was not found in the guild."}) + + @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") + @patch("bot.exts.moderation.infraction.infractions.format_user") + async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): + """Should add role back with ignoring, notify user and return log dictionary..""" + self.guild.get_member.return_value = self.user + notify_pardon_mock.return_value = True + format_user_mock.return_value = "my-user" + + result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + self.assertEqual(result, { + "Member": "my-user", + "DM": "Sent" + }) + notify_pardon_mock.assert_awaited_once() + + @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") + @patch("bot.exts.moderation.infraction.infractions.format_user") + async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): + """Should add role back with ignoring, notify user and return log dictionary..""" + self.guild.get_member.return_value = self.user + notify_pardon_mock.return_value = False + format_user_mock.return_value = "my-user" + + result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + self.assertEqual(result, { + "Member": "my-user", + "DM": "**Failed**" + }) + notify_pardon_mock.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index e2d44c637..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio import unittest +from datetime import datetime, timezone from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule(): # noqa: N802 + """Create and connect to the fakeredis session.""" + global redis_session + redis_session = RedisSession(use_fakeredis=True) + redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule(): # noqa: N802 + """Close the fakeredis session.""" + if redis_session: + redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): + """A datetime object with a mocked now() function.""" + + now = mock.create_autospec(datetime, "now") class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.alert_channel = MockTextChannel() - self.notifier = SilenceNotifier(self.alert_channel) + self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() self.notifier.start = self.notifier_start_mock = Mock() def test_add_channel_adds_channel(self): - """Channel in FirstHash with current loop is added to internal set.""" + """Channel is added to `_silenced_channels` with the current loop.""" channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.notifier_start_mock.assert_not_called() def test_remove_channel_removes_channel(self): - """Channel in FirstHash is removed from `_silenced_channels`.""" + """Channel is removed from `_silenced_channels`.""" channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): with self.subTest(current_loop=current_loop): with mock.patch.object(self.notifier, "_current_loop", new=current_loop): await self.notifier._notifier() - self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") + self.alert_channel.send.assert_called_once_with( + f"<@&{Roles.moderators}> currently silenced channels: " + ) self.alert_channel.send.reset_mock() async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): + """Tests for the general functionality of the Silence cog.""" + + @autospec(silence, "Scheduler", pass_mocks=False) def setUp(self) -> None: self.bot = MockBot() - self.cog = Silence(self.bot) - self.ctx = MockContext() - self.cog._verified_role = None - # Set event so command callbacks can continue. - self.cog._get_instance_vars_event.set() + self.cog = silence.Silence(self.bot) - async def test_instance_vars_got_guild(self): + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_guild(self): """Bot got guild after it became available.""" - await self.cog._get_instance_vars() - self.bot.wait_until_guild_available.assert_called_once() + await self.cog._async_init() + self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) - async def test_instance_vars_got_role(self): + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_role(self): """Got `Roles.verified` role from guild.""" - await self.cog._get_instance_vars() guild = self.bot.get_guild() - guild.get_role.assert_called_once_with(Roles.verified) + guild.get_role.side_effect = lambda id_: Mock(id=id_) - async def test_instance_vars_got_channels(self): + await self.cog._async_init() + self.assertEqual(self.cog._verified_role.id, Roles.verified) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_got_channels(self): """Got channels from bot.""" - await self.cog._get_instance_vars() - self.bot.get_channel.called_once_with(Channels.mod_alerts) - self.bot.get_channel.called_once_with(Channels.mod_log) + self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + + await self.cog._async_init() + self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) - @mock.patch("bot.exts.moderation.silence.SilenceNotifier") - async def test_instance_vars_got_notifier(self, notifier): + @autospec(silence, "SilenceNotifier") + async def test_async_init_got_notifier(self, notifier): """Notifier was started with channel.""" - mod_log = MockTextChannel() - self.bot.get_channel.side_effect = (None, mod_log) - await self.cog._get_instance_vars() - notifier.assert_called_once_with(mod_log) - self.bot.get_channel.side_effect = None - - async def test_silence_sent_correct_discord_message(self): - """Check if proper message was sent when called with duration in channel with previous state.""" + self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + + await self.cog._async_init() + notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) + self.assertEqual(self.cog.notifier, notifier.return_value) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def test_async_init_rescheduled(self): + """`_reschedule_` coroutine was awaited.""" + self.cog._reschedule = mock.create_autospec(self.cog._reschedule) + await self.cog._async_init() + self.cog._reschedule.assert_awaited_once_with() + + def test_cog_unload_cancelled_tasks(self): + """The init task was cancelled.""" + self.cog._init_task = asyncio.Future() + self.cog.cog_unload() + + # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. + self.assertTrue(self.cog._init_task.cancelled()) + + @autospec("discord.ext.commands", "has_any_role") + @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) + async def test_cog_check(self, role_check): + """Role check was called with `MODERATION_ROLES`""" + ctx = MockContext() + role_check.return_value.predicate = mock.AsyncMock() + + await self.cog.cog_check(ctx) + role_check.assert_called_once_with(*(1, 2, 3)) + role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): + """Tests for the rescheduling of cached unsilences.""" + + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self): + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + + with mock.patch.object(self.cog, "_reschedule", autospec=True): + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + async def test_skipped_missing_channel(self): + """Did nothing because the channel couldn't be retrieved.""" + self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] + self.bot.get_channel.return_value = None + + await self.cog._reschedule() + + self.cog.notifier.add_channel.assert_not_called() + self.cog._unsilence_wrapper.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + async def test_added_permanent_to_notifier(self): + """Permanently silenced channels were added to the notifier.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + + await self.cog._reschedule() + + self.cog.notifier.add_channel.assert_any_call(channels[0]) + self.cog.notifier.add_channel.assert_any_call(channels[1]) + + self.cog._unsilence_wrapper.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + async def test_unsilenced_expired(self): + """Unsilenced expired silences.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + + await self.cog._reschedule() + + self.cog._unsilence_wrapper.assert_any_call(channels[0]) + self.cog._unsilence_wrapper.assert_any_call(channels[1]) + + self.cog.notifier.add_channel.assert_not_called() + self.cog.scheduler.schedule_later.assert_not_called() + + @mock.patch.object(silence, "datetime", new=PatchedDatetime) + async def test_rescheduled_active(self): + """Rescheduled active silences.""" + channels = [MockTextChannel(id=123), MockTextChannel(id=456)] + self.bot.get_channel.side_effect = channels + self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] + silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + + self.cog._unsilence_wrapper = mock.MagicMock() + unsilence_return = self.cog._unsilence_wrapper.return_value + + await self.cog._reschedule() + + # Yuck. + calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] + self.cog.scheduler.schedule_later.assert_has_calls(calls) + + unsilence_calls = [mock.call(channel) for channel in channels] + self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + + self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): + """Tests for the silence command and its related helper methods.""" + + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + self.cog._init_task = asyncio.Future() + self.cog._init_task.set_result(None) + + # Avoid unawaited coroutine warnings. + self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + self.channel = MockTextChannel() + self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) + self.channel.overwrites_for.return_value = self.overwrite + + async def test_sent_correct_message(self): + """Appropriate failure/success message was sent by the command.""" test_cases = ( - (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), - (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), - (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), + (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), + (None, silence.MSG_SILENCE_PERMANENT, True,), + (5, silence.MSG_SILENCE_FAIL, False,), ) - for duration, result_message, _silence_patch_return in test_cases: - with self.subTest( - silence_duration=duration, - result_message=result_message, - starting_unsilenced_state=_silence_patch_return - ): - with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - await self.cog.silence.callback(self.cog, self.ctx, duration) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_unsilence_sent_correct_discord_message(self): - """Check if proper message was sent when unsilencing channel.""" - test_cases = ( - (True, f"{Emojis.check_mark} unsilenced current channel."), - (False, f"{Emojis.cross_mark} current channel was not silenced.") + for duration, message, was_silenced in test_cases: + ctx = MockContext() + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): + with self.subTest(was_silenced=was_silenced, message=message, duration=duration): + await self.cog.silence.callback(self.cog, ctx, duration) + ctx.send.assert_called_once_with(message) + + async def test_skipped_already_silenced(self): + """Permissions were not set and `False` was returned for an already silenced channel.""" + subtests = ( + (False, PermissionOverwrite(send_messages=False, add_reactions=False)), + (True, PermissionOverwrite(send_messages=True, add_reactions=True)), + (True, PermissionOverwrite(send_messages=False, add_reactions=False)), ) - for _unsilence_patch_return, result_message in test_cases: - with self.subTest( - starting_silenced_state=_unsilence_patch_return, - result_message=result_message - ): - with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): - await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(result_message) - self.ctx.reset_mock() - - async def test_silence_private_for_false(self): - """Permissions are not set and `False` is returned in an already silenced channel.""" - perm_overwrite = Mock(send_messages=False) - channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - - self.assertFalse(await self.cog._silence(channel, True, None)) - channel.set_permissions.assert_not_called() - async def test_silence_private_silenced_channel(self): - """Channel had `send_message` permissions revoked.""" - channel = MockTextChannel() - self.assertTrue(await self.cog._silence(channel, False, None)) - channel.set_permissions.assert_called_once() - self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + for contains, overwrite in subtests: + with self.subTest(contains=contains, overwrite=overwrite): + self.cog.scheduler.__contains__.return_value = contains + channel = MockTextChannel() + channel.overwrites_for.return_value = overwrite + + self.assertFalse(await self.cog._set_silence_overwrites(channel)) + channel.set_permissions.assert_not_called() + + async def test_silenced_channel(self): + """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) + self.assertFalse(self.overwrite.send_messages) + self.assertFalse(self.overwrite.add_reactions) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite + ) - async def test_silence_private_preserves_permissions(self): - """Previous permissions were preserved when channel was silenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite() - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._silence(channel, False, None) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - async def test_silence_private_notifier(self): - """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" - channel = MockTextChannel() - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=True): - await self.cog._silence(channel, True, None) - self.cog.notifier.add_channel.assert_called_once() - - with mock.patch.object(self.cog, "notifier", create=True): - with self.subTest(persistent=False): - await self.cog._silence(channel, False, None) - self.cog.notifier.add_channel.assert_not_called() - - async def test_silence_private_added_muted_channel(self): - """Channel was added to `muted_channels` on silence.""" + async def test_preserved_other_overwrites(self): + """Channel's other unrelated overwrites were not changed.""" + prev_overwrite_dict = dict(self.overwrite) + await self.cog._set_silence_overwrites(self.channel) + new_overwrite_dict = dict(self.overwrite) + + # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. + del prev_overwrite_dict['send_messages'] + del prev_overwrite_dict['add_reactions'] + del new_overwrite_dict['send_messages'] + del new_overwrite_dict['add_reactions'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_temp_not_added_to_notifier(self): + """Channel was not added to notifier if a duration was set for the silence.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): + await self.cog.silence.callback(self.cog, MockContext(), 15) + self.cog.notifier.add_channel.assert_not_called() + + async def test_indefinite_added_to_notifier(self): + """Channel was added to notifier if a duration was not set for the silence.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): + await self.cog.silence.callback(self.cog, MockContext(), None) + self.cog.notifier.add_channel.assert_called_once() + + async def test_silenced_not_added_to_notifier(self): + """Channel was not added to the notifier if it was already silenced.""" + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): + await self.cog.silence.callback(self.cog, MockContext(), 15) + self.cog.notifier.add_channel.assert_not_called() + + async def test_cached_previous_overwrites(self): + """Channel's previous overwrites were cached.""" + overwrite_json = '{"send_messages": true, "add_reactions": false}' + await self.cog._set_silence_overwrites(self.channel) + self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + + @autospec(silence, "datetime") + async def test_cached_unsilence_time(self, datetime_mock): + """The UTC POSIX timestamp for the unsilence was cached.""" + now_timestamp = 100 + duration = 15 + timestamp = now_timestamp + duration * 60 + datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, duration) + + self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) + datetime_mock.now.assert_called_once_with(tz=timezone.utc) # Ensure it's using an aware dt. + + async def test_cached_indefinite_time(self): + """A value of -1 was cached for a permanent silence.""" + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, None) + self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + + async def test_scheduled_task(self): + """An unsilence task was scheduled.""" + ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + + await self.cog.silence.callback(self.cog, ctx, 5) + + args = (300, ctx.channel.id, ctx.invoke.return_value) + self.cog.scheduler.schedule_later.assert_called_once_with(*args) + ctx.invoke.assert_called_once_with(self.cog.unsilence) + + async def test_permanent_not_scheduled(self): + """A task was not scheduled for a permanent silence.""" + ctx = MockContext(channel=self.channel) + await self.cog.silence.callback(self.cog, ctx, None) + self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): + """Tests for the unsilence command and its related helper methods.""" + + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _: MockTextChannel()) + self.cog = silence.Silence(self.bot) + self.cog._init_task = asyncio.Future() + self.cog._init_task.set_result(None) + + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + asyncio.run(self.cog._async_init()) # Populate instance attributes. + + self.cog.scheduler.__contains__.return_value = True + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + self.channel = MockTextChannel() + self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) + self.channel.overwrites_for.return_value = self.overwrite + + async def test_sent_correct_message(self): + """Appropriate failure/success message was sent by the command.""" + unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) + test_cases = ( + (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), + (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), + (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), + ) + for was_unsilenced, message, overwrite in test_cases: + ctx = MockContext() + with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): + with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): + ctx.channel.overwrites_for.return_value = overwrite + await self.cog.unsilence.callback(self.cog, ctx) + ctx.channel.send.assert_called_once_with(message) + + async def test_skipped_already_unsilenced(self): + """Permissions were not set and `False` was returned for an already unsilenced channel.""" + self.cog.scheduler.__contains__.return_value = False + self.cog.previous_overwrites.get.return_value = None channel = MockTextChannel() - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._silence(channel, False, None) - muted_channels.add.assert_called_once_with(channel) - async def test_unsilence_private_for_false(self): - """Permissions are not set and `False` is returned in an unsilenced channel.""" - channel = Mock() self.assertFalse(await self.cog._unsilence(channel)) channel.set_permissions.assert_not_called() - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_unsilenced_channel(self, _): - """Channel had `send_message` permissions restored""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - self.assertTrue(await self.cog._unsilence(channel)) - channel.set_permissions.assert_called_once() - self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_notifier(self, notifier): - """Channel was removed from `notifier` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - await self.cog._unsilence(channel) - notifier.remove_channel.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_removed_muted_channel(self, _): - """Channel was removed from `muted_channels` on unsilence.""" - perm_overwrite = MagicMock(send_messages=False) - channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._unsilence(channel) - muted_channels.discard.assert_called_once_with(channel) - - @mock.patch.object(Silence, "notifier", create=True) - async def test_unsilence_private_preserves_permissions(self, _): - """Previous permissions were preserved when channel was unsilenced.""" - channel = MockTextChannel() - # Set up mock channel permission state. - mock_permissions = PermissionOverwrite(send_messages=False) - mock_permissions_dict = dict(mock_permissions) - channel.overwrites_for.return_value = mock_permissions - await self.cog._unsilence(channel) - new_permissions = channel.set_permissions.call_args.kwargs - # Remove 'send_messages' key because it got changed in the method. - del new_permissions['send_messages'] - del mock_permissions_dict['send_messages'] - self.assertDictEqual(mock_permissions_dict, new_permissions) - - @mock.patch("bot.exts.moderation.silence.asyncio") - @mock.patch.object(Silence, "_mod_alerts_channel", create=True) - def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): - """Task for sending an alert was created with present `muted_channels`.""" - with mock.patch.object(self.cog, "muted_channels"): - self.cog.cog_unload() - alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") - asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - - @mock.patch("bot.exts.moderation.silence.asyncio") - def test_cog_unload_skips_task_start(self, asyncio_mock): - """No task created with no channels.""" - self.cog.cog_unload() - asyncio_mock.create_task.assert_not_called() + async def test_restored_overwrites(self): + """Channel's `send_message` and `add_reactions` overwrites were restored.""" + await self.cog._unsilence(self.channel) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite, + ) - @mock.patch("discord.ext.commands.has_any_role") - @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) - async def test_cog_check(self, role_check): - """Role check is called with `MODERATION_ROLES`""" - role_check.return_value.predicate = mock.AsyncMock() - await self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(*(1, 2, 3)) - role_check.return_value.predicate.assert_awaited_once_with(self.ctx) + # Recall that these values are determined by the fixture. + self.assertTrue(self.overwrite.send_messages) + self.assertFalse(self.overwrite.add_reactions) + + async def test_cache_miss_used_default_overwrites(self): + """Both overwrites were set to None due previous values not being found in the cache.""" + self.cog.previous_overwrites.get.return_value = None + + await self.cog._unsilence(self.channel) + self.channel.set_permissions.assert_awaited_once_with( + self.cog._verified_role, + overwrite=self.overwrite, + ) + + self.assertIsNone(self.overwrite.send_messages) + self.assertIsNone(self.overwrite.add_reactions) + + async def test_cache_miss_sent_mod_alert(self): + """A message was sent to the mod alerts channel.""" + self.cog.previous_overwrites.get.return_value = None + + await self.cog._unsilence(self.channel) + self.cog._mod_alerts_channel.send.assert_awaited_once() + + async def test_removed_notifier(self): + """Channel was removed from `notifier`.""" + await self.cog._unsilence(self.channel) + self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + + async def test_deleted_cached_overwrite(self): + """Channel was deleted from the overwrites cache.""" + await self.cog._unsilence(self.channel) + self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + + async def test_deleted_cached_time(self): + """Channel was deleted from the timestamp cache.""" + await self.cog._unsilence(self.channel) + self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + + async def test_cancelled_task(self): + """The scheduled unsilence task should be cancelled.""" + await self.cog._unsilence(self.channel) + self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + + async def test_preserved_other_overwrites(self): + """Channel's other unrelated overwrites were not changed, including cache misses.""" + for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): + with self.subTest(overwrite_json=overwrite_json): + self.cog.previous_overwrites.get.return_value = overwrite_json + + prev_overwrite_dict = dict(self.overwrite) + await self.cog._unsilence(self.channel) + new_overwrite_dict = dict(self.overwrite) + + # Remove these keys because they were modified by the unsilence. + del prev_overwrite_dict['send_messages'] + del prev_overwrite_dict['add_reactions'] + del new_overwrite_dict['send_messages'] + del new_overwrite_dict['add_reactions'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 40b2202aa..321a92445 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -42,9 +42,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with( - self.bot.http_session, "Test output.", extension="txt" - ) + mock_paste_util.assert_called_once_with("Test output.", extension="txt") def test_prepare_input(self): cases = ( @@ -52,6 +50,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), + ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', + 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), + ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', + 'print("How\'s it going?")', 'code block preceded by inline code'), + ('`print("Hello world!")`\ntext\n`print("Hello world!")`', + 'print("Hello world!")', 'one inline code block of two') ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): @@ -154,7 +159,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock(return_value=None) - await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') self.cog.continue_eval.assert_called_once_with(ctx, response) @@ -168,7 +173,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.continue_eval = AsyncMock() self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) - await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') self.cog.continue_eval.assert_called_with(ctx, response) @@ -180,7 +185,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.author.mention = '@LemonLemonishBeard#0042' ctx.send = AsyncMock() self.cog.jobs = (42,) - await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') ctx.send.assert_called_once_with( "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) @@ -188,8 +193,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_eval_command_call_help(self): """Test if the eval command call the help command if no code is provided.""" ctx = MockContext(command="sentinel") - await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with("sentinel") + await self.cog.eval_command(self.cog, ctx=ctx, code='') + ctx.send_help.assert_called_once_with(ctx.command) async def test_send_eval(self): """Test the send_eval function.""" @@ -290,7 +295,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) ) ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reactions.assert_called_once() + ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) response.delete.assert_called_once() async def test_continue_eval_does_not_continue(self): @@ -299,7 +304,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_eval(ctx, MockMessage()) self.assertEqual(actual, None) - ctx.message.clear_reactions.assert_called_once() + ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" diff --git a/tests/bot/patches/__init__.py b/tests/bot/patches/__init__.py deleted file mode 100644 index e69de29bb..000000000 --- a/tests/bot/patches/__init__.py +++ /dev/null diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 9a72723e2..66c2d9f92 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -5,11 +5,12 @@ from tests.bot.rules import DisallowedCase, RuleTest from tests.helpers import MockMessage discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> +unicode_emoji = "🧪" -def make_msg(author: str, n_emojis: int) -> MockMessage: +def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage: """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" - return MockMessage(author=author, content=discord_emoji * n_emojis) + return MockMessage(author=author, content=emoji * n_emojis) class DiscordEmojisRuleTests(RuleTest): @@ -20,16 +21,22 @@ class DiscordEmojisRuleTests(RuleTest): self.config = {"max": 2, "interval": 10} async def test_allows_messages_within_limit(self): - """Cases with a total amount of discord emojis within limit.""" + """Cases with a total amount of discord and unicode emojis within limit.""" cases = ( [make_msg("bob", 2)], [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], + [make_msg("bob", 2, unicode_emoji)], + [ + make_msg("alice", 1, unicode_emoji), + make_msg("bob", 2, unicode_emoji), + make_msg("alice", 1, unicode_emoji) + ], ) await self.run_allowed(cases) async def test_disallows_messages_beyond_limit(self): - """Cases with more than the allowed amount of discord emojis.""" + """Cases with more than the allowed amount of discord and unicode emojis.""" cases = ( DisallowedCase( [make_msg("bob", 3)], @@ -41,6 +48,20 @@ class DiscordEmojisRuleTests(RuleTest): ("alice",), 4, ), + DisallowedCase( + [make_msg("bob", 3, unicode_emoji)], + ("bob",), + 3, + ), + DisallowedCase( + [ + make_msg("alice", 2, unicode_emoji), + make_msg("bob", 2, unicode_emoji), + make_msg("alice", 2, unicode_emoji) + ], + ("alice",), + 4 + ) ) await self.run_disallowed(cases) diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 5e0855704..1b48f6560 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -5,11 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from tests.helpers import MockBot class PasteTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: - self.http_session = MagicMock() + patcher = patch("bot.instance", new=MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") async def test_url_and_sent_contents(self): @@ -17,10 +20,10 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): response = MagicMock( json=AsyncMock(return_value={"key": ""}) ) - self.http_session.post().__aenter__.return_value = response - self.http_session.post.reset_mock() - await send_to_paste_service(self.http_session, "Content") - self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + self.bot.http_session.post.return_value.__aenter__.return_value = response + self.bot.http_session.post.reset_mock() + await send_to_paste_service("Content") + self.bot.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") async def test_paste_returns_correct_url_on_success(self): @@ -34,41 +37,41 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): response = MagicMock( json=AsyncMock(return_value={"key": key}) ) - self.http_session.post().__aenter__.return_value = response + self.bot.http_session.post.return_value.__aenter__.return_value = response for expected_output, extension in test_cases: with self.subTest(msg=f"Send contents with extension {repr(extension)}"): self.assertEqual( - await send_to_paste_service(self.http_session, "", extension=extension), + await send_to_paste_service("", extension=extension), expected_output ) async def test_request_repeated_on_json_errors(self): """Json with error message and invalid json are handled as errors and requests repeated.""" test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) - self.http_session.post().__aenter__.return_value = response = MagicMock() - self.http_session.post.reset_mock() + self.bot.http_session.post.return_value.__aenter__.return_value = response = MagicMock() + self.bot.http_session.post.reset_mock() for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertIsNone(result) - self.http_session.post.reset_mock() + self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" - self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" - self.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.bot.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) self.assertIsNone(result) diff --git a/tests/helpers.py b/tests/helpers.py index d35012278..ea5e2822c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools import logging import unittest.mock from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional import discord from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context from bot.api import APIClient from bot.async_stats import AsyncStatsClient from bot.bot import Bot +from tests._autospec import autospec # noqa: F401 other modules import it via this module for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: - """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" - # Caller's kwargs should take priority and overwrite the defaults. - kwargs = {'spec_set': True, 'autospec': True, **kwargs} - - # Import the target if it's a string. - # This is to support both object and string targets like patch.multiple. - if type(target) is str: - target = unittest.mock._importer(target) - - def decorator(func): - for attribute in attributes: - patcher = unittest.mock.patch.object(target, attribute, **kwargs) - func = patcher(func) - return func - return decorator - - class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. |