diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 548 | ||||
-rw-r--r-- | tests/bot/exts/__init__.py (renamed from tests/bot/cogs/__init__.py) | 0 | ||||
-rw-r--r-- | tests/bot/exts/backend/__init__.py (renamed from tests/bot/cogs/moderation/__init__.py) | 0 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/__init__.py (renamed from tests/bot/cogs/sync/__init__.py) | 0 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_base.py (renamed from tests/bot/cogs/sync/test_base.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py (renamed from tests/bot/cogs/sync/test_cog.py) | 17 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py (renamed from tests/bot/cogs/sync/test_roles.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/backend/sync/test_users.py (renamed from tests/bot/cogs/sync/test_users.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_logging.py (renamed from tests/bot/cogs/test_logging.py) | 6 | ||||
-rw-r--r-- | tests/bot/exts/filters/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/filters/test_antimalware.py (renamed from tests/bot/cogs/test_antimalware.py) | 24 | ||||
-rw-r--r-- | tests/bot/exts/filters/test_antispam.py (renamed from tests/bot/cogs/test_antispam.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/filters/test_security.py (renamed from tests/bot/cogs/test_security.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/filters/test_token_remover.py (renamed from tests/bot/cogs/test_token_remover.py) | 26 | ||||
-rw-r--r-- | tests/bot/exts/info/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/info/test_information.py (renamed from tests/bot/cogs/test_information.py) | 20 | ||||
-rw-r--r-- | tests/bot/exts/moderation/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py (renamed from tests/bot/cogs/moderation/test_infractions.py) | 8 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 359 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_incidents.py (renamed from tests/bot/cogs/moderation/test_incidents.py) | 66 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_modlog.py (renamed from tests/bot/cogs/moderation/test_modlog.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py (renamed from tests/bot/cogs/moderation/test_silence.py) | 20 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_slowmode.py (renamed from tests/bot/cogs/test_slowmode.py) | 14 | ||||
-rw-r--r-- | tests/bot/exts/test_cogs.py (renamed from tests/bot/cogs/test_cogs.py) | 8 | ||||
-rw-r--r-- | tests/bot/exts/utils/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_jams.py (renamed from tests/bot/cogs/test_jams.py) | 2 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py (renamed from tests/bot/cogs/test_snekbox.py) | 54 | ||||
-rw-r--r-- | tests/bot/test_pagination.py | 15 | ||||
-rw-r--r-- | tests/bot/utils/test_checks.py | 44 | ||||
-rw-r--r-- | tests/bot/utils/test_redis_cache.py | 265 | ||||
-rw-r--r-- | tests/bot/utils/test_services.py | 74 | ||||
-rw-r--r-- | tests/helpers.py | 6 |
33 files changed, 599 insertions, 989 deletions
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py deleted file mode 100644 index cfe10aebf..000000000 --- a/tests/bot/cogs/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.cogs import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.cogs.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.cogs.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/__init__.py b/tests/bot/exts/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/__init__.py +++ b/tests/bot/exts/__init__.py diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/exts/backend/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/moderation/__init__.py +++ b/tests/bot/exts/backend/__init__.py diff --git a/tests/bot/cogs/sync/__init__.py b/tests/bot/exts/backend/sync/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/sync/__init__.py +++ b/tests/bot/exts/backend/sync/__init__.py diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 70aea2bab..886c243cf 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -6,7 +6,7 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs.sync.syncers import Syncer, _Diff +from bot.exts.backend.sync._syncers import Syncer, _Diff from tests import helpers diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 120bc991d..1b89564f2 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -5,8 +5,9 @@ import discord from bot import constants from bot.api import ResponseCodeError -from bot.cogs import sync -from bot.cogs.sync.syncers import Syncer +from bot.exts.backend import sync +from bot.exts.backend.sync._cog import Sync +from bot.exts.backend.sync._syncers import Syncer from tests import helpers from tests.base import CommandTestCase @@ -29,19 +30,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): self.bot = helpers.MockBot() self.role_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.RoleSyncer", + "bot.exts.backend.sync._syncers.RoleSyncer", autospec=Syncer, spec_set=True ) self.user_syncer_patcher = mock.patch( - "bot.cogs.sync.syncers.UserSyncer", + "bot.exts.backend.sync._syncers.UserSyncer", autospec=Syncer, spec_set=True ) self.RoleSyncer = self.role_syncer_patcher.start() self.UserSyncer = self.user_syncer_patcher.start() - self.cog = sync.Sync(self.bot) + self.cog = Sync(self.bot) def tearDown(self): self.role_syncer_patcher.stop() @@ -59,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild): """Should instantiate syncers and run a sync for the guild.""" # Reset because a Sync cog was already instantiated in setUp. @@ -70,7 +71,7 @@ class SyncCogTests(SyncCogTestCase): mock_sync_guild_coro = mock.MagicMock() sync_guild.return_value = mock_sync_guild_coro - sync.Sync(self.bot) + Sync(self.bot) self.RoleSyncer.assert_called_once_with(self.bot) self.UserSyncer.assert_called_once_with(self.bot) @@ -131,7 +132,7 @@ class SyncCogListenerTests(SyncCogTestCase): super().setUp() self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - self.guild_id_patcher = mock.patch("bot.cogs.sync.cog.constants.Guild.id", 5) + self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5) self.guild_id = self.guild_id_patcher.start() self.guild = helpers.MockGuild(id=self.guild_id) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 79eee98f4..7b9f40cad 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock import discord -from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 002a947ad..c0a1da35c 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User from tests import helpers diff --git a/tests/bot/cogs/test_logging.py b/tests/bot/exts/backend/test_logging.py index 8a18fdcd6..466f207d9 100644 --- a/tests/bot/cogs/test_logging.py +++ b/tests/bot/exts/backend/test_logging.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import patch from bot import constants -from bot.cogs.logging import Logging +from bot.exts.backend.logging import Logging from tests.helpers import MockBot, MockTextChannel @@ -14,7 +14,7 @@ class LoggingTests(unittest.IsolatedAsyncioTestCase): self.cog = Logging(self.bot) self.dev_log = MockTextChannel(id=1234, name="dev-log") - @patch("bot.cogs.logging.DEBUG_MODE", False) + @patch("bot.exts.backend.logging.DEBUG_MODE", False) async def test_debug_mode_false(self): """Should send connected message to dev-log.""" self.bot.get_channel.return_value = self.dev_log @@ -24,7 +24,7 @@ class LoggingTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) self.dev_log.send.assert_awaited_once() - @patch("bot.cogs.logging.DEBUG_MODE", True) + @patch("bot.exts.backend.logging.DEBUG_MODE", True) async def test_debug_mode_true(self): """Should not send anything to dev-log.""" await self.cog.startup_greeting() diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/filters/__init__.py diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index ecb7abf00..3393c6cdc 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -3,8 +3,8 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound -from bot.cogs import antimalware from bot.constants import Channels, STAFF_ROLES +from bot.exts.filters import antimalware from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole @@ -23,6 +23,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.message.webhook_id = None + self.message.author.bot = None self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): @@ -48,6 +50,26 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.delete.assert_not_called() + async def test_webhook_message_with_illegal_extension(self): + """A webhook message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.webhook_id = 697140105563078727 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_bot_message_with_illegal_extension(self): + """A bot message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.author.bot = 409107086526644234 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + async def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" attachment = MockAttachment(filename="python.disallowed") diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/exts/filters/test_antispam.py index ce5472c71..6a0e4fded 100644 --- a/tests/bot/cogs/test_antispam.py +++ b/tests/bot/exts/filters/test_antispam.py @@ -1,6 +1,6 @@ import unittest -from bot.cogs import antispam +from bot.exts.filters import antispam class AntispamConfigurationValidationTests(unittest.TestCase): diff --git a/tests/bot/cogs/test_security.py b/tests/bot/exts/filters/test_security.py index 9d1a62f7e..c0c3baa42 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage -from bot.cogs import security +from bot.exts.filters import security from tests.helpers import MockBot, MockContext diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index 3349caa73..ea822053b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -6,9 +6,10 @@ from unittest.mock import MagicMock from discord import Colour, NotFound from bot import constants -from bot.cogs import token_remover -from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import Token, TokenRemover +from bot.exts.filters import token_remover +from bot.exts.filters.token_remover import Token, TokenRemover +from bot.exts.moderation.modlog import ModLog +from bot.utils.messages import format_user from tests.helpers import MockBot, MockMessage, autospec @@ -132,7 +133,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): await cog.on_message(msg) find_token_in_message.assert_not_called() - @autospec("bot.cogs.token_remover", "TOKEN_RE") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") def test_find_token_no_matches(self, token_re): """None should be returned if the regex matches no tokens in a message.""" token_re.finditer.return_value = () @@ -143,8 +144,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): token_re.finditer.assert_called_once_with(self.msg.content) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.token_remover", "Token") - @autospec("bot.cogs.token_remover", "TOKEN_RE") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): """The first match with a valid user ID and timestamp should be returned as a `Token`.""" matches = [ @@ -167,8 +168,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): token_re.finditer.assert_called_once_with(self.msg.content) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - @autospec("bot.cogs.token_remover", "Token") - @autospec("bot.cogs.token_remover", "TOKEN_RE") + @autospec("bot.exts.filters.token_remover", "Token") + @autospec("bot.exts.filters.token_remover", "TOKEN_RE") def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): """None should be returned if no matches have valid user IDs or timestamps.""" token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] @@ -230,7 +231,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = [match[0] for match in results] self.assertCountEqual((token_1, token_2), results) - @autospec("bot.cogs.token_remover", "LOG_MESSAGE") + @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") @@ -240,8 +241,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(return_value, log_message.format.return_value) log_message.format.assert_called_once_with( - author=self.msg.author, - author_id=self.msg.author.id, + author=format_user(self.msg.author), channel=self.msg.channel.mention, user_id=token.user_id, timestamp=token.timestamp, @@ -249,7 +249,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): ) @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - @autospec("bot.cogs.token_remover", "log") + @autospec("bot.exts.filters.token_remover", "log") @autospec(TokenRemover, "format_log_message") async def test_take_action(self, format_log_message, logger, mod_log_property): """Should delete the message and send a mod log.""" @@ -299,7 +299,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): class TokenRemoverExtensionTests(unittest.TestCase): """Tests for the token_remover extension.""" - @autospec("bot.cogs.token_remover", "TokenRemover") + @autospec("bot.exts.filters.token_remover", "TokenRemover") def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() diff --git a/tests/bot/exts/info/__init__.py b/tests/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/info/__init__.py diff --git a/tests/bot/cogs/test_information.py b/tests/bot/exts/info/test_information.py index 3438635f1..d3f2995fb 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -6,11 +6,11 @@ import unittest.mock import discord from bot import constants -from bot.cogs import information +from bot.exts.info import information from bot.utils.checks import InWhitelistCheckFailure from tests import helpers -COG_PATH = "bot.cogs.information.Information" +COG_PATH = "bot.exts.info.information.Information" class InformationCogTests(unittest.TestCase): @@ -97,7 +97,7 @@ class InformationCogTests(unittest.TestCase): self.assertEqual(admin_embed.title, "Admins info") self.assertEqual(admin_embed.colour, discord.Colour.red()) - @unittest.mock.patch('bot.cogs.information.time_since') + @unittest.mock.patch('bot.exts.info.information.time_since') def test_server_info_command(self, time_since_patch): time_since_patch.return_value = '2 days ago' @@ -339,8 +339,8 @@ class UserInfractionHelperMethodTests(unittest.TestCase): self._method_subtests(self.cog.user_nomination_counts, test_values, header) [email protected]("bot.cogs.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) [email protected]("bot.cogs.information.constants.MODERATION_CHANNELS", new=[50]) [email protected]("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) [email protected]("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50]) class UserEmbedTests(unittest.TestCase): """Tests for the creation of the `!user` embed.""" @@ -515,7 +515,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.thumbnail.url, "avatar url") [email protected]("bot.cogs.information.constants") [email protected]("bot.exts.info.information.constants") class UserCommandTests(unittest.TestCase): """Tests for the `!user` command.""" @@ -555,7 +555,7 @@ class UserCommandTests(unittest.TestCase): with self.assertRaises(InWhitelistCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -566,7 +566,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") def test_regular_user_can_explicitly_target_themselves(self, create_embed, _): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -577,7 +577,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -588,7 +588,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.moderator) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") def test_moderators_can_target_another_member(self, create_embed, constants): """A moderator should be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/exts/moderation/__init__.py b/tests/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/moderation/__init__.py diff --git a/tests/bot/exts/moderation/infraction/__init__.py b/tests/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/moderation/infraction/__init__.py diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index da4e92ccc..be1b649e1 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -2,7 +2,7 @@ import textwrap import unittest from unittest.mock import AsyncMock, Mock, patch -from bot.cogs.moderation.infractions import Infractions +from bot.exts.moderation.infraction.infractions import Infractions from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -17,8 +17,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild(id=4567) self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - @patch("bot.cogs.moderation.utils.get_active_infraction") - @patch("bot.cogs.moderation.utils.post_infraction") + @patch("bot.exts.moderation.infraction._utils.get_active_infraction") + @patch("bot.exts.moderation.infraction._utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" get_active_mock.return_value = None @@ -39,7 +39,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value ) - @patch("bot.cogs.moderation.utils.post_infraction") + @patch("bot.exts.moderation.infraction._utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): """Should truncate reason for `Member.kick`.""" post_infraction_mock.return_value = {"foo": "bar"} diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py new file mode 100644 index 000000000..5b62463e0 --- /dev/null +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -0,0 +1,359 @@ +import unittest +from collections import namedtuple +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, call, patch + +from discord import Embed, Forbidden, HTTPException, NotFound + +from bot.api import ResponseCodeError +from bot.constants import Colours, Icons +from bot.exts.moderation.infraction import _utils as utils +from tests.helpers import MockBot, MockContext, MockMember, MockUser + + +class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): + """Tests Moderation utils.""" + + def setUp(self): + self.bot = MockBot() + self.member = MockMember(id=1234) + self.user = MockUser(id=1234) + self.ctx = MockContext(bot=self.bot, author=self.member) + + async def test_post_user(self): + """Should POST a new user and return the response if successful or otherwise send an error message.""" + user = MockUser(discriminator=5678, id=1234, name="Test user") + not_user = MagicMock(discriminator=3333, id=5678, name="Wrong user") + test_cases = [ + { + "user": user, + "post_result": "bar", + "raise_error": None, + "payload": { + "discriminator": 5678, + "id": self.user.id, + "in_guild": False, + "name": "Test user", + "roles": [] + } + }, + { + "user": self.member, + "post_result": "foo", + "raise_error": ResponseCodeError(MagicMock(status=400), "foo"), + "payload": { + "discriminator": 0, + "id": self.member.id, + "in_guild": False, + "name": "Name unknown", + "roles": [] + } + }, + { + "user": not_user, + "post_result": "bar", + "raise_error": None, + "payload": { + "discriminator": not_user.discriminator, + "id": not_user.id, + "in_guild": False, + "name": not_user.name, + "roles": [] + } + } + ] + + for case in test_cases: + user = case["user"] + post_result = case["post_result"] + raise_error = case["raise_error"] + payload = case["payload"] + + with self.subTest(user=user, post_result=post_result, raise_error=raise_error, payload=payload): + self.bot.api_client.post.reset_mock(side_effect=True) + self.ctx.bot.api_client.post.return_value = post_result + + self.ctx.bot.api_client.post.side_effect = raise_error + + result = await utils.post_user(self.ctx, user) + + if raise_error: + self.assertIsNone(result) + self.ctx.send.assert_awaited_once() + self.assertIn(str(raise_error.status), self.ctx.send.call_args[0][0]) + else: + self.assertEqual(result, post_result) + self.bot.api_client.post.assert_awaited_once_with("bot/users", json=payload) + + async def test_get_active_infraction(self): + """ + Should request the API for active infractions and return infraction if the user has one or `None` otherwise. + + A message should be sent to the context indicating a user already has an infraction, if that's the case. + """ + test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"]) + test_cases = [ + test_case([], None, None, True), + test_case([{"id": 123987}], {"id": 123987}, "123987", False), + test_case([{"id": 123987}], {"id": 123987}, "123987", True) + ] + + for case in test_cases: + with self.subTest(return_value=case.get_return_value, expected=case.expected_output): + self.bot.api_client.get.reset_mock() + self.ctx.send.reset_mock() + + params = { + "active": "true", + "type": "ban", + "user__id": str(self.member.id) + } + + self.bot.api_client.get.return_value = case.get_return_value + + result = await utils.get_active_infraction(self.ctx, self.member, "ban", send_msg=case.send_msg) + self.assertEqual(result, case.expected_output) + self.bot.api_client.get.assert_awaited_once_with("bot/infractions", params=params) + + if case.send_msg and case.get_return_value: + self.ctx.send.assert_awaited_once() + sent_message = self.ctx.send.call_args[0][0] + self.assertIn(case.infraction_nr, sent_message) + self.assertIn("ban", sent_message) + else: + self.ctx.send.assert_not_awaited() + + @patch("bot.exts.moderation.infraction._utils.send_private_embed") + async def test_notify_infraction(self, send_private_embed_mock): + """ + Should send an embed of a certain format as a DM and return `True` if DM successful. + + Appealable infractions should have the appeal message in the embed's footer. + """ + test_cases = [ + { + "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Ban", + expires="2020-02-26 09:20 (23 hours and 59 minutes)", + reason="No reason provided." + ), + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.token_removed + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), + "send_result": True + }, + { + "args": (self.user, "warning", None, "Test reason."), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Warning", + expires="N/A", + reason="Test reason." + ), + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.token_removed + ), + "send_result": False + }, + { + "args": (self.user, "note", None, None, Icons.defcon_denied), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Note", + expires="N/A", + reason="No reason provided." + ), + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ), + "send_result": False + }, + { + "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Mute", + expires="2020-02-26 09:20 (23 hours and 59 minutes)", + reason="Test" + ), + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), + "send_result": False + }, + { + "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied), + "expected_output": Embed( + title=utils.INFRACTION_TITLE, + description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( + type="Mute", + expires="N/A", + reason="foo bar" * 4000 + )[:2045] + "...", + colour=Colours.soft_red, + url=utils.RULES_URL + ).set_author( + name=utils.INFRACTION_AUTHOR_NAME, + url=utils.RULES_URL, + icon_url=Icons.defcon_denied + ).set_footer(text=utils.INFRACTION_APPEAL_FOOTER), + "send_result": True + } + ] + + for case in test_cases: + with self.subTest(args=case["args"], expected=case["expected_output"], send=case["send_result"]): + send_private_embed_mock.reset_mock() + + send_private_embed_mock.return_value = case["send_result"] + result = await utils.notify_infraction(*case["args"]) + + self.assertEqual(case["send_result"], result) + + embed = send_private_embed_mock.call_args[0][1] + + self.assertEqual(embed.to_dict(), case["expected_output"].to_dict()) + + send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed) + + @patch("bot.exts.moderation.infraction._utils.send_private_embed") + async def test_notify_pardon(self, send_private_embed_mock): + """Should send an embed of a certain format as a DM and return `True` if DM successful.""" + test_case = namedtuple("test_case", ["args", "icon", "send_result"]) + test_cases = [ + test_case((self.user, "Test title", "Example content"), Icons.user_verified, True), + test_case((self.user, "Test title", "Example content", Icons.user_update), Icons.user_update, False) + ] + + for case in test_cases: + expected = Embed( + description="Example content", + colour=Colours.soft_green + ).set_author( + name="Test title", + icon_url=case.icon + ) + + with self.subTest(args=case.args, expected=expected): + send_private_embed_mock.reset_mock() + + send_private_embed_mock.return_value = case.send_result + + result = await utils.notify_pardon(*case.args) + self.assertEqual(case.send_result, result) + + embed = send_private_embed_mock.call_args[0][1] + self.assertEqual(embed.to_dict(), expected.to_dict()) + + send_private_embed_mock.assert_awaited_once_with(case.args[0], embed) + + async def test_send_private_embed(self): + """Should DM the user and return `True` on success or `False` on failure.""" + embed = Embed(title="Test", description="Test val") + + test_case = namedtuple("test_case", ["expected_output", "raised_exception"]) + test_cases = [ + test_case(True, None), + test_case(False, HTTPException(AsyncMock(), AsyncMock())), + test_case(False, Forbidden(AsyncMock(), AsyncMock())), + test_case(False, NotFound(AsyncMock(), AsyncMock())) + ] + + for case in test_cases: + with self.subTest(expected=case.expected_output, raised=case.raised_exception): + self.user.send.reset_mock(side_effect=True) + self.user.send.side_effect = case.raised_exception + + result = await utils.send_private_embed(self.user, embed) + + self.assertEqual(result, case.expected_output) + self.user.send.assert_awaited_once_with(embed=embed) + + +class TestPostInfraction(unittest.IsolatedAsyncioTestCase): + """Tests for the `post_infraction` function.""" + + def setUp(self): + self.bot = MockBot() + self.member = MockMember(id=1234) + self.user = MockUser(id=1234) + self.ctx = MockContext(bot=self.bot, author=self.member) + + async def test_normal_post_infraction(self): + """Should return response from POST request if there are no errors.""" + now = datetime.now() + payload = { + "actor": self.ctx.author.id, + "hidden": True, + "reason": "Test reason", + "type": "ban", + "user": self.member.id, + "active": False, + "expires_at": now.isoformat() + } + + self.ctx.bot.api_client.post.return_value = "foo" + actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) + + self.assertEqual(actual, "foo") + self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + + async def test_unknown_error_post_infraction(self): + """Should send an error message to chat when a non-400 error occurs.""" + self.ctx.bot.api_client.post.side_effect = ResponseCodeError(AsyncMock(), AsyncMock()) + self.ctx.bot.api_client.post.side_effect.status = 500 + + actual = await utils.post_infraction(self.ctx, self.user, "ban", "Test reason") + self.assertIsNone(actual) + + self.assertTrue("500" in self.ctx.send.call_args[0][0]) + + @patch("bot.exts.moderation.infraction._utils.post_user", return_value=None) + async def test_user_not_found_none_post_infraction(self, post_user_mock): + """Should abort and return `None` when a new user fails to be posted.""" + self.bot.api_client.post.side_effect = ResponseCodeError(MagicMock(status=400), {"user": "foo"}) + + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertIsNone(actual) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) + + @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar") + async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): + """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" + payload = { + "actor": self.ctx.author.id, + "hidden": False, + "reason": "Test reason", + "type": "mute", + "user": self.user.id, + "active": True + } + + self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] + + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertEqual(actual, "foo") + self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 435a1cd51..cbf7f7bcf 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,8 +8,8 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import aiohttp import discord -from bot.cogs.moderation import Incidents, incidents from bot.constants import Colours +from bot.exts.moderation import incidents from tests.helpers import ( MockAsyncWebhook, MockAttachment, @@ -130,7 +130,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): incident = MockMessage(content="this is an incident", attachments=[attachment]) # Patch `download_file` to return our `file` - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=file)): + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) self.assertIs(file, returned_file) @@ -142,7 +142,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): incident = MockMessage(content="this is an incident", attachments=[attachment]) # Patch `download_file` to return None as if the download failed - with patch("bot.cogs.moderation.incidents.download_file", AsyncMock(return_value=None)): + with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) self.assertIsNone(returned_file) @@ -215,7 +215,7 @@ class TestOwnReactions(unittest.TestCase): self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) -@patch("bot.cogs.moderation.incidents.ALL_SIGNALS", {"A", "B"}) +@patch("bot.exts.moderation.incidents.ALL_SIGNALS", {"A", "B"}) class TestHasSignals(unittest.TestCase): """ Assertions for the `has_signals` function. @@ -229,7 +229,7 @@ class TestHasSignals(unittest.TestCase): message = MockMessage() own_reactions = MagicMock(return_value={"A", "B"}) - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): self.assertTrue(incidents.has_signals(message)) def test_has_signals_false(self): @@ -237,11 +237,11 @@ class TestHasSignals(unittest.TestCase): message = MockMessage() own_reactions = MagicMock(return_value={"A", "C"}) - with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions): + with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): self.assertFalse(incidents.has_signals(message)) -@patch("bot.cogs.moderation.incidents.Signal", MockSignal) +@patch("bot.exts.moderation.incidents.Signal", MockSignal) class TestAddSignals(unittest.IsolatedAsyncioTestCase): """ Assertions for the `add_signals` coroutine. @@ -255,19 +255,19 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): """Prepare a mock incident message for tests to use.""" self.incident = MockMessage() - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set())) + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value=set())) async def test_add_signals_missing(self): """All emoji are added when none are present.""" await incidents.add_signals(self.incident) self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) async def test_add_signals_partial(self): """Only missing emoji are added when some are present.""" await incidents.add_signals(self.incident) self.incident.add_reaction.assert_has_calls([call("B")]) - @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) + @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) async def test_add_signals_present(self): """No emoji are added when all are present.""" await incidents.add_signals(self.incident) @@ -290,7 +290,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): Note that this will not schedule `crawl_incidents` in the background, as everything is being mocked. The `crawl_task` attribute will end up being None. """ - self.cog_instance = Incidents(MockBot()) + self.cog_instance = incidents.Incidents(MockBot()) @patch("asyncio.sleep", AsyncMock()) # Prevent the coro from sleeping to speed up the test @@ -326,25 +326,25 @@ class TestCrawlIncidents(TestIncidents): await self.cog_instance.crawl_incidents() self.cog_instance.bot.wait_until_guild_available.assert_awaited() - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) # Message doesn't qualify + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) async def test_crawl_incidents_noop_if_is_not_incident(self): """Signals are not added for a non-incident message.""" await self.cog_instance.crawl_incidents() incidents.add_signals.assert_not_awaited() - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=True)) # But already has signals async def test_crawl_incidents_noop_if_message_already_has_signals(self): """Signals are not added for messages which already have them.""" await self.cog_instance.crawl_incidents() incidents.add_signals.assert_not_awaited() - @patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies - @patch("bot.cogs.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals + @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) # Message qualifies + @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) # And doesn't have signals async def test_crawl_incidents_add_signals_called(self): """Message has signals added as it does not have them yet and qualifies as an incident.""" await self.cog_instance.crawl_incidents() @@ -384,7 +384,7 @@ class TestArchive(TestIncidents): ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this - with patch("bot.cogs.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): + with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) # Now we check that the webhook was given the correct args, and that `archive` returned True @@ -451,8 +451,8 @@ class TestMakeConfirmationTask(TestIncidents): self.assertFalse(created_check(payload=MagicMock(message_id=0))) -@patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2}) -@patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable +@patch("bot.exts.moderation.incidents.ALLOWED_ROLES", {1, 2}) +@patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable class TestProcessEvent(TestIncidents): """Tests for the `Incidents.process_event` coroutine.""" @@ -479,7 +479,7 @@ class TestProcessEvent(TestIncidents): async def test_process_event_no_archive_on_investigating(self): """Message is not archived on `Signal.INVESTIGATING`.""" - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: await self.cog_instance.process_event( reaction=incidents.Signal.INVESTIGATING.value, incident=MockMessage(), @@ -497,7 +497,7 @@ class TestProcessEvent(TestIncidents): """ incident = MockMessage() - with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): + with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, incident=incident, @@ -510,7 +510,7 @@ class TestProcessEvent(TestIncidents): """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" mock_task = AsyncMock() - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, incident=MockMessage(), @@ -530,7 +530,7 @@ class TestProcessEvent(TestIncidents): mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) try: - with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task): + with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, incident=MockMessage(), @@ -712,7 +712,7 @@ class TestOnRawReactionAdd(TestIncidents): self.cog_instance.process_event = AsyncMock() self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)): + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)): await self.cog_instance.on_raw_reaction_add(self.payload) self.cog_instance.process_event.assert_not_called() @@ -733,7 +733,7 @@ class TestOnRawReactionAdd(TestIncidents): self.cog_instance.process_event = AsyncMock() self.cog_instance.resolve_message = AsyncMock(return_value=incident) - with patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)): + with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)): await self.cog_instance.on_raw_reaction_add(self.payload) self.cog_instance.process_event.assert_called_with( @@ -751,20 +751,20 @@ class TestOnMessage(TestIncidents): function is tested in `TestIsIncident` - here we do not worry about it. """ - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True)) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) async def test_on_message_incident(self): """Messages qualifying as incidents are passed to `add_signals`.""" incident = MockMessage() - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: await self.cog_instance.on_message(incident) mock_add_signals.assert_called_once_with(incident) - @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False)) + @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) async def test_on_message_non_incident(self): """Messages not qualifying as incidents are ignored.""" - with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: + with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: await self.cog_instance.on_message(MockMessage()) mock_add_signals.assert_not_called() diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index f2809f40a..f8f142484 100644 --- a/tests/bot/cogs/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -2,7 +2,7 @@ import unittest import discord -from bot.cogs.moderation.modlog import ModLog +from bot.exts.moderation.modlog import ModLog from tests.helpers import MockBot, MockTextChannel diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index ab3d0742a..e2d44c637 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock, Mock from discord import PermissionOverwrite -from bot.cogs.moderation.silence import Silence, SilenceNotifier from bot.constants import Channels, Emojis, Guild, Roles +from bot.exts.moderation.silence import Silence, SilenceNotifier from tests.helpers import MockBot, MockContext, MockTextChannel @@ -99,7 +99,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.called_once_with(Channels.mod_alerts) self.bot.get_channel.called_once_with(Channels.mod_log) - @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") + @mock.patch("bot.exts.moderation.silence.SilenceNotifier") async def test_instance_vars_got_notifier(self, notifier): """Notifier was started with channel.""" mod_log = MockTextChannel() @@ -238,7 +238,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): del mock_permissions_dict['send_messages'] self.assertDictEqual(mock_permissions_dict, new_permissions) - @mock.patch("bot.cogs.moderation.silence.asyncio") + @mock.patch("bot.exts.moderation.silence.asyncio") @mock.patch.object(Silence, "_mod_alerts_channel", create=True) def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): """Task for sending an alert was created with present `muted_channels`.""" @@ -247,15 +247,17 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - @mock.patch("bot.cogs.moderation.silence.asyncio") + @mock.patch("bot.exts.moderation.silence.asyncio") def test_cog_unload_skips_task_start(self, asyncio_mock): """No task created with no channels.""" self.cog.cog_unload() asyncio_mock.create_task.assert_not_called() - @mock.patch("bot.cogs.moderation.silence.with_role_check") - @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): + @mock.patch("discord.ext.commands.has_any_role") + @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) + async def test_cog_check(self, role_check): """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) + role_check.return_value.predicate = mock.AsyncMock() + await self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(*(1, 2, 3)) + role_check.return_value.predicate.assert_awaited_once_with(self.ctx) diff --git a/tests/bot/cogs/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py index f442814c8..dad751e0d 100644 --- a/tests/bot/cogs/test_slowmode.py +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -3,8 +3,8 @@ from unittest import mock from dateutil.relativedelta import relativedelta -from bot.cogs.moderation.slowmode import Slowmode from bot.constants import Emojis +from bot.exts.moderation.slowmode import Slowmode from tests.helpers import MockBot, MockContext, MockTextChannel @@ -103,9 +103,11 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase): f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' ) - @mock.patch("bot.cogs.moderation.slowmode.with_role_check") - @mock.patch("bot.cogs.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) - def test_cog_check(self, role_check): + @mock.patch("bot.exts.moderation.slowmode.has_any_role") + @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) + async def test_cog_check(self, role_check): """Role check is called with `MODERATION_ROLES`""" - self.cog.cog_check(self.ctx) - role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) + role_check.return_value.predicate = mock.AsyncMock() + await self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(*(1, 2, 3)) + role_check.return_value.predicate.assert_awaited_once_with(self.ctx) diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/exts/test_cogs.py index fdda59a8f..f8e120262 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -10,7 +10,7 @@ from unittest import mock from discord.ext import commands -from bot import cogs +from bot import exts class CommandNameTests(unittest.TestCase): @@ -29,13 +29,14 @@ class CommandNameTests(unittest.TestCase): @staticmethod def walk_modules() -> t.Iterator[ModuleType]: - """Yield imported modules from the bot.cogs subpackage.""" + """Yield imported modules from the bot.exts subpackage.""" def on_error(name: str) -> t.NoReturn: raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. with mock.patch("discord.ext.tasks.loop"): - for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): + prefix = f"{exts.__name__}." + for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): if not module.ispkg: yield importlib.import_module(module.name) @@ -53,6 +54,7 @@ class CommandNameTests(unittest.TestCase): """Return a list of all qualified names, including aliases, for the `command`.""" names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] names.append(command.qualified_name) + names += getattr(command, "root_aliases", []) return names diff --git a/tests/bot/exts/utils/__init__.py b/tests/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/utils/__init__.py diff --git a/tests/bot/cogs/test_jams.py b/tests/bot/exts/utils/test_jams.py index b4ad8535f..45e7b5b51 100644 --- a/tests/bot/cogs/test_jams.py +++ b/tests/bot/exts/utils/test_jams.py @@ -3,8 +3,8 @@ from unittest.mock import AsyncMock, MagicMock, create_autospec from discord import CategoryChannel -from bot.cogs import jams from bot.constants import Roles +from bot.exts.utils import jams from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 343e37db9..40b2202aa 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -1,13 +1,12 @@ import asyncio -import logging import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch from discord.ext import commands from bot import constants -from bot.cogs import snekbox -from bot.cogs.snekbox import Snekbox +from bot.exts.utils import snekbox +from bot.exts.utils.snekbox import Snekbox from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser @@ -39,43 +38,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) self.assertEqual(result, "too long to upload") - async def test_upload_output(self): + @patch("bot.exts.utils.snekbox.send_to_paste_service") + async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "MarkDiamond" - resp = MagicMock() - resp.json = AsyncMock(return_value={"key": key}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - self.assertEqual( - await self.cog.upload_output("My awesome output"), - constants.URLs.paste_service.format(key=key) - ) - self.bot.http_session.post.assert_called_with( - constants.URLs.paste_service.format(key="documents"), - data="My awesome output", - raise_for_status=True + await self.cog.upload_output("Test output.") + mock_paste_util.assert_called_once_with( + self.bot.http_session, "Test output.", extension="txt" ) - async def test_upload_output_gracefully_fallback_if_exception_during_request(self): - """Output upload gracefully fallback if the upload fail.""" - resp = MagicMock() - resp.json = AsyncMock(side_effect=Exception) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - log = logging.getLogger("bot.cogs.snekbox") - with self.assertLogs(logger=log, level='ERROR'): - await self.cog.upload_output('My awesome output!') - - async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): - """Output upload gracefully fallback if there is no key entry in the response body.""" - self.assertEqual((await self.cog.upload_output('My awesome output!')), None) - def test_prepare_input(self): cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), @@ -99,14 +69,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) self.assertEqual(actual, expected) - @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) + @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( self.cog.get_results_message({'stdout': '', 'returncode': 127}), ('Your eval job has completed with return code 127', '') ) - @patch('bot.cogs.snekbox.Signals') + @patch('bot.exts.utils.snekbox.Signals') def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( @@ -147,12 +117,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'), ( '\u202E\u202E\u202E', - ('Code block escape attempt detected; will not output result', None), + ('Code block escape attempt detected; will not output result', 'https://testificate.com/'), 'Detect RIGHT-TO-LEFT OVERRIDE' ), ( '\u200B\u200B\u200B', - ('Code block escape attempt detected; will not output result', None), + ('Code block escape attempt detected; will not output result', 'https://testificate.com/'), 'Detect ZERO WIDTH SPACE' ), ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'), @@ -296,7 +266,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.format_output.assert_not_called() - @patch("bot.cogs.snekbox.partial") + @patch("bot.exts.utils.snekbox.partial") async def test_continue_eval_does_continue(self, partial_mock): """Test that the continue_eval function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index ce880d457..630f2516d 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -44,18 +44,3 @@ class LinePaginatorTests(TestCase): self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) # Note: item at index 1 is the truncated line, index 0 is prefix self.assertEqual(self.paginator._current_page[1], 'x' * self.paginator.scale_to_size) - - -class ImagePaginatorTests(TestCase): - """Tests functionality of the `ImagePaginator`.""" - - def setUp(self): - """Create a paginator for the test method.""" - self.paginator = pagination.ImagePaginator() - - def test_add_image_appends_image(self): - """`add_image` appends the image to the image list.""" - image = 'lemon' - self.paginator.add_image(image) - - assert self.paginator.images == [image] diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index de72e5748..883465e0b 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,48 +1,50 @@ import unittest from unittest.mock import MagicMock +from discord import DMChannel + from bot.utils import checks from bot.utils.checks import InWhitelistCheckFailure from tests.helpers import MockContext, MockRole -class ChecksTests(unittest.TestCase): +class ChecksTests(unittest.IsolatedAsyncioTestCase): """Tests the check functions defined in `bot.checks`.""" def setUp(self): self.ctx = MockContext() - def test_with_role_check_without_guild(self): - """`with_role_check` returns `False` if `Context.guild` is None.""" - self.ctx.guild = None - self.assertFalse(checks.with_role_check(self.ctx)) + async def test_has_any_role_check_without_guild(self): + """`has_any_role_check` returns `False` for non-guild channels.""" + self.ctx.channel = MagicMock(DMChannel) + self.assertFalse(await checks.has_any_role_check(self.ctx)) - def test_with_role_check_without_required_roles(self): - """`with_role_check` returns `False` if `Context.author` lacks the required role.""" + async def test_has_any_role_check_without_required_roles(self): + """`has_any_role_check` returns `False` if `Context.author` lacks the required role.""" self.ctx.author.roles = [] - self.assertFalse(checks.with_role_check(self.ctx)) + self.assertFalse(await checks.has_any_role_check(self.ctx)) - def test_with_role_check_with_guild_and_required_role(self): - """`with_role_check` returns `True` if `Context.author` has the required role.""" + async def test_has_any_role_check_with_guild_and_required_role(self): + """`has_any_role_check` returns `True` if `Context.author` has the required role.""" self.ctx.author.roles.append(MockRole(id=10)) - self.assertTrue(checks.with_role_check(self.ctx, 10)) + self.assertTrue(await checks.has_any_role_check(self.ctx, 10)) - def test_without_role_check_without_guild(self): - """`without_role_check` should return `False` when `Context.guild` is None.""" - self.ctx.guild = None - self.assertFalse(checks.without_role_check(self.ctx)) + async def test_has_no_roles_check_without_guild(self): + """`has_no_roles_check` should return `False` when `Context.guild` is None.""" + self.ctx.channel = MagicMock(DMChannel) + self.assertFalse(await checks.has_no_roles_check(self.ctx)) - def test_without_role_check_returns_false_with_unwanted_role(self): - """`without_role_check` returns `False` if `Context.author` has unwanted role.""" + async def test_has_no_roles_check_returns_false_with_unwanted_role(self): + """`has_no_roles_check` returns `False` if `Context.author` has unwanted role.""" role_id = 42 self.ctx.author.roles.append(MockRole(id=role_id)) - self.assertFalse(checks.without_role_check(self.ctx, role_id)) + self.assertFalse(await checks.has_no_roles_check(self.ctx, role_id)) - def test_without_role_check_returns_true_without_unwanted_role(self): - """`without_role_check` returns `True` if `Context.author` does not have unwanted role.""" + async def test_has_no_roles_check_returns_true_without_unwanted_role(self): + """`has_no_roles_check` returns `True` if `Context.author` does not have unwanted role.""" role_id = 42 self.ctx.author.roles.append(MockRole(id=role_id)) - self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + self.assertTrue(await checks.has_no_roles_check(self.ctx, role_id + 10)) def test_in_whitelist_check_correct_channel(self): """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py deleted file mode 100644 index a2f0fe55d..000000000 --- a/tests/bot/utils/test_redis_cache.py +++ /dev/null @@ -1,265 +0,0 @@ -import asyncio -import unittest - -import fakeredis.aioredis - -from bot.utils import RedisCache -from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError -from tests import helpers - - -class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisCache class from utils.redis_dict.py.""" - - async def asyncSetUp(self): # noqa: N802 - """Sets up the objects that only have to be initialized once.""" - self.bot = helpers.MockBot() - self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - - # Okay, so this is necessary so that we can create a clean new - # class for every test method, and we want that because it will - # ensure we get a fresh loop, which is necessary for test_increment_lock - # to be able to pass. - class DummyCog: - """A dummy cog, for dummies.""" - - redis = RedisCache() - - def __init__(self, bot: helpers.MockBot): - self.bot = bot - - self.cog = DummyCog(self.bot) - - await self.cog.redis.clear() - - def test_class_attribute_namespace(self): - """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") - - async def test_class_attribute_required(self): - """Test that errors are raised when not assigned as a class attribute.""" - bad_cache = RedisCache() - self.assertIs(bad_cache._namespace, None) - - with self.assertRaises(RuntimeError): - await bad_cache.set("test", "me_up_deadman") - - async def test_set_get_item(self): - """Test that users can set and get items from the RedisDict.""" - test_cases = ( - ('favorite_fruit', 'melon'), - ('favorite_number', 86), - ('favorite_fraction', 86.54), - ('favorite_boolean', False), - ('other_boolean', True), - ) - - # Test that we can get and set different types. - for test in test_cases: - await self.cog.redis.set(*test) - self.assertEqual(await self.cog.redis.get(test[0]), test[1]) - - # Test that .get allows a default value - self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") - - async def test_set_item_type(self): - """Test that .set rejects keys and values that are not permitted.""" - fruits = ["lemon", "melon", "apple"] - - with self.assertRaises(TypeError): - await self.cog.redis.set(fruits, "nice") - - with self.assertRaises(TypeError): - await self.cog.redis.set(4.23, "nice") - - async def test_delete_item(self): - """Test that .delete allows us to delete stuff from the RedisCache.""" - # Add an item and verify that it gets added - await self.cog.redis.set("internet", "firetruck") - self.assertEqual(await self.cog.redis.get("internet"), "firetruck") - - # Delete that item and verify that it gets deleted - await self.cog.redis.delete("internet") - self.assertIs(await self.cog.redis.get("internet"), None) - - async def test_contains(self): - """Test that we can check membership with .contains.""" - await self.cog.redis.set('favorite_country', "Burkina Faso") - - self.assertIs(await self.cog.redis.contains('favorite_country'), True) - self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) - - async def test_items(self): - """Test that the RedisDict can be iterated.""" - # Set up our test cases in the Redis cache - test_cases = [ - ('favorite_turtle', 'Donatello'), - ('second_favorite_turtle', 'Leonardo'), - ('third_favorite_turtle', 'Raphael'), - ] - for key, value in test_cases: - await self.cog.redis.set(key, value) - - # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.cog.redis.items()] - - # These sequences are probably in the same order now, but probably - # isn't good enough for tests. Let's not rely on .hgetall always - # returning things in sequence, and just sort both lists to be safe. - redis_items = sorted(redis_items) - test_cases = sorted(test_cases) - - # If these are equal now, everything works fine. - self.assertSequenceEqual(test_cases, redis_items) - - async def test_length(self): - """Test that we can get the correct .length from the RedisDict.""" - await self.cog.redis.set('one', 1) - await self.cog.redis.set('two', 2) - await self.cog.redis.set('three', 3) - self.assertEqual(await self.cog.redis.length(), 3) - - await self.cog.redis.set('four', 4) - self.assertEqual(await self.cog.redis.length(), 4) - - async def test_to_dict(self): - """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.cog.redis.to_dict() - local_copy = {key: value for key, value in await self.cog.redis.items()} - self.assertIs(type(copy), dict) - self.assertDictEqual(copy, local_copy) - - async def test_clear(self): - """Test that the .clear method removes the entire hash.""" - await self.cog.redis.set('teddy', 'with me') - await self.cog.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.cog.redis.length(), 2) - - await self.cog.redis.clear() - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_pop(self): - """Test that we can .pop an item from the RedisDict.""" - await self.cog.redis.set('john', 'was afraid') - - self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') - self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_update(self): - """Test that we can .update the RedisDict with multiple items.""" - await self.cog.redis.set("reckfried", "lona") - await self.cog.redis.set("bel air", "prince") - await self.cog.redis.update({ - "reckfried": "jona", - "mega": "hungry, though", - }) - - result = { - "reckfried": "jona", - "bel air": "prince", - "mega": "hungry, though", - } - self.assertDictEqual(await self.cog.redis.to_dict(), result) - - def test_typestring_conversion(self): - """Test the typestring-related helper functions.""" - conversion_tests = ( - (12, "i|12"), - (12.4, "f|12.4"), - ("cowabunga", "s|cowabunga"), - ) - - # Test conversion to typestring - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) - - # Test conversion from typestrings - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) - - # Test that exceptions are raised on invalid input - with self.assertRaises(TypeError): - self.cog.redis._value_to_typestring(["internet"]) - self.cog.redis._value_from_typestring("o|firedog") - - async def test_increment_decrement(self): - """Test .increment and .decrement methods.""" - await self.cog.redis.set("entropic", 5) - await self.cog.redis.set("disentropic", 12.5) - - # Test default increment - await self.cog.redis.increment("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 6) - - # Test default decrement - await self.cog.redis.decrement("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # Test float increment with float - await self.cog.redis.increment("disentropic", 2.0) - self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) - - # Test float increment with int - await self.cog.redis.increment("disentropic", 2) - self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) - - # Test negative increments, because why not. - await self.cog.redis.increment("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 0) - - # Negative decrements? Sure. - await self.cog.redis.decrement("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # What about if we use a negative float to decrement an int? - # This should convert the type into a float. - await self.cog.redis.decrement("entropic", -2.5) - self.assertEqual(await self.cog.redis.get("entropic"), 7.5) - - # Let's test that they raise the right errors - with self.assertRaises(KeyError): - await self.cog.redis.increment("doesn't_exist!") - - await self.cog.redis.set("stringthing", "stringthing") - with self.assertRaises(TypeError): - await self.cog.redis.increment("stringthing") - - async def test_increment_lock(self): - """Test that we can't produce a race condition in .increment.""" - await self.cog.redis.set("test_key", 0) - tasks = [] - - # Increment this a lot in different tasks - for _ in range(100): - task = asyncio.create_task( - self.cog.redis.increment("test_key", 1) - ) - tasks.append(task) - await asyncio.gather(*tasks) - - # Confirm that the value has been incremented the exact right number of times. - value = await self.cog.redis.get("test_key") - self.assertEqual(value, 100) - - async def test_exceptions_raised(self): - """Testing that the various RuntimeErrors are reachable.""" - class MyCog: - cache = RedisCache() - - def __init__(self): - self.other_cache = RedisCache() - - cog = MyCog() - - # Raises "No Bot instance" - with self.assertRaises(NoBotInstanceError): - await cog.cache.get("john") - - # Raises "RedisCache has no namespace" - with self.assertRaises(NoNamespaceError): - await cog.other_cache.get("was") - - # Raises "You must access the RedisCache instance through the cog instance" - with self.assertRaises(NoParentInstanceError): - await MyCog.cache.get("afraid") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py new file mode 100644 index 000000000..5e0855704 --- /dev/null +++ b/tests/bot/utils/test_services.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.http_session = MagicMock() + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_url_and_sent_contents(self): + """Correct url was used and post was called with expected data.""" + response = MagicMock( + json=AsyncMock(return_value={"key": ""}) + ) + self.http_session.post().__aenter__.return_value = response + self.http_session.post.reset_mock() + await send_to_paste_service(self.http_session, "Content") + self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + + @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") + async def test_paste_returns_correct_url_on_success(self): + """Url with specified extension is returned on successful requests.""" + key = "paste_key" + test_cases = ( + (f"https://paste_service.com/{key}.txt", "txt"), + (f"https://paste_service.com/{key}.py", "py"), + (f"https://paste_service.com/{key}", ""), + ) + response = MagicMock( + json=AsyncMock(return_value={"key": key}) + ) + self.http_session.post().__aenter__.return_value = response + + for expected_output, extension in test_cases: + with self.subTest(msg=f"Send contents with extension {repr(extension)}"): + self.assertEqual( + await send_to_paste_service(self.http_session, "", extension=extension), + expected_output + ) + + async def test_request_repeated_on_json_errors(self): + """Json with error message and invalid json are handled as errors and requests repeated.""" + test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) + self.http_session.post().__aenter__.return_value = response = MagicMock() + self.http_session.post.reset_mock() + + for error_json in test_cases: + with self.subTest(error_json=error_json): + response.json = AsyncMock(return_value=error_json) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + self.http_session.post.reset_mock() + + async def test_request_repeated_on_connection_errors(self): + """Requests are repeated in the case of connection errors.""" + self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertIsNone(result) + + async def test_general_error_handled_and_request_repeated(self): + """All `Exception`s are handled, logged and request repeated.""" + self.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service(self.http_session, "") + self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.assertLogs("bot.utils", logging.ERROR) + self.assertIsNone(result) diff --git a/tests/helpers.py b/tests/helpers.py index facc4e1af..e47fdf28f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -308,7 +308,11 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ - spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) + spec_set = Bot( + command_prefix=unittest.mock.MagicMock(), + loop=_get_mock_loop(), + redis_session=unittest.mock.MagicMock(), + ) additional_spec_asyncs = ("wait_for", "redis_ready") def __init__(self, **kwargs) -> None: |