diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/cogs/test_token_remover.py | 131 | ||||
| -rw-r--r-- | tests/bot/exts/__init__.py (renamed from tests/bot/cogs/__init__.py) | 0 | ||||
| -rw-r--r-- | tests/bot/exts/backend/__init__.py (renamed from tests/bot/cogs/moderation/__init__.py) | 0 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/__init__.py (renamed from tests/bot/cogs/sync/__init__.py) | 0 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_base.py (renamed from tests/bot/cogs/sync/test_base.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py (renamed from tests/bot/cogs/sync/test_cog.py) | 99 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py (renamed from tests/bot/cogs/sync/test_roles.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py (renamed from tests/bot/cogs/sync/test_users.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_logging.py | 32 | ||||
| -rw-r--r-- | tests/bot/exts/filters/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/filters/test_antimalware.py (renamed from tests/bot/cogs/test_antimalware.py) | 48 | ||||
| -rw-r--r-- | tests/bot/exts/filters/test_antispam.py (renamed from tests/bot/cogs/test_antispam.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/filters/test_security.py (renamed from tests/bot/cogs/test_security.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/filters/test_token_remover.py | 310 | ||||
| -rw-r--r-- | tests/bot/exts/fun/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/fun/test_duck_pond.py (renamed from tests/bot/cogs/test_duck_pond.py) | 61 | ||||
| -rw-r--r-- | tests/bot/exts/info/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_information.py (renamed from tests/bot/cogs/test_information.py) | 107 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py (renamed from tests/bot/cogs/moderation/test_infractions.py) | 8 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py (renamed from tests/bot/cogs/moderation/test_utils.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 770 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_modlog.py (renamed from tests/bot/cogs/moderation/test_modlog.py) | 2 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py (renamed from tests/bot/cogs/moderation/test_silence.py) | 30 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_slowmode.py | 111 | ||||
| -rw-r--r-- | tests/bot/exts/test_cogs.py (renamed from tests/bot/cogs/test_cogs.py) | 8 | ||||
| -rw-r--r-- | tests/bot/exts/utils/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_jams.py | 173 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py (renamed from tests/bot/cogs/test_snekbox.py) | 68 | ||||
| -rw-r--r-- | tests/bot/test_pagination.py | 54 | ||||
| -rw-r--r-- | tests/bot/utils/test_messages.py | 27 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 14 | ||||
| -rw-r--r-- | tests/bot/utils/test_services.py | 74 | ||||
| -rw-r--r-- | tests/helpers.py | 20 | 
35 files changed, 1807 insertions, 352 deletions
| diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py deleted file mode 100644 index 33d1ec170..000000000 --- a/tests/bot/cogs/test_token_remover.py +++ /dev/null @@ -1,131 +0,0 @@ -import asyncio -import logging -import unittest -from unittest.mock import AsyncMock, MagicMock - -from discord import Colour - -from bot.cogs.token_remover import ( -    DELETION_MESSAGE_TEMPLATE, -    TokenRemover, -    setup as setup_cog, -) -from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import MockBot, MockMessage - - -class TokenRemoverTests(unittest.TestCase): -    """Tests the `TokenRemover` cog.""" - -    def setUp(self): -        """Adds the cog, a bot, and a message to the instance for usage in tests.""" -        self.bot = MockBot() -        self.bot.get_cog.return_value = MagicMock() -        self.bot.get_cog.return_value.send_log_message = AsyncMock() -        self.cog = TokenRemover(bot=self.bot) - -        self.msg = MockMessage(id=555, content='') -        self.msg.author.__str__ = MagicMock() -        self.msg.author.__str__.return_value = 'lemon' -        self.msg.author.bot = False -        self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' -        self.msg.author.id = 42 -        self.msg.author.mention = '@lemon' -        self.msg.channel.mention = "#lemonade-stand" - -    def test_is_valid_user_id_is_true_for_numeric_content(self): -        """A string decoding to numeric characters is a valid user ID.""" -        # MTIz = base64(123) -        self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) - -    def test_is_valid_user_id_is_false_for_alphabetic_content(self): -        """A string decoding to alphabetic characters is not a valid user ID.""" -        # YWJj = base64(abc) -        self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) - -    def test_is_valid_timestamp_is_true_for_valid_timestamps(self): -        """A string decoding to a valid timestamp should be recognized as such.""" -        self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) - -    def test_is_valid_timestamp_is_false_for_invalid_values(self): -        """A string not decoding to a valid timestamp should not be recognized as such.""" -        # MTIz = base64(123) -        self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) - -    def test_mod_log_property(self): -        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -        self.bot.get_cog.return_value = 'lemon' -        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) -        self.bot.get_cog.assert_called_once_with('ModLog') - -    def test_ignores_bot_messages(self): -        """When the message event handler is called with a bot message, nothing is done.""" -        self.msg.author.bot = True -        coroutine = self.cog.on_message(self.msg) -        self.assertIsNone(asyncio.run(coroutine)) - -    def test_ignores_messages_without_tokens(self): -        """Messages without anything looking like a token are ignored.""" -        for content in ('', 'lemon wins'): -            with self.subTest(content=content): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                self.assertIsNone(asyncio.run(coroutine)) - -    def test_ignores_messages_with_invalid_tokens(self): -        """Messages with values that are invalid tokens are ignored.""" -        for content in ('foo.bar.baz', 'x.y.'): -            with self.subTest(content=content): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                self.assertIsNone(asyncio.run(coroutine)) - -    def test_censors_valid_tokens(self): -        """Valid tokens are censored.""" -        cases = ( -            # (content, censored_token) -            ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), -        ) - -        for content, censored_token in cases: -            with self.subTest(content=content, censored_token=censored_token): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: -                    self.assertIsNone(asyncio.run(coroutine))  # no return value - -                [line] = cm.output -                log_message = ( -                    "Censored a seemingly valid token sent by " -                    "lemon (`42`) in #lemonade-stand, " -                    f"token was `{censored_token}`" -                ) -                self.assertIn(log_message, line) - -                self.msg.delete.assert_called_once_with() -                self.msg.channel.send.assert_called_once_with( -                    DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') -                ) -                self.bot.get_cog.assert_called_with('ModLog') -                self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') - -                mod_log = self.bot.get_cog.return_value -                mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) -                mod_log.send_log_message.assert_called_once_with( -                    icon_url=Icons.token_removed, -                    colour=Colour(Colours.soft_red), -                    title="Token removed!", -                    text=log_message, -                    thumbnail='picture-lemon.png', -                    channel_id=Channels.mod_alerts -                ) - - -class TokenRemoverSetupTests(unittest.TestCase): -    """Tests setup of the `TokenRemover` cog.""" - -    def test_setup(self): -        """Setup of the extension should call add_cog.""" -        bot = MockBot() -        setup_cog(bot) -        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/__init__.py b/tests/bot/exts/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/__init__.py +++ b/tests/bot/exts/__init__.py diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/exts/backend/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/moderation/__init__.py +++ b/tests/bot/exts/backend/__init__.py diff --git a/tests/bot/cogs/sync/__init__.py b/tests/bot/exts/backend/sync/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/cogs/sync/__init__.py +++ b/tests/bot/exts/backend/sync/__init__.py diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 70aea2bab..886c243cf 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -6,7 +6,7 @@ import discord  from bot import constants  from bot.api import ResponseCodeError -from bot.cogs.sync.syncers import Syncer, _Diff +from bot.exts.backend.sync._syncers import Syncer, _Diff  from tests import helpers diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 14fd909c4..1b89564f2 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -5,8 +5,9 @@ import discord  from bot import constants  from bot.api import ResponseCodeError -from bot.cogs import sync -from bot.cogs.sync.syncers import Syncer +from bot.exts.backend import sync +from bot.exts.backend.sync._cog import Sync +from bot.exts.backend.sync._syncers import Syncer  from tests import helpers  from tests.base import CommandTestCase @@ -29,19 +30,19 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):          self.bot = helpers.MockBot()          self.role_syncer_patcher = mock.patch( -            "bot.cogs.sync.syncers.RoleSyncer", +            "bot.exts.backend.sync._syncers.RoleSyncer",              autospec=Syncer,              spec_set=True          )          self.user_syncer_patcher = mock.patch( -            "bot.cogs.sync.syncers.UserSyncer", +            "bot.exts.backend.sync._syncers.UserSyncer",              autospec=Syncer,              spec_set=True          )          self.RoleSyncer = self.role_syncer_patcher.start()          self.UserSyncer = self.user_syncer_patcher.start() -        self.cog = sync.Sync(self.bot) +        self.cog = Sync(self.bot)      def tearDown(self):          self.role_syncer_patcher.stop() @@ -59,7 +60,7 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):  class SyncCogTests(SyncCogTestCase):      """Tests for the Sync cog.""" -    @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) +    @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock)      def test_sync_cog_init(self, sync_guild):          """Should instantiate syncers and run a sync for the guild."""          # Reset because a Sync cog was already instantiated in setUp. @@ -70,7 +71,7 @@ class SyncCogTests(SyncCogTestCase):          mock_sync_guild_coro = mock.MagicMock()          sync_guild.return_value = mock_sync_guild_coro -        sync.Sync(self.bot) +        Sync(self.bot)          self.RoleSyncer.assert_called_once_with(self.bot)          self.UserSyncer.assert_called_once_with(self.bot) @@ -131,6 +132,15 @@ class SyncCogListenerTests(SyncCogTestCase):          super().setUp()          self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) +        self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5) +        self.guild_id = self.guild_id_patcher.start() + +        self.guild = helpers.MockGuild(id=self.guild_id) +        self.other_guild = helpers.MockGuild(id=0) + +    def tearDown(self): +        self.guild_id_patcher.stop() +      async def test_sync_cog_on_guild_role_create(self):          """A POST request should be sent with the new role's data."""          self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -142,20 +152,32 @@ class SyncCogListenerTests(SyncCogTestCase):              "permissions": 8,              "position": 23,          } -        role = helpers.MockRole(**role_data) +        role = helpers.MockRole(**role_data, guild=self.guild)          await self.cog.on_guild_role_create(role)          self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) +    async def test_sync_cog_on_guild_role_create_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        role = helpers.MockRole(guild=self.other_guild) +        await self.cog.on_guild_role_create(role) +        self.bot.api_client.post.assert_not_awaited() +      async def test_sync_cog_on_guild_role_delete(self):          """A DELETE request should be sent."""          self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) -        role = helpers.MockRole(id=99) +        role = helpers.MockRole(id=99, guild=self.guild)          await self.cog.on_guild_role_delete(role)          self.bot.api_client.delete.assert_called_once_with("bot/roles/99") +    async def test_sync_cog_on_guild_role_delete_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        role = helpers.MockRole(guild=self.other_guild) +        await self.cog.on_guild_role_delete(role) +        self.bot.api_client.delete.assert_not_awaited() +      async def test_sync_cog_on_guild_role_update(self):          """A PUT request should be sent if the colour, name, permissions, or position changes."""          self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -180,8 +202,8 @@ class SyncCogListenerTests(SyncCogTestCase):                      after_role_data = role_data.copy()                      after_role_data[attribute] = 876 -                    before_role = helpers.MockRole(**role_data) -                    after_role = helpers.MockRole(**after_role_data) +                    before_role = helpers.MockRole(**role_data, guild=self.guild) +                    after_role = helpers.MockRole(**after_role_data, guild=self.guild)                      await self.cog.on_guild_role_update(before_role, after_role) @@ -193,31 +215,43 @@ class SyncCogListenerTests(SyncCogTestCase):                      else:                          self.bot.api_client.put.assert_not_called() +    async def test_sync_cog_on_guild_role_update_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        role = helpers.MockRole(guild=self.other_guild) +        await self.cog.on_guild_role_update(role, role) +        self.bot.api_client.put.assert_not_awaited() +      async def test_sync_cog_on_member_remove(self): -        """Member should patched to set in_guild as False.""" +        """Member should be patched to set in_guild as False."""          self.assertTrue(self.cog.on_member_remove.__cog_listener__) -        member = helpers.MockMember() +        member = helpers.MockMember(guild=self.guild)          await self.cog.on_member_remove(member)          self.cog.patch_user.assert_called_once_with(              member.id, -            updated_information={"in_guild": False} +            json={"in_guild": False}          ) +    async def test_sync_cog_on_member_remove_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        member = helpers.MockMember(guild=self.other_guild) +        await self.cog.on_member_remove(member) +        self.cog.patch_user.assert_not_awaited() +      async def test_sync_cog_on_member_update_roles(self):          """Members should be patched if their roles have changed."""          self.assertTrue(self.cog.on_member_update.__cog_listener__)          # Roles are intentionally unsorted.          before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] -        before_member = helpers.MockMember(roles=before_roles) -        after_member = helpers.MockMember(roles=before_roles[1:]) +        before_member = helpers.MockMember(roles=before_roles, guild=self.guild) +        after_member = helpers.MockMember(roles=before_roles[1:], guild=self.guild)          await self.cog.on_member_update(before_member, after_member)          data = {"roles": sorted(role.id for role in after_member.roles)} -        self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) +        self.cog.patch_user.assert_called_once_with(after_member.id, json=data)      async def test_sync_cog_on_member_update_other(self):          """Members should not be patched if other attributes have changed.""" @@ -233,13 +267,19 @@ class SyncCogListenerTests(SyncCogTestCase):              with self.subTest(attribute=attribute):                  self.cog.patch_user.reset_mock() -                before_member = helpers.MockMember(**{attribute: old_value}) -                after_member = helpers.MockMember(**{attribute: new_value}) +                before_member = helpers.MockMember(**{attribute: old_value}, guild=self.guild) +                after_member = helpers.MockMember(**{attribute: new_value}, guild=self.guild)                  await self.cog.on_member_update(before_member, after_member)                  self.cog.patch_user.assert_not_called() +    async def test_sync_cog_on_member_update_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        member = helpers.MockMember(guild=self.other_guild) +        await self.cog.on_member_update(member, member) +        self.cog.patch_user.assert_not_awaited() +      async def test_sync_cog_on_user_update(self):          """A user should be patched only if the name, discriminator, or avatar changes."""          self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -272,12 +312,15 @@ class SyncCogListenerTests(SyncCogTestCase):                      # Don't care if *all* keys are present; only the changed one is required                      call_args = self.cog.patch_user.call_args -                    self.assertEqual(call_args[0][0], after_user.id) -                    self.assertIn("updated_information", call_args[1]) +                    self.assertEqual(call_args.args[0], after_user.id) +                    self.assertIn("json", call_args.kwargs) + +                    self.assertIn("ignore_404", call_args.kwargs) +                    self.assertTrue(call_args.kwargs["ignore_404"]) -                    updated_information = call_args[1]["updated_information"] -                    self.assertIn(api_field, updated_information) -                    self.assertEqual(updated_information[api_field], api_value) +                    json = call_args.kwargs["json"] +                    self.assertIn(api_field, json) +                    self.assertEqual(json[api_field], api_value)                  else:                      self.cog.patch_user.assert_not_called() @@ -290,6 +333,7 @@ class SyncCogListenerTests(SyncCogTestCase):          member = helpers.MockMember(              discriminator="1234",              roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], +            guild=self.guild,          )          data = { @@ -334,6 +378,13 @@ class SyncCogListenerTests(SyncCogTestCase):          self.bot.api_client.post.assert_not_called() +    async def test_sync_cog_on_member_join_ignores_guilds(self): +        """Events from other guilds should be ignored.""" +        member = helpers.MockMember(guild=self.other_guild) +        await self.cog.on_member_join(member) +        self.bot.api_client.post.assert_not_awaited() +        self.bot.api_client.put.assert_not_awaited() +  class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):      """Tests for the commands in the Sync cog.""" diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 79eee98f4..7b9f40cad 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -3,7 +3,7 @@ from unittest import mock  import discord -from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role  from tests import helpers diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 002a947ad..c0a1da35c 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@  import unittest  from unittest import mock -from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User  from tests import helpers diff --git a/tests/bot/exts/backend/test_logging.py b/tests/bot/exts/backend/test_logging.py new file mode 100644 index 000000000..466f207d9 --- /dev/null +++ b/tests/bot/exts/backend/test_logging.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import patch + +from bot import constants +from bot.exts.backend.logging import Logging +from tests.helpers import MockBot, MockTextChannel + + +class LoggingTests(unittest.IsolatedAsyncioTestCase): +    """Test cases for connected login.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = Logging(self.bot) +        self.dev_log = MockTextChannel(id=1234, name="dev-log") + +    @patch("bot.exts.backend.logging.DEBUG_MODE", False) +    async def test_debug_mode_false(self): +        """Should send connected message to dev-log.""" +        self.bot.get_channel.return_value = self.dev_log + +        await self.cog.startup_greeting() +        self.bot.wait_until_guild_available.assert_awaited_once_with() +        self.bot.get_channel.assert_called_once_with(constants.Channels.dev_log) +        self.dev_log.send.assert_awaited_once() + +    @patch("bot.exts.backend.logging.DEBUG_MODE", True) +    async def test_debug_mode_true(self): +        """Should not send anything to dev-log.""" +        await self.cog.startup_greeting() +        self.bot.wait_until_guild_available.assert_awaited_once_with() +        self.bot.get_channel.assert_not_called() diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filters/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/filters/__init__.py diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index f219fc1ba..3393c6cdc 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -1,28 +1,35 @@  import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock  from discord import NotFound -from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from bot.constants import Channels, STAFF_ROLES +from bot.exts.filters import antimalware  from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole -MODULE = "bot.cogs.antimalware" - -@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])  class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      """Test the AntiMalware cog."""      def setUp(self):          """Sets up fresh objects for each test."""          self.bot = MockBot() +        self.bot.filter_list_cache = { +            "FILE_FORMAT.True": { +                ".first": {}, +                ".second": {}, +                ".third": {}, +            } +        }          self.cog = antimalware.AntiMalware(self.bot)          self.message = MockMessage() +        self.message.webhook_id = None +        self.message.author.bot = None +        self.whitelist = [".first", ".second", ".third"]      async def test_message_with_allowed_attachment(self):          """Messages with allowed extensions should not be deleted""" -        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") +        attachment = MockAttachment(filename="python.first")          self.message.attachments = [attachment]          await self.cog.on_message(self.message) @@ -43,6 +50,26 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          self.message.delete.assert_not_called() +    async def test_webhook_message_with_illegal_extension(self): +        """A webhook message containing an illegal extension should be ignored.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.webhook_id = 697140105563078727 +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_bot_message_with_illegal_extension(self): +        """A bot message containing an illegal extension should be ignored.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.author.bot = 409107086526644234 +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() +      async def test_message_with_illegal_extension_gets_deleted(self):          """A message containing an illegal extension should send an embed."""          attachment = MockAttachment(filename="python.disallowed") @@ -93,7 +120,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)          antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) -    async def test_other_disallowed_extention_embed_description(self): +    async def test_other_disallowed_extension_embed_description(self):          """Test the description for a non .py/.txt disallowed extension."""          attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment] @@ -109,6 +136,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)          antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( +            joined_whitelist=", ".join(self.whitelist),              blocked_extensions_str=".disallowed",              meta_channel_mention=meta_channel.mention          ) @@ -135,7 +163,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          """The return value should include all non-whitelisted extensions."""          test_values = (              ([], []), -            (AntiMalwareConfig.whitelist, []), +            (self.whitelist, []),              ([".first"], []),              ([".first", ".disallowed"], [".disallowed"]),              ([".disallowed"], [".disallowed"]), @@ -145,7 +173,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          for extensions, expected_disallowed_extensions in test_values:              with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):                  self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] -                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                disallowed_extensions = self.cog._get_disallowed_extensions(self.message)                  self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/exts/filters/test_antispam.py index ce5472c71..6a0e4fded 100644 --- a/tests/bot/cogs/test_antispam.py +++ b/tests/bot/exts/filters/test_antispam.py @@ -1,6 +1,6 @@  import unittest -from bot.cogs import antispam +from bot.exts.filters import antispam  class AntispamConfigurationValidationTests(unittest.TestCase): diff --git a/tests/bot/cogs/test_security.py b/tests/bot/exts/filters/test_security.py index 9d1a62f7e..c0c3baa42 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock  from discord.ext.commands import NoPrivateMessage -from bot.cogs import security +from bot.exts.filters import security  from tests.helpers import MockBot, MockContext diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py new file mode 100644 index 000000000..a0ff8a877 --- /dev/null +++ b/tests/bot/exts/filters/test_token_remover.py @@ -0,0 +1,310 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock + +from discord import Colour, NotFound + +from bot import constants +from bot.exts.filters import token_remover +from bot.exts.filters.token_remover import Token, TokenRemover +from bot.exts.moderation.modlog import ModLog +from tests.helpers import MockBot, MockMessage, autospec + + +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): +    """Tests the `TokenRemover` cog.""" + +    def setUp(self): +        """Adds the cog, a bot, and a message to the instance for usage in tests.""" +        self.bot = MockBot() +        self.cog = TokenRemover(bot=self.bot) + +        self.msg = MockMessage(id=555, content="hello world") +        self.msg.channel.mention = "#lemonade-stand" +        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) +        self.msg.author.avatar_url_as.return_value = "picture-lemon.png" + +    def test_is_valid_user_id_valid(self): +        """Should consider user IDs valid if they decode entirely to ASCII digits.""" +        ids = ( +            "NDcyMjY1OTQzMDYyNDEzMzMy", +            "NDc1MDczNjI5Mzk5NTQ3OTA0", +            "NDY3MjIzMjMwNjUwNzc3NjQx", +        ) + +        for user_id in ids: +            with self.subTest(user_id=user_id): +                result = TokenRemover.is_valid_user_id(user_id) +                self.assertTrue(result) + +    def test_is_valid_user_id_invalid(self): +        """Should consider non-digit and non-ASCII IDs invalid.""" +        ids = ( +            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), +            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), +            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), +            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), +            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for user_id, msg in ids: +            with self.subTest(msg=msg): +                result = TokenRemover.is_valid_user_id(user_id) +                self.assertFalse(result) + +    def test_is_valid_timestamp_valid(self): +        """Should consider timestamps valid if they're greater than the Discord epoch.""" +        timestamps = ( +            "XsyRkw", +            "Xrim9Q", +            "XsyR-w", +            "XsySD_", +            "Dn9r_A", +        ) + +        for timestamp in timestamps: +            with self.subTest(timestamp=timestamp): +                result = TokenRemover.is_valid_timestamp(timestamp) +                self.assertTrue(result) + +    def test_is_valid_timestamp_invalid(self): +        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" +        timestamps = ( +            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), +            ("ew", "123"), +            ("AoIKgA", "42076800"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for timestamp, msg in timestamps: +            with self.subTest(msg=msg): +                result = TokenRemover.is_valid_timestamp(timestamp) +                self.assertFalse(result) + +    def test_mod_log_property(self): +        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" +        self.bot.get_cog.return_value = 'lemon' +        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) +        self.bot.get_cog.assert_called_once_with('ModLog') + +    async def test_on_message_edit_uses_on_message(self): +        """The edit listener should delegate handling of the message to the normal listener.""" +        self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + +        await self.cog.on_message_edit(MockMessage(), self.msg) +        self.cog.on_message.assert_awaited_once_with(self.msg) + +    @autospec(TokenRemover, "find_token_in_message", "take_action") +    async def test_on_message_takes_action(self, find_token_in_message, take_action): +        """Should take action if a valid token is found when a message is sent.""" +        cog = TokenRemover(self.bot) +        found_token = "foobar" +        find_token_in_message.return_value = found_token + +        await cog.on_message(self.msg) + +        find_token_in_message.assert_called_once_with(self.msg) +        take_action.assert_awaited_once_with(cog, self.msg, found_token) + +    @autospec(TokenRemover, "find_token_in_message", "take_action") +    async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): +        """Shouldn't take action if a valid token isn't found when a message is sent.""" +        cog = TokenRemover(self.bot) +        find_token_in_message.return_value = False + +        await cog.on_message(self.msg) + +        find_token_in_message.assert_called_once_with(self.msg) +        take_action.assert_not_awaited() + +    @autospec(TokenRemover, "find_token_in_message") +    async def test_on_message_ignores_dms_bots(self, find_token_in_message): +        """Shouldn't parse a message if it is a DM or authored by a bot.""" +        cog = TokenRemover(self.bot) +        dm_msg = MockMessage(guild=None) +        bot_msg = MockMessage(author=MagicMock(bot=True)) + +        for msg in (dm_msg, bot_msg): +            await cog.on_message(msg) +            find_token_in_message.assert_not_called() + +    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") +    def test_find_token_no_matches(self, token_re): +        """None should be returned if the regex matches no tokens in a message.""" +        token_re.finditer.return_value = () + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") +    @autospec("bot.exts.filters.token_remover", "Token") +    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") +    def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): +        """The first match with a valid user ID and timestamp should be returned as a `Token`.""" +        matches = [ +            mock.create_autospec(Match, spec_set=True, instance=True), +            mock.create_autospec(Match, spec_set=True, instance=True), +        ] +        tokens = [ +            mock.create_autospec(Token, spec_set=True, instance=True), +            mock.create_autospec(Token, spec_set=True, instance=True), +        ] + +        token_re.finditer.return_value = matches +        token_cls.side_effect = tokens +        is_valid_id.side_effect = (False, True)  # The 1st match will be invalid, 2nd one valid. +        is_valid_timestamp.return_value = True + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertEqual(tokens[1], return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") +    @autospec("bot.exts.filters.token_remover", "Token") +    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") +    def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): +        """None should be returned if no matches have valid user IDs or timestamps.""" +        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] +        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) +        is_valid_id.return_value = False +        is_valid_timestamp.return_value = False + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    def test_regex_invalid_tokens(self): +        """Messages without anything looking like a token are not matched.""" +        tokens = ( +            "", +            "lemon wins", +            "..", +            "x.y", +            "x.y.", +            ".y.z", +            ".y.", +            "..z", +            "x..z", +            " . . ", +            "\n.\n.\n", +            "hellö.world.bye", +            "base64.nötbåse64.morebase64", +            "19jd3J.dfkm3d.€víł§tüff", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = token_remover.TOKEN_RE.findall(token) +                self.assertEqual(len(results), 0) + +    def test_regex_valid_tokens(self): +        """Messages that look like tokens should be matched.""" +        # Don't worry, these tokens have been invalidated. +        tokens = ( +            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", +            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", +            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = token_remover.TOKEN_RE.fullmatch(token) +                self.assertIsNotNone(results, f"{token} was not matched by the regex") + +    def test_regex_matches_multiple_valid(self): +        """Should support multiple matches in the middle of a string.""" +        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" +        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" +        message = f"garbage {token_1} hello {token_2} world" + +        results = token_remover.TOKEN_RE.finditer(message) +        results = [match[0] for match in results] +        self.assertCountEqual((token_1, token_2), results) + +    @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") +    def test_format_log_message(self, log_message): +        """Should correctly format the log message with info from the message and token.""" +        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        log_message.format.return_value = "Howdy" + +        return_value = TokenRemover.format_log_message(self.msg, token) + +        self.assertEqual(return_value, log_message.format.return_value) +        log_message.format.assert_called_once_with( +            author=self.msg.author, +            author_id=self.msg.author.id, +            channel=self.msg.channel.mention, +            user_id=token.user_id, +            timestamp=token.timestamp, +            hmac="x" * len(token.hmac), +        ) + +    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) +    @autospec("bot.exts.filters.token_remover", "log") +    @autospec(TokenRemover, "format_log_message") +    async def test_take_action(self, format_log_message, logger, mod_log_property): +        """Should delete the message and send a mod log.""" +        cog = TokenRemover(self.bot) +        mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) +        token = mock.create_autospec(Token, spec_set=True, instance=True) +        log_msg = "testing123" + +        mod_log_property.return_value = mod_log +        format_log_message.return_value = log_msg + +        await cog.take_action(self.msg, token) + +        self.msg.delete.assert_called_once_with() +        self.msg.channel.send.assert_called_once_with( +            token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) +        ) + +        format_log_message.assert_called_once_with(self.msg, token) +        logger.debug.assert_called_with(log_msg) +        self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + +        mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) +        mod_log.send_log_message.assert_called_once_with( +            icon_url=constants.Icons.token_removed, +            colour=Colour(constants.Colours.soft_red), +            title="Token removed!", +            text=log_msg, +            thumbnail=self.msg.author.avatar_url_as.return_value, +            channel_id=constants.Channels.mod_alerts +        ) + +    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) +    async def test_take_action_delete_failure(self, mod_log_property): +        """Shouldn't send any messages if the token message can't be deleted.""" +        cog = TokenRemover(self.bot) +        mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) +        self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) + +        token = mock.create_autospec(Token, spec_set=True, instance=True) +        await cog.take_action(self.msg, token) + +        self.msg.delete.assert_called_once_with() +        self.msg.channel.send.assert_not_awaited() + + +class TokenRemoverExtensionTests(unittest.TestCase): +    """Tests for the token_remover extension.""" + +    @autospec("bot.exts.filters.token_remover", "TokenRemover") +    def test_extension_setup(self, cog): +        """The TokenRemover cog should be added.""" +        bot = MockBot() +        token_remover.setup(bot) + +        cog.assert_called_once_with(bot) +        bot.add_cog.assert_called_once() +        self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/fun/__init__.py b/tests/bot/exts/fun/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/fun/__init__.py diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/exts/fun/test_duck_pond.py index a8c0107c6..704b08066 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/exts/fun/test_duck_pond.py @@ -7,11 +7,11 @@ from unittest.mock import AsyncMock, MagicMock, patch  import discord  from bot import constants -from bot.cogs import duck_pond +from bot.exts.fun import duck_pond  from tests import base  from tests import helpers -MODULE_PATH = "bot.cogs.duck_pond" +MODULE_PATH = "bot.exts.fun.duck_pond"  class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): @@ -63,7 +63,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.")          self.cog.webhook_id = 1 -        log = logging.getLogger('bot.cogs.duck_pond') +        log = logging.getLogger(MODULE_PATH)          with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:              asyncio.run(self.cog.fetch_webhook()) @@ -129,38 +129,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):              ):                  self.assertEqual(expected_return, actual_return) -    def test_send_webhook_correctly_passes_on_arguments(self): -        """The `send_webhook` method should pass the arguments to the webhook correctly.""" -        self.cog.webhook = helpers.MockAsyncWebhook() - -        content = "fake content" -        username = "fake username" -        avatar_url = "fake avatar_url" -        embed = "fake embed" - -        asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed)) - -        self.cog.webhook.send.assert_called_once_with( -            content=content, -            username=username, -            avatar_url=avatar_url, -            embed=embed -        ) - -    def test_send_webhook_logs_when_sending_message_fails(self): -        """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly.""" -        self.cog.webhook = helpers.MockAsyncWebhook() -        self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.") - -        log = logging.getLogger('bot.cogs.duck_pond') -        with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: -            asyncio.run(self.cog.send_webhook()) - -        self.assertEqual(len(log_watcher.records), 1) - -        record = log_watcher.records[0] -        self.assertEqual(record.levelno, logging.ERROR) -      def _get_reaction(          self,          emoji: typing.Union[str, helpers.MockEmoji], @@ -280,16 +248,20 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):      async def test_relay_message_correctly_relays_content_and_attachments(self):          """The `relay_message` method should correctly relay message content and attachments.""" -        send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" +        send_webhook_path = f"{MODULE_PATH}.send_webhook"          send_attachments_path = f"{MODULE_PATH}.send_attachments" +        author = MagicMock( +            display_name="x", +            avatar_url="https://" +        )          self.cog.webhook = helpers.MockAsyncWebhook()          test_values = ( -            (helpers.MockMessage(clean_content="", attachments=[]), False, False), -            (helpers.MockMessage(clean_content="message", attachments=[]), True, False), -            (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True), -            (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True), +            (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), +            (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), +            (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), +            (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True),          )          for message, expect_webhook_call, expect_attachment_call in test_values: @@ -310,25 +282,25 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), ""))          self.cog.webhook = helpers.MockAsyncWebhook() -        log = logging.getLogger("bot.cogs.duck_pond") +        log = logging.getLogger(MODULE_PATH)          for side_effect in side_effects:  # pragma: no cover              send_attachments.side_effect = side_effect -            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: +            with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook:                  with self.subTest(side_effect=type(side_effect).__name__):                      with self.assertNotLogs(logger=log, level=logging.ERROR):                          await self.cog.relay_message(message)                      self.assertEqual(send_webhook.call_count, 2) -    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) +    @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock)      @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"])          self.cog.webhook = helpers.MockAsyncWebhook() -        log = logging.getLogger("bot.cogs.duck_pond") +        log = logging.getLogger(MODULE_PATH)          side_effect = discord.HTTPException(MagicMock(), "")          send_attachments.side_effect = side_effect @@ -337,6 +309,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):                  await self.cog.relay_message(message)              send_webhook.assert_called_once_with( +                webhook=self.cog.webhook,                  content=message.clean_content,                  username=message.author.display_name,                  avatar_url=message.author.avatar_url diff --git a/tests/bot/exts/info/__init__.py b/tests/bot/exts/info/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/info/__init__.py diff --git a/tests/bot/cogs/test_information.py b/tests/bot/exts/info/test_information.py index 79c0e0ad3..ba8d5d608 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -6,11 +6,11 @@ import unittest.mock  import discord  from bot import constants -from bot.cogs import information +from bot.exts.info import information  from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -COG_PATH = "bot.cogs.information.Information" +COG_PATH = "bot.exts.info.information.Information"  class InformationCogTests(unittest.TestCase): @@ -97,7 +97,7 @@ class InformationCogTests(unittest.TestCase):          self.assertEqual(admin_embed.title, "Admins info")          self.assertEqual(admin_embed.colour, discord.Colour.red()) -    @unittest.mock.patch('bot.cogs.information.time_since') +    @unittest.mock.patch('bot.exts.info.information.time_since')      def test_server_info_command(self, time_since_patch):          time_since_patch.return_value = '2 days ago' @@ -215,10 +215,10 @@ class UserInfractionHelperMethodTests(unittest.TestCase):              with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines):                  self.bot.api_client.get.return_value = api_response -                expected_output = "\n".join(default_header + expected_lines) +                expected_output = "\n".join(expected_lines)                  actual_output = asyncio.run(method(self.member)) -                self.assertEqual(expected_output, actual_output) +                self.assertEqual((default_header, expected_output), actual_output)      def test_basic_user_infraction_counts_returns_correct_strings(self):          """The method should correctly list both the total and active number of non-hidden infractions.""" @@ -249,7 +249,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):              },          ) -        header = ["**Infractions**"] +        header = "Infractions"          self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) @@ -258,7 +258,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          test_values = (              {                  "api response": [], -                "expected_lines": ["This user has never received an infraction."], +                "expected_lines": ["No infractions"],              },              # Shows non-hidden inactive infraction as expected              { @@ -304,7 +304,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):              },          ) -        header = ["**Infractions**"] +        header = "Infractions"          self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) @@ -313,15 +313,15 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          test_values = (              {                  "api response": [], -                "expected_lines": ["This user has never been nominated."], +                "expected_lines": ["No nominations"],              },              {                  "api response": [{'active': True}], -                "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], +                "expected_lines": ["This user is **currently** nominated", "(1 nomination in total)"],              },              {                  "api response": [{'active': True}, {'active': False}], -                "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], +                "expected_lines": ["This user is **currently** nominated", "(2 nominations in total)"],              },              {                  "api response": [{'active': False}], @@ -334,13 +334,13 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          ) -        header = ["**Nominations**"] +        header = "Nominations"          self._method_subtests(self.cog.user_nomination_counts, test_values, header) [email protected]("bot.cogs.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) [email protected]("bot.cogs.information.constants.MODERATION_CHANNELS", new=[50]) [email protected]("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago")) [email protected]("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50])  class UserEmbedTests(unittest.TestCase):      """Tests for the creation of the `!user` embed.""" @@ -350,7 +350,10 @@ class UserEmbedTests(unittest.TestCase):          self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -362,7 +365,10 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Mr. Hemlock") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_uses_nick_in_title_if_available(self):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -374,7 +380,10 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_ignores_everyone_role(self):          """Created `!user` embeds should not contain mention of the @everyone-role."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -386,8 +395,8 @@ class UserEmbedTests(unittest.TestCase):          embed = asyncio.run(self.cog.create_user_embed(ctx, user)) -        self.assertIn("&Admins", embed.description) -        self.assertNotIn("&Everyone", embed.description) +        self.assertIn("&Admins", embed.fields[1].value) +        self.assertNotIn("&Everyone", embed.fields[1].value)      @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)      @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) @@ -398,8 +407,8 @@ class UserEmbedTests(unittest.TestCase):          moderators_role = helpers.MockRole(name='Moderators')          moderators_role.colour = 100 -        infraction_counts.return_value = "expanded infractions info" -        nomination_counts.return_value = "nomination info" +        infraction_counts.return_value = ("Infractions", "expanded infractions info") +        nomination_counts.return_value = ("Nominations", "nomination info")          user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)          embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -409,20 +418,19 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(              textwrap.dedent(f""" -                **User Information**                  Created: {"1 year ago"}                  Profile: {user.mention}                  ID: {user.id} +            """).strip(), +            embed.fields[0].value +        ) -                **Member Information** +        self.assertEqual( +            textwrap.dedent(f"""                  Joined: {"1 year ago"}                  Roles: &Moderators - -                expanded infractions info - -                nomination info              """).strip(), -            embed.description +            embed.fields[1].value          )      @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @@ -433,7 +441,7 @@ class UserEmbedTests(unittest.TestCase):          moderators_role = helpers.MockRole(name='Moderators')          moderators_role.colour = 100 -        infraction_counts.return_value = "basic infractions info" +        infraction_counts.return_value = ("Infractions", "basic infractions info")          user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)          embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -442,21 +450,30 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(              textwrap.dedent(f""" -                **User Information**                  Created: {"1 year ago"}                  Profile: {user.mention}                  ID: {user.id} +            """).strip(), +            embed.fields[0].value +        ) -                **Member Information** +        self.assertEqual( +            textwrap.dedent(f"""                  Joined: {"1 year ago"}                  Roles: &Moderators - -                basic infractions info              """).strip(), -            embed.description +            embed.fields[1].value +        ) + +        self.assertEqual( +            "basic infractions info", +            embed.fields[3].value          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):          """The embed should be created with the colour of the top role, if a top role is available."""          ctx = helpers.MockContext() @@ -469,7 +486,10 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):          """The embed should be created with a blurple colour if the user has no assigned roles."""          ctx = helpers.MockContext() @@ -479,7 +499,10 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour.blurple()) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) +    @unittest.mock.patch( +        f"{COG_PATH}.basic_user_infraction_counts", +        new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) +    )      def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):          """The embed thumbnail should be set to the user's avatar in `png` format."""          ctx = helpers.MockContext() @@ -492,7 +515,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.thumbnail.url, "avatar url") [email protected]("bot.cogs.information.constants") [email protected]("bot.exts.info.information.constants")  class UserCommandTests(unittest.TestCase):      """Tests for the `!user` command.""" @@ -531,7 +554,7 @@ class UserCommandTests(unittest.TestCase):          with self.assertRaises(InWhitelistCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx)) -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")      def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -544,7 +567,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")      def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -557,7 +580,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")      def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -570,7 +593,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.moderator)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")      def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/exts/moderation/__init__.py b/tests/bot/exts/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/moderation/__init__.py diff --git a/tests/bot/exts/moderation/infraction/__init__.py b/tests/bot/exts/moderation/infraction/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/moderation/infraction/__init__.py diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index da4e92ccc..be1b649e1 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -2,7 +2,7 @@ import textwrap  import unittest  from unittest.mock import AsyncMock, Mock, patch -from bot.cogs.moderation.infractions import Infractions +from bot.exts.moderation.infraction.infractions import Infractions  from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -17,8 +17,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.guild = MockGuild(id=4567)          self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) -    @patch("bot.cogs.moderation.utils.get_active_infraction") -    @patch("bot.cogs.moderation.utils.post_infraction") +    @patch("bot.exts.moderation.infraction._utils.get_active_infraction") +    @patch("bot.exts.moderation.infraction._utils.post_infraction")      async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock):          """Should truncate reason for `ctx.guild.ban`."""          get_active_mock.return_value = None @@ -39,7 +39,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):              self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value          ) -    @patch("bot.cogs.moderation.utils.post_infraction") +    @patch("bot.exts.moderation.infraction._utils.post_infraction")      async def test_apply_kick_reason_truncation(self, post_infraction_mock):          """Should truncate reason for `Member.kick`."""          post_infraction_mock.return_value = {"foo": "bar"} diff --git a/tests/bot/cogs/moderation/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5f649e136..674993862 100644 --- a/tests/bot/cogs/moderation/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -356,4 +356,4 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):          actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason")          self.assertEqual(actual, "foo")          self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) -        post_user_mock.assert_awaited_once_with(self.ctx, self.user) +        post_user_mock.assert_awaited_once_with(self.ctx, self.user)
\ No newline at end of file diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py new file mode 100644 index 000000000..cbf7f7bcf --- /dev/null +++ b/tests/bot/exts/moderation/test_incidents.py @@ -0,0 +1,770 @@ +import asyncio +import enum +import logging +import typing as t +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +import aiohttp +import discord + +from bot.constants import Colours +from bot.exts.moderation import incidents +from tests.helpers import ( +    MockAsyncWebhook, +    MockAttachment, +    MockBot, +    MockMember, +    MockMessage, +    MockReaction, +    MockRole, +    MockTextChannel, +    MockUser, +) + + +class MockAsyncIterable: +    """ +    Helper for mocking asynchronous for loops. + +    It does not appear that the `unittest` library currently provides anything that would +    allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + +    We therefore write our own helper to wrap a regular synchronous iterable, and feed +    its values via `__anext__` rather than `__next__`. + +    This class was written for the purposes of testing the `Incidents` cog - it may not +    be generic enough to be placed in the `tests.helpers` module. +    """ + +    def __init__(self, messages: t.Iterable): +        """Take a sync iterable to be wrapped.""" +        self.iter_messages = iter(messages) + +    def __aiter__(self): +        """Return `self` as we provide the `__anext__` method.""" +        return self + +    async def __anext__(self): +        """ +        Feed the next item, or raise `StopAsyncIteration`. + +        Since we're wrapping a sync iterator, it will communicate that it has been depleted +        by raising a `StopIteration`. The `async for` construct does not expect it, and we +        therefore need to substitute it for the appropriate exception type. +        """ +        try: +            return next(self.iter_messages) +        except StopIteration: +            raise StopAsyncIteration + + +class MockSignal(enum.Enum): +    A = "A" +    B = "B" + + +mock_404 = discord.NotFound( +    response=MagicMock(aiohttp.ClientResponse),  # Mock the erroneous response +    message="Not found", +) + + +class TestDownloadFile(unittest.IsolatedAsyncioTestCase): +    """Collection of tests for the `download_file` helper function.""" + +    async def test_download_file_success(self): +        """If `to_file` succeeds, function returns the acquired `discord.File`.""" +        file = MagicMock(discord.File, filename="bigbadlemon.jpg") +        attachment = MockAttachment(to_file=AsyncMock(return_value=file)) + +        acquired_file = await incidents.download_file(attachment) +        self.assertIs(file, acquired_file) + +    async def test_download_file_404(self): +        """If `to_file` encounters a 404, function handles the exception & returns None.""" +        attachment = MockAttachment(to_file=AsyncMock(side_effect=mock_404)) + +        acquired_file = await incidents.download_file(attachment) +        self.assertIsNone(acquired_file) + +    async def test_download_file_fail(self): +        """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" +        arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") +        attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) + +        with self.assertLogs(logger=incidents.log, level=logging.ERROR): +            acquired_file = await incidents.download_file(attachment) + +        self.assertIsNone(acquired_file) + + +class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): +    """Collection of tests for the `make_embed` helper function.""" + +    async def test_make_embed_actioned(self): +        """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" +        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + +        self.assertEqual(embed.colour.value, Colours.soft_green) +        self.assertIn("Actioned", embed.footer.text) + +    async def test_make_embed_not_actioned(self): +        """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" +        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + +        self.assertEqual(embed.colour.value, Colours.soft_red) +        self.assertIn("Rejected", embed.footer.text) + +    async def test_make_embed_content(self): +        """Incident content appears as embed description.""" +        incident = MockMessage(content="this is an incident") +        embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + +        self.assertEqual(incident.content, embed.description) + +    async def test_make_embed_with_attachment_succeeds(self): +        """Incident's attachment is downloaded and displayed in the embed's image field.""" +        file = MagicMock(discord.File, filename="bigbadjoe.jpg") +        attachment = MockAttachment(filename="bigbadjoe.jpg") +        incident = MockMessage(content="this is an incident", attachments=[attachment]) + +        # Patch `download_file` to return our `file` +        with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): +            embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + +        self.assertIs(file, returned_file) +        self.assertEqual("attachment://bigbadjoe.jpg", embed.image.url) + +    async def test_make_embed_with_attachment_fails(self): +        """Incident's attachment fails to download, proxy url is linked instead.""" +        attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") +        incident = MockMessage(content="this is an incident", attachments=[attachment]) + +        # Patch `download_file` to return None as if the download failed +        with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): +            embed, returned_file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) + +        self.assertIsNone(returned_file) + +        # The author name field is simply expected to have something in it, we do not assert the message +        self.assertGreater(len(embed.author.name), 0) +        self.assertEqual(embed.author.url, "discord.com/bigbadjoe.jpg")  # However, it should link the exact url + + +@patch("bot.constants.Channels.incidents", 123) +class TestIsIncident(unittest.TestCase): +    """ +    Collection of tests for the `is_incident` helper function. + +    In `setUp`, we will create a mock message which should qualify as an incident. Each +    test case will then mutate this instance to make it **not** qualify, in various ways. + +    Notice that we patch the #incidents channel id globally for this class. +    """ + +    def setUp(self) -> None: +        """Prepare a mock message which should qualify as an incident.""" +        self.incident = MockMessage( +            channel=MockTextChannel(id=123), +            content="this is an incident", +            author=MockUser(bot=False), +            pinned=False, +        ) + +    def test_is_incident_true(self): +        """Message qualifies as an incident if unchanged.""" +        self.assertTrue(incidents.is_incident(self.incident)) + +    def check_false(self): +        """Assert that `self.incident` does **not** qualify as an incident.""" +        self.assertFalse(incidents.is_incident(self.incident)) + +    def test_is_incident_false_channel(self): +        """Message doesn't qualify if sent outside of #incidents.""" +        self.incident.channel = MockTextChannel(id=456) +        self.check_false() + +    def test_is_incident_false_content(self): +        """Message doesn't qualify if content begins with hash symbol.""" +        self.incident.content = "# this is a comment message" +        self.check_false() + +    def test_is_incident_false_author(self): +        """Message doesn't qualify if author is a bot.""" +        self.incident.author = MockUser(bot=True) +        self.check_false() + +    def test_is_incident_false_pinned(self): +        """Message doesn't qualify if it is pinned.""" +        self.incident.pinned = True +        self.check_false() + + +class TestOwnReactions(unittest.TestCase): +    """Assertions for the `own_reactions` function.""" + +    def test_own_reactions(self): +        """Only bot's own emoji are extracted from the input incident.""" +        reactions = ( +            MockReaction(emoji="A", me=True), +            MockReaction(emoji="B", me=True), +            MockReaction(emoji="C", me=False), +        ) +        message = MockMessage(reactions=reactions) +        self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) + + +@patch("bot.exts.moderation.incidents.ALL_SIGNALS", {"A", "B"}) +class TestHasSignals(unittest.TestCase): +    """ +    Assertions for the `has_signals` function. + +    We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` +    as appropriate. +    """ + +    def test_has_signals_true(self): +        """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" +        message = MockMessage() +        own_reactions = MagicMock(return_value={"A", "B"}) + +        with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): +            self.assertTrue(incidents.has_signals(message)) + +    def test_has_signals_false(self): +        """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" +        message = MockMessage() +        own_reactions = MagicMock(return_value={"A", "C"}) + +        with patch("bot.exts.moderation.incidents.own_reactions", own_reactions): +            self.assertFalse(incidents.has_signals(message)) + + +@patch("bot.exts.moderation.incidents.Signal", MockSignal) +class TestAddSignals(unittest.IsolatedAsyncioTestCase): +    """ +    Assertions for the `add_signals` coroutine. + +    These are all fairly similar and could go into a single test function, but I found the +    patching & sub-testing fairly awkward in that case and decided to split them up +    to avoid unnecessary syntax noise. +    """ + +    def setUp(self): +        """Prepare a mock incident message for tests to use.""" +        self.incident = MockMessage() + +    @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value=set())) +    async def test_add_signals_missing(self): +        """All emoji are added when none are present.""" +        await incidents.add_signals(self.incident) +        self.incident.add_reaction.assert_has_calls([call("A"), call("B")]) + +    @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A"})) +    async def test_add_signals_partial(self): +        """Only missing emoji are added when some are present.""" +        await incidents.add_signals(self.incident) +        self.incident.add_reaction.assert_has_calls([call("B")]) + +    @patch("bot.exts.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"})) +    async def test_add_signals_present(self): +        """No emoji are added when all are present.""" +        await incidents.add_signals(self.incident) +        self.incident.add_reaction.assert_not_called() + + +class TestIncidents(unittest.IsolatedAsyncioTestCase): +    """ +    Tests for bound methods of the `Incidents` cog. + +    Use this as a base class for `Incidents` tests - it will prepare a fresh instance +    for each test function, but not make any assertions on its own. Tests can mutate +    the instance as they wish. +    """ + +    def setUp(self): +        """ +        Prepare a fresh `Incidents` instance for each test. + +        Note that this will not schedule `crawl_incidents` in the background, as everything +        is being mocked. The `crawl_task` attribute will end up being None. +        """ +        self.cog_instance = incidents.Incidents(MockBot()) + + +@patch("asyncio.sleep", AsyncMock())  # Prevent the coro from sleeping to speed up the test +class TestCrawlIncidents(TestIncidents): +    """ +    Tests for the `Incidents.crawl_incidents` coroutine. + +    Apart from `test_crawl_incidents_waits_until_cache_ready`, all tests in this class +    will patch the return values of `is_incident` and `has_signal` and then observe +    whether the `AsyncMock` for `add_signals` was awaited or not. + +    The `add_signals` mock is added by each test separately to ensure it is clean (has not +    been awaited by another test yet). The mock can be reset, but this appears to be the +    cleaner way. + +    For each test, we inject a mock channel with a history of 1 message only (see: `setUp`). +    """ + +    def setUp(self): +        """For each test, ensure `bot.get_channel` returns a channel with 1 arbitrary message.""" +        super().setUp()  # First ensure we get `cog_instance` from parent + +        incidents_history = MagicMock(return_value=MockAsyncIterable([MockMessage()])) +        self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(history=incidents_history)) + +    async def test_crawl_incidents_waits_until_cache_ready(self): +        """ +        The coroutine will await the `wait_until_guild_available` event. + +        Since this task is schedule in the `__init__`, it is critical that it waits for the +        cache to be ready, so that it can safely get the #incidents channel. +        """ +        await self.cog_instance.crawl_incidents() +        self.cog_instance.bot.wait_until_guild_available.assert_awaited() + +    @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) +    @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False))  # Message doesn't qualify +    @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False)) +    async def test_crawl_incidents_noop_if_is_not_incident(self): +        """Signals are not added for a non-incident message.""" +        await self.cog_instance.crawl_incidents() +        incidents.add_signals.assert_not_awaited() + +    @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) +    @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True))  # Message qualifies +    @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=True))  # But already has signals +    async def test_crawl_incidents_noop_if_message_already_has_signals(self): +        """Signals are not added for messages which already have them.""" +        await self.cog_instance.crawl_incidents() +        incidents.add_signals.assert_not_awaited() + +    @patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) +    @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True))  # Message qualifies +    @patch("bot.exts.moderation.incidents.has_signals", MagicMock(return_value=False))  # And doesn't have signals +    async def test_crawl_incidents_add_signals_called(self): +        """Message has signals added as it does not have them yet and qualifies as an incident.""" +        await self.cog_instance.crawl_incidents() +        incidents.add_signals.assert_awaited_once() + + +class TestArchive(TestIncidents): +    """Tests for the `Incidents.archive` coroutine.""" + +    async def test_archive_webhook_not_found(self): +        """ +        Method recovers and returns False when the webhook is not found. + +        Implicitly, this also tests that the error is handled internally and doesn't +        propagate out of the method, which is just as important. +        """ +        self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) +        self.assertFalse( +            await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) +        ) + +    async def test_archive_relays_incident(self): +        """ +        If webhook is found, method relays `incident` properly. + +        This test will assert that the fetched webhook's `send` method is fed the correct arguments, +        and that the `archive` method returns True. +        """ +        webhook = MockAsyncWebhook() +        self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook)  # Patch in our webhook + +        # Define our own `incident` to be archived +        incident = MockMessage( +            content="this is an incident", +            author=MockUser(name="author_name", avatar_url="author_avatar"), +            id=123, +        ) +        built_embed = MagicMock(discord.Embed, id=123)  # We patch `make_embed` to return this + +        with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): +            archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) + +        # Now we check that the webhook was given the correct args, and that `archive` returned True +        webhook.send.assert_called_once_with( +            embed=built_embed, +            username="author_name", +            avatar_url="author_avatar", +            file=None, +        ) +        self.assertTrue(archive_return) + +    async def test_archive_clyde_username(self): +        """ +        The archive webhook username is cleansed using `sub_clyde`. + +        Discord will reject any webhook with "clyde" in the username field, as it impersonates +        the official Clyde bot. Since we do not control what the username will be (the incident +        author name is used), we must ensure the name is cleansed, otherwise the relay may fail. + +        This test assumes the username is passed as a kwarg. If this test fails, please review +        whether the passed argument is being retrieved correctly. +        """ +        webhook = MockAsyncWebhook() +        self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) + +        message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) +        await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) + +        self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) + + +class TestMakeConfirmationTask(TestIncidents): +    """ +    Tests for the `Incidents.make_confirmation_task` method. + +    Writing tests for this method is difficult, as it mostly just delegates the provided +    information elsewhere. There is very little internal logic. Whether our approach +    works conceptually is difficult to prove using unit tests. +    """ + +    def test_make_confirmation_task_check(self): +        """ +        The internal check will recognize the passed incident. + +        This is a little tricky - we first pass a message with a specific `id` in, and then +        retrieve the built check from the `call_args` of the `wait_for` method. This relies +        on the check being passed as a kwarg. + +        Once the check is retrieved, we assert that it gives True for our incident's `id`, +        and False for any other. + +        If this function begins to fail, first check that `created_check` is being retrieved +        correctly. It should be the function that is built locally in the tested method. +        """ +        self.cog_instance.make_confirmation_task(MockMessage(id=123)) + +        self.cog_instance.bot.wait_for.assert_called_once() +        created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"] + +        # The `message_id` matches the `id` of our incident +        self.assertTrue(created_check(payload=MagicMock(message_id=123))) + +        # This `message_id` does not match +        self.assertFalse(created_check(payload=MagicMock(message_id=0))) + + +@patch("bot.exts.moderation.incidents.ALLOWED_ROLES", {1, 2}) +@patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", AsyncMock())  # Generic awaitable +class TestProcessEvent(TestIncidents): +    """Tests for the `Incidents.process_event` coroutine.""" + +    async def test_process_event_bad_role(self): +        """The reaction is removed when the author lacks all allowed roles.""" +        incident = MockMessage() +        member = MockMember(roles=[MockRole(id=0)])  # Must have role 1 or 2 + +        await self.cog_instance.process_event("reaction", incident, member) +        incident.remove_reaction.assert_called_once_with("reaction", member) + +    async def test_process_event_bad_emoji(self): +        """ +        The reaction is removed when an invalid emoji is used. + +        This requires that we pass in a `member` with valid roles, as we need the role check +        to succeed. +        """ +        incident = MockMessage() +        member = MockMember(roles=[MockRole(id=1)])  # Member has allowed role + +        await self.cog_instance.process_event("invalid_signal", incident, member) +        incident.remove_reaction.assert_called_once_with("invalid_signal", member) + +    async def test_process_event_no_archive_on_investigating(self): +        """Message is not archived on `Signal.INVESTIGATING`.""" +        with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive: +            await self.cog_instance.process_event( +                reaction=incidents.Signal.INVESTIGATING.value, +                incident=MockMessage(), +                member=MockMember(roles=[MockRole(id=1)]), +            ) + +        mocked_archive.assert_not_called() + +    async def test_process_event_no_delete_if_archive_fails(self): +        """ +        Original message is not deleted when `Incidents.archive` returns False. + +        This is the way of signaling that the relay failed, and we should not remove the original, +        as that would result in losing the incident record. +        """ +        incident = MockMessage() + +        with patch("bot.exts.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)): +            await self.cog_instance.process_event( +                reaction=incidents.Signal.ACTIONED.value, +                incident=incident, +                member=MockMember(roles=[MockRole(id=1)]) +            ) + +        incident.delete.assert_not_called() + +    async def test_process_event_confirmation_task_is_awaited(self): +        """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" +        mock_task = AsyncMock() + +        with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): +            await self.cog_instance.process_event( +                reaction=incidents.Signal.ACTIONED.value, +                incident=MockMessage(), +                member=MockMember(roles=[MockRole(id=1)]) +            ) + +        mock_task.assert_awaited() + +    async def test_process_event_confirmation_task_timeout_is_handled(self): +        """ +        Confirmation task `asyncio.TimeoutError` is handled gracefully. + +        We have `make_confirmation_task` return a mock with a side effect, and then catch the +        exception should it propagate out of `process_event`. This is so that we can then manually +        fail the test with a more informative message than just the plain traceback. +        """ +        mock_task = AsyncMock(side_effect=asyncio.TimeoutError()) + +        try: +            with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): +                await self.cog_instance.process_event( +                    reaction=incidents.Signal.ACTIONED.value, +                    incident=MockMessage(), +                    member=MockMember(roles=[MockRole(id=1)]) +                ) +        except asyncio.TimeoutError: +            self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!") + + +class TestResolveMessage(TestIncidents): +    """Tests for the `Incidents.resolve_message` coroutine.""" + +    async def test_resolve_message_pass_message_id(self): +        """Method will call `_get_message` with the passed `message_id`.""" +        await self.cog_instance.resolve_message(123) +        self.cog_instance.bot._connection._get_message.assert_called_once_with(123) + +    async def test_resolve_message_in_cache(self): +        """ +        No API call is made if the queried message exists in the cache. + +        We mock the `_get_message` return value regardless of input. Whether it finds the message +        internally is considered d.py's responsibility, not ours. +        """ +        cached_message = MockMessage(id=123) +        self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message) + +        return_value = await self.cog_instance.resolve_message(123) + +        self.assertIs(return_value, cached_message) +        self.cog_instance.bot.get_channel.assert_not_called()  # The `fetch_message` line was never hit + +    async def test_resolve_message_not_in_cache(self): +        """ +        The message is retrieved from the API if it isn't cached. + +        This is desired behaviour for messages which exist, but were sent before the bot's +        current session. +        """ +        self.cog_instance.bot._connection._get_message = MagicMock(return_value=None)  # Cache returns None + +        # API returns our message +        uncached_message = MockMessage() +        fetch_message = AsyncMock(return_value=uncached_message) +        self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + +        retrieved_message = await self.cog_instance.resolve_message(123) +        self.assertIs(retrieved_message, uncached_message) + +    async def test_resolve_message_doesnt_exist(self): +        """ +        If the API returns a 404, the function handles it gracefully and returns None. + +        This is an edge-case happening with racing events - event A will relay the message +        to the archive and delete the original. Once event B acquires the `event_lock`, +        it will not find the message in the cache, and will ask the API. +        """ +        self.cog_instance.bot._connection._get_message = MagicMock(return_value=None)  # Cache returns None + +        fetch_message = AsyncMock(side_effect=mock_404) +        self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + +        self.assertIsNone(await self.cog_instance.resolve_message(123)) + +    async def test_resolve_message_fetch_fails(self): +        """ +        Non-404 errors are handled, logged & None is returned. + +        In contrast with a 404, this should make an error-level log. We assert that at least +        one such log was made - we do not make any assertions about the log's message. +        """ +        self.cog_instance.bot._connection._get_message = MagicMock(return_value=None)  # Cache returns None + +        arbitrary_error = discord.HTTPException( +            response=MagicMock(aiohttp.ClientResponse), +            message="Arbitrary error", +        ) +        fetch_message = AsyncMock(side_effect=arbitrary_error) +        self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message)) + +        with self.assertLogs(logger=incidents.log, level=logging.ERROR): +            self.assertIsNone(await self.cog_instance.resolve_message(123)) + + +@patch("bot.constants.Channels.incidents", 123) +class TestOnRawReactionAdd(TestIncidents): +    """ +    Tests for the `Incidents.on_raw_reaction_add` listener. + +    Writing tests for this listener comes with additional complexity due to the listener +    awaiting the `crawl_task` task. See `asyncSetUp` for further details, which attempts +    to make unit testing this function possible. +    """ + +    def setUp(self): +        """ +        Prepare & assign `payload` attribute. + +        This attribute represents an *ideal* payload which will not be rejected by the +        listener. As each test will receive a fresh instance, it can be mutated to +        observe how the listener's behaviour changes with different attributes on +        the passed payload. +        """ +        super().setUp()  # Ensure `cog_instance` is assigned + +        self.payload = MagicMock( +            discord.RawReactionActionEvent, +            channel_id=123,  # Patched at class level +            message_id=456, +            member=MockMember(bot=False), +            emoji="reaction", +        ) + +    async def asyncSetUp(self):  # noqa: N802 +        """ +        Prepare an empty task and assign it as `crawl_task`. + +        It appears that the `unittest` framework does not provide anything for mocking +        asyncio tasks. An `AsyncMock` instance can be called and then awaited, however, +        it does not provide the `done` method or any other parts of the `asyncio.Task` +        interface. + +        Although we do not need to make any assertions about the task itself while +        testing the listener, the code will still await it and call the `done` method, +        and so we must inject something that will not fail on either action. + +        Note that this is done in an `asyncSetUp`, which runs after `setUp`. +        The justification is that creating an actual task requires the event +        loop to be ready, which is not the case in the `setUp`. +        """ +        mock_task = asyncio.create_task(AsyncMock()())  # Mock async func, then a coro +        self.cog_instance.crawl_task = mock_task + +    async def test_on_raw_reaction_add_wrong_channel(self): +        """ +        Events outside of #incidents will be ignored. + +        We check this by asserting that `resolve_message` was never queried. +        """ +        self.payload.channel_id = 0 +        self.cog_instance.resolve_message = AsyncMock() + +        await self.cog_instance.on_raw_reaction_add(self.payload) +        self.cog_instance.resolve_message.assert_not_called() + +    async def test_on_raw_reaction_add_user_is_bot(self): +        """ +        Events dispatched by bot accounts will be ignored. + +        We check this by asserting that `resolve_message` was never queried. +        """ +        self.payload.member = MockMember(bot=True) +        self.cog_instance.resolve_message = AsyncMock() + +        await self.cog_instance.on_raw_reaction_add(self.payload) +        self.cog_instance.resolve_message.assert_not_called() + +    async def test_on_raw_reaction_add_message_doesnt_exist(self): +        """ +        Listener gracefully handles the case where `resolve_message` gives None. + +        We check this by asserting that `process_event` was never called. +        """ +        self.cog_instance.process_event = AsyncMock() +        self.cog_instance.resolve_message = AsyncMock(return_value=None) + +        await self.cog_instance.on_raw_reaction_add(self.payload) +        self.cog_instance.process_event.assert_not_called() + +    async def test_on_raw_reaction_add_message_is_not_an_incident(self): +        """ +        The event won't be processed if the related message is not an incident. + +        This is an edge-case that can happen if someone manually leaves a reaction +        on a pinned message, or a comment. + +        We check this by asserting that `process_event` was never called. +        """ +        self.cog_instance.process_event = AsyncMock() +        self.cog_instance.resolve_message = AsyncMock(return_value=MockMessage()) + +        with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)): +            await self.cog_instance.on_raw_reaction_add(self.payload) + +        self.cog_instance.process_event.assert_not_called() + +    async def test_on_raw_reaction_add_valid_event_is_processed(self): +        """ +        If the reaction event is valid, it is passed to `process_event`. + +        This is the case when everything goes right: +            * The reaction was placed in #incidents, and not by a bot +            * The message was found successfully +            * The message qualifies as an incident + +        Additionally, we check that all arguments were passed as expected. +        """ +        incident = MockMessage(id=1) + +        self.cog_instance.process_event = AsyncMock() +        self.cog_instance.resolve_message = AsyncMock(return_value=incident) + +        with patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)): +            await self.cog_instance.on_raw_reaction_add(self.payload) + +        self.cog_instance.process_event.assert_called_with( +            "reaction",  # Defined in `self.payload` +            incident, +            self.payload.member, +        ) + + +class TestOnMessage(TestIncidents): +    """ +    Tests for the `Incidents.on_message` listener. + +    Notice the decorators mocking the `is_incident` return value. The `is_incidents` +    function is tested in `TestIsIncident` - here we do not worry about it. +    """ + +    @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=True)) +    async def test_on_message_incident(self): +        """Messages qualifying as incidents are passed to `add_signals`.""" +        incident = MockMessage() + +        with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: +            await self.cog_instance.on_message(incident) + +        mock_add_signals.assert_called_once_with(incident) + +    @patch("bot.exts.moderation.incidents.is_incident", MagicMock(return_value=False)) +    async def test_on_message_non_incident(self): +        """Messages not qualifying as incidents are ignored.""" +        with patch("bot.exts.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals: +            await self.cog_instance.on_message(MockMessage()) + +        mock_add_signals.assert_not_called() diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index f2809f40a..f8f142484 100644 --- a/tests/bot/cogs/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -2,7 +2,7 @@ import unittest  import discord -from bot.cogs.moderation.modlog import ModLog +from bot.exts.moderation.modlog import ModLog  from tests.helpers import MockBot, MockTextChannel diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 3fd149f04..8c4fb764a 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock, Mock  from discord import PermissionOverwrite -from bot.cogs.moderation.silence import Silence, SilenceNotifier  from bot.constants import Channels, Emojis, Guild, Roles +from bot.exts.moderation.silence import Silence, SilenceNotifier  from tests.helpers import MockBot, MockContext, MockTextChannel @@ -99,7 +99,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          self.bot.get_channel.called_once_with(Channels.mod_alerts)          self.bot.get_channel.called_once_with(Channels.mod_log) -    @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") +    @mock.patch("bot.exts.moderation.silence.SilenceNotifier")      async def test_instance_vars_got_notifier(self, notifier):          """Notifier was started with channel."""          mod_log = MockTextChannel() @@ -127,10 +127,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):              self.ctx.reset_mock()      async def test_unsilence_sent_correct_discord_message(self): -        """Proper reply after a successful unsilence.""" -        with mock.patch.object(self.cog, "_unsilence", return_value=True): -            await self.cog.unsilence.callback(self.cog, self.ctx) -            self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") +        """Check if proper message was sent when unsilencing channel.""" +        test_cases = ( +            (True, f"{Emojis.check_mark} unsilenced current channel."), +            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        ) +        for _unsilence_patch_return, result_message in test_cases: +            with self.subTest( +                starting_silenced_state=_unsilence_patch_return, +                result_message=result_message +            ): +                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): +                    await self.cog.unsilence.callback(self.cog, self.ctx) +                    self.ctx.send.assert_called_once_with(result_message) +            self.ctx.reset_mock()      async def test_silence_private_for_false(self):          """Permissions are not set and `False` is returned in an already silenced channel.""" @@ -228,7 +238,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          del mock_permissions_dict['send_messages']          self.assertDictEqual(mock_permissions_dict, new_permissions) -    @mock.patch("bot.cogs.moderation.silence.asyncio") +    @mock.patch("bot.exts.moderation.silence.asyncio")      @mock.patch.object(Silence, "_mod_alerts_channel", create=True)      def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):          """Task for sending an alert was created with present `muted_channels`.""" @@ -237,14 +247,14 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):              alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")              asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) -    @mock.patch("bot.cogs.moderation.silence.asyncio") +    @mock.patch("bot.exts.moderation.silence.asyncio")      def test_cog_unload_skips_task_start(self, asyncio_mock):          """No task created with no channels."""          self.cog.cog_unload()          asyncio_mock.create_task.assert_not_called() -    @mock.patch("bot.cogs.moderation.silence.with_role_check") -    @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) +    @mock.patch("bot.exts.moderation.silence.with_role_check") +    @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))      def test_cog_check(self, role_check):          """Role check is called with `MODERATION_ROLES`"""          self.cog.cog_check(self.ctx) diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py new file mode 100644 index 000000000..e90394ab9 --- /dev/null +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -0,0 +1,111 @@ +import unittest +from unittest import mock + +from dateutil.relativedelta import relativedelta + +from bot.constants import Emojis +from bot.exts.moderation.slowmode import Slowmode +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SlowmodeTests(unittest.IsolatedAsyncioTestCase): + +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = Slowmode(self.bot) +        self.ctx = MockContext() + +    async def test_get_slowmode_no_channel(self) -> None: +        """Get slowmode without a given channel.""" +        self.ctx.channel = MockTextChannel(name='python-general', slowmode_delay=5) + +        await self.cog.get_slowmode(self.cog, self.ctx, None) +        self.ctx.send.assert_called_once_with("The slowmode delay for #python-general is 5 seconds.") + +    async def test_get_slowmode_with_channel(self) -> None: +        """Get slowmode with a given channel.""" +        text_channel = MockTextChannel(name='python-language', slowmode_delay=2) + +        await self.cog.get_slowmode(self.cog, self.ctx, text_channel) +        self.ctx.send.assert_called_once_with('The slowmode delay for #python-language is 2 seconds.') + +    async def test_set_slowmode_no_channel(self) -> None: +        """Set slowmode without a given channel.""" +        test_cases = ( +            ('helpers', 23, True, f'{Emojis.check_mark} The slowmode delay for #helpers is now 23 seconds.'), +            ('mods', 76526, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.'), +            ('admins', 97, True, f'{Emojis.check_mark} The slowmode delay for #admins is now 1 minute and 37 seconds.') +        ) + +        for channel_name, seconds, edited, result_msg in test_cases: +            with self.subTest( +                channel_mention=channel_name, +                seconds=seconds, +                edited=edited, +                result_msg=result_msg +            ): +                self.ctx.channel = MockTextChannel(name=channel_name) + +                await self.cog.set_slowmode(self.cog, self.ctx, None, relativedelta(seconds=seconds)) + +                if edited: +                    self.ctx.channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) +                else: +                    self.ctx.channel.edit.assert_not_called() + +                self.ctx.send.assert_called_once_with(result_msg) + +            self.ctx.reset_mock() + +    async def test_set_slowmode_with_channel(self) -> None: +        """Set slowmode with a given channel.""" +        test_cases = ( +            ('bot-commands', 12, True, f'{Emojis.check_mark} The slowmode delay for #bot-commands is now 12 seconds.'), +            ('mod-spam', 21, True, f'{Emojis.check_mark} The slowmode delay for #mod-spam is now 21 seconds.'), +            ('admin-spam', 4323598, False, f'{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours.') +        ) + +        for channel_name, seconds, edited, result_msg in test_cases: +            with self.subTest( +                channel_mention=channel_name, +                seconds=seconds, +                edited=edited, +                result_msg=result_msg +            ): +                text_channel = MockTextChannel(name=channel_name) + +                await self.cog.set_slowmode(self.cog, self.ctx, text_channel, relativedelta(seconds=seconds)) + +                if edited: +                    text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds)) +                else: +                    text_channel.edit.assert_not_called() + +                self.ctx.send.assert_called_once_with(result_msg) + +            self.ctx.reset_mock() + +    async def test_reset_slowmode_no_channel(self) -> None: +        """Reset slowmode without a given channel.""" +        self.ctx.channel = MockTextChannel(name='careers', slowmode_delay=6) + +        await self.cog.reset_slowmode(self.cog, self.ctx, None) +        self.ctx.send.assert_called_once_with( +            f'{Emojis.check_mark} The slowmode delay for #careers has been reset to 0 seconds.' +        ) + +    async def test_reset_slowmode_with_channel(self) -> None: +        """Reset slowmode with a given channel.""" +        text_channel = MockTextChannel(name='meta', slowmode_delay=1) + +        await self.cog.reset_slowmode(self.cog, self.ctx, text_channel) +        self.ctx.send.assert_called_once_with( +            f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.' +        ) + +    @mock.patch("bot.exts.moderation.slowmode.with_role_check") +    @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) +    def test_cog_check(self, role_check): +        """Role check is called with `MODERATION_ROLES`""" +        self.cog.cog_check(self.ctx) +        role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/exts/test_cogs.py index fdda59a8f..f8e120262 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -10,7 +10,7 @@ from unittest import mock  from discord.ext import commands -from bot import cogs +from bot import exts  class CommandNameTests(unittest.TestCase): @@ -29,13 +29,14 @@ class CommandNameTests(unittest.TestCase):      @staticmethod      def walk_modules() -> t.Iterator[ModuleType]: -        """Yield imported modules from the bot.cogs subpackage.""" +        """Yield imported modules from the bot.exts subpackage."""          def on_error(name: str) -> t.NoReturn:              raise ImportError(name=name)  # pragma: no cover          # The mock prevents asyncio.get_event_loop() from being called.          with mock.patch("discord.ext.tasks.loop"): -            for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): +            prefix = f"{exts.__name__}." +            for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error):                  if not module.ispkg:                      yield importlib.import_module(module.name) @@ -53,6 +54,7 @@ class CommandNameTests(unittest.TestCase):          """Return a list of all qualified names, including aliases, for the `command`."""          names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases]          names.append(command.qualified_name) +        names += getattr(command, "root_aliases", [])          return names diff --git a/tests/bot/exts/utils/__init__.py b/tests/bot/exts/utils/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/utils/__init__.py diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/utils/test_jams.py new file mode 100644 index 000000000..45e7b5b51 --- /dev/null +++ b/tests/bot/exts/utils/test_jams.py @@ -0,0 +1,173 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, create_autospec + +from discord import CategoryChannel + +from bot.constants import Roles +from bot.exts.utils import jams +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel + + +def get_mock_category(channel_count: int, name: str) -> CategoryChannel: +    """Return a mocked code jam category.""" +    category = create_autospec(CategoryChannel, spec_set=True, instance=True) +    category.name = name +    category.channels = [MockTextChannel() for _ in range(channel_count)] + +    return category + + +class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): +    """Tests for `createteam` command.""" + +    def setUp(self): +        self.bot = MockBot() +        self.admin_role = MockRole(name="Admins", id=Roles.admins) +        self.command_user = MockMember([self.admin_role]) +        self.guild = MockGuild([self.admin_role]) +        self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) +        self.cog = jams.CodeJams(self.bot) + +    async def test_too_small_amount_of_team_members_passed(self): +        """Should `ctx.send` and exit early when too small amount of members.""" +        for case in (1, 2): +            with self.subTest(amount_of_members=case): +                self.cog.create_channels = AsyncMock() +                self.cog.add_roles = AsyncMock() + +                self.ctx.reset_mock() +                members = (MockMember() for _ in range(case)) +                await self.cog.createteam(self.cog, self.ctx, "foo", members) + +                self.ctx.send.assert_awaited_once() +                self.cog.create_channels.assert_not_awaited() +                self.cog.add_roles.assert_not_awaited() + +    async def test_duplicate_members_provided(self): +        """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" +        self.cog.create_channels = AsyncMock() +        self.cog.add_roles = AsyncMock() + +        member = MockMember() +        await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) + +        self.ctx.send.assert_awaited_once() +        self.cog.create_channels.assert_not_awaited() +        self.cog.add_roles.assert_not_awaited() + +    async def test_result_sending(self): +        """Should call `ctx.send` when everything goes right.""" +        self.cog.create_channels = AsyncMock() +        self.cog.add_roles = AsyncMock() + +        members = [MockMember() for _ in range(5)] +        await self.cog.createteam(self.cog, self.ctx, "foo", members) + +        self.cog.create_channels.assert_awaited_once() +        self.cog.add_roles.assert_awaited_once() +        self.ctx.send.assert_awaited_once() + +    async def test_category_doesnt_exist(self): +        """Should create a new code jam category.""" +        subtests = ( +            [], +            [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], +            [get_mock_category(jams.MAX_CHANNELS - 2, "other")], +        ) + +        for categories in subtests: +            self.guild.reset_mock() +            self.guild.categories = categories + +            with self.subTest(categories=categories): +                actual_category = await self.cog.get_category(self.guild) + +                self.guild.create_category_channel.assert_awaited_once() +                category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + +                self.assertFalse(category_overwrites[self.guild.default_role].read_messages) +                self.assertTrue(category_overwrites[self.guild.me].read_messages) +                self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + +    async def test_category_channel_exist(self): +        """Should not try to create category channel.""" +        expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) +        self.guild.categories = [ +            get_mock_category(jams.MAX_CHANNELS - 2, "other"), +            expected_category, +            get_mock_category(0, jams.CATEGORY_NAME), +        ] + +        actual_category = await self.cog.get_category(self.guild) +        self.assertEqual(expected_category, actual_category) + +    async def test_channel_overwrites(self): +        """Should have correct permission overwrites for users and roles.""" +        leader = MockMember() +        members = [leader] + [MockMember() for _ in range(4)] +        overwrites = self.cog.get_overwrites(members, self.guild) + +        # Leader permission overwrites +        self.assertTrue(overwrites[leader].manage_messages) +        self.assertTrue(overwrites[leader].read_messages) +        self.assertTrue(overwrites[leader].manage_webhooks) +        self.assertTrue(overwrites[leader].connect) + +        # Other members permission overwrites +        for member in members[1:]: +            self.assertTrue(overwrites[member].read_messages) +            self.assertTrue(overwrites[member].connect) + +        # Everyone and verified role overwrite +        self.assertFalse(overwrites[self.guild.default_role].read_messages) +        self.assertFalse(overwrites[self.guild.default_role].connect) +        self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].read_messages) +        self.assertFalse(overwrites[self.guild.get_role(Roles.verified)].connect) + +    async def test_team_channels_creation(self): +        """Should create new voice and text channel for team.""" +        members = [MockMember() for _ in range(5)] + +        self.cog.get_overwrites = MagicMock() +        self.cog.get_category = AsyncMock() +        self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") +        actual = await self.cog.create_channels(self.guild, "my-team", members) + +        self.assertEqual("foobar-channel", actual) +        self.cog.get_overwrites.assert_called_once_with(members, self.guild) +        self.cog.get_category.assert_awaited_once_with(self.guild) + +        self.guild.create_text_channel.assert_awaited_once_with( +            "my-team", +            overwrites=self.cog.get_overwrites.return_value, +            category=self.cog.get_category.return_value +        ) +        self.guild.create_voice_channel.assert_awaited_once_with( +            "My Team", +            overwrites=self.cog.get_overwrites.return_value, +            category=self.cog.get_category.return_value +        ) + +    async def test_jam_roles_adding(self): +        """Should add team leader role to leader and jam role to every team member.""" +        leader_role = MockRole(name="Team Leader") +        jam_role = MockRole(name="Jammer") +        self.guild.get_role.side_effect = [leader_role, jam_role] + +        leader = MockMember() +        members = [leader] + [MockMember() for _ in range(4)] +        await self.cog.add_roles(self.guild, members) + +        leader.add_roles.assert_any_await(leader_role) +        for member in members: +            member.add_roles.assert_any_await(jam_role) + + +class CodeJamSetup(unittest.TestCase): +    """Test for `setup` function of `CodeJam` cog.""" + +    def test_setup(self): +        """Should call `bot.add_cog`.""" +        bot = MockBot() +        jams.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index cf9adbee0..c272a4756 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -1,13 +1,12 @@  import asyncio -import logging  import unittest  from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch  from discord.ext import commands  from bot import constants -from bot.cogs import snekbox -from bot.cogs.snekbox import Snekbox +from bot.exts.utils import snekbox +from bot.exts.utils.snekbox import Snekbox  from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser @@ -39,43 +38,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))          self.assertEqual(result, "too long to upload") -    async def test_upload_output(self): +    @patch("bot.exts.utils.snekbox.send_to_paste_service") +    async def test_upload_output(self, mock_paste_util):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" -        key = "MarkDiamond" -        resp = MagicMock() -        resp.json = AsyncMock(return_value={"key": key}) - -        context_manager = MagicMock() -        context_manager.__aenter__.return_value = resp -        self.bot.http_session.post.return_value = context_manager - -        self.assertEqual( -            await self.cog.upload_output("My awesome output"), -            constants.URLs.paste_service.format(key=key) -        ) -        self.bot.http_session.post.assert_called_with( -            constants.URLs.paste_service.format(key="documents"), -            data="My awesome output", -            raise_for_status=True +        await self.cog.upload_output("Test output.") +        mock_paste_util.assert_called_once_with( +            self.bot.http_session, "Test output.", extension="txt"          ) -    async def test_upload_output_gracefully_fallback_if_exception_during_request(self): -        """Output upload gracefully fallback if the upload fail.""" -        resp = MagicMock() -        resp.json = AsyncMock(side_effect=Exception) - -        context_manager = MagicMock() -        context_manager.__aenter__.return_value = resp -        self.bot.http_session.post.return_value = context_manager - -        log = logging.getLogger("bot.cogs.snekbox") -        with self.assertLogs(logger=log, level='ERROR'): -            await self.cog.upload_output('My awesome output!') - -    async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): -        """Output upload gracefully fallback if there is no key entry in the response body.""" -        self.assertEqual((await self.cog.upload_output('My awesome output!')), None) -      def test_prepare_input(self):          cases = (              ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), @@ -99,14 +69,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):                  actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode})                  self.assertEqual(actual, expected) -    @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) +    @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)      def test_get_results_message_invalid_signal(self, mock_signals: Mock):          self.assertEqual(              self.cog.get_results_message({'stdout': '', 'returncode': 127}),              ('Your eval job has completed with return code 127', '')          ) -    @patch('bot.cogs.snekbox.Signals') +    @patch('bot.exts.utils.snekbox.Signals')      def test_get_results_message_valid_signal(self, mock_signals: Mock):          mock_signals.return_value.name = 'SIGTEST'          self.assertEqual( @@ -233,9 +203,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.get_status_emoji = MagicMock(return_value=':yay!:')          self.cog.format_output = AsyncMock(return_value=('[No output]', None)) +        mocked_filter_cog = MagicMock() +        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        self.bot.get_cog.return_value = mocked_filter_cog +          await self.cog.send_eval(ctx, 'MyAwesomeCode')          ctx.send.assert_called_once_with( -            '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' +            '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```'          )          self.cog.post_eval.assert_called_once_with('MyAwesomeCode')          self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) @@ -254,10 +228,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.get_status_emoji = MagicMock(return_value=':yay!:')          self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) +        mocked_filter_cog = MagicMock() +        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        self.bot.get_cog.return_value = mocked_filter_cog +          await self.cog.send_eval(ctx, 'MyAwesomeCode')          ctx.send.assert_called_once_with(              '@LemonLemonishBeard#0042 :yay!: Return code 0.' -            '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' +            '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'          )          self.cog.post_eval.assert_called_once_with('MyAwesomeCode')          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) @@ -275,16 +253,20 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.get_status_emoji = MagicMock(return_value=':nope!:')          self.cog.format_output = AsyncMock()  # This function isn't called +        mocked_filter_cog = MagicMock() +        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        self.bot.get_cog.return_value = mocked_filter_cog +          await self.cog.send_eval(ctx, 'MyAwesomeCode')          ctx.send.assert_called_once_with( -            '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' +            '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'          )          self.cog.post_eval.assert_called_once_with('MyAwesomeCode')          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})          self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})          self.cog.format_output.assert_not_called() -    @patch("bot.cogs.snekbox.partial") +    @patch("bot.exts.utils.snekbox.partial")      async def test_continue_eval_does_continue(self, partial_mock):          """Test that the continue_eval function does continue if required conditions are met."""          ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index 0a734b505..630f2516d 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -8,29 +8,39 @@ class LinePaginatorTests(TestCase):      def setUp(self):          """Create a paginator for the test method.""" -        self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30) - -    def test_add_line_raises_on_too_long_lines(self): -        """`add_line` should raise a `RuntimeError` for too long lines.""" -        message = f"Line exceeds maximum page size {self.paginator.max_size - 2}" -        with self.assertRaises(RuntimeError, msg=message): -            self.paginator.add_line('x' * self.paginator.max_size) +        self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30, +                                                  scale_to_size=50)      def test_add_line_works_on_small_lines(self):          """`add_line` should allow small lines to be added."""          self.paginator.add_line('x' * (self.paginator.max_size - 3)) - - -class ImagePaginatorTests(TestCase): -    """Tests functionality of the `ImagePaginator`.""" - -    def setUp(self): -        """Create a paginator for the test method.""" -        self.paginator = pagination.ImagePaginator() - -    def test_add_image_appends_image(self): -        """`add_image` appends the image to the image list.""" -        image = 'lemon' -        self.paginator.add_image(image) - -        assert self.paginator.images == [image] +        # Note that the page isn't added to _pages until it's full. +        self.assertEqual(len(self.paginator._pages), 0) + +    def test_add_line_works_on_long_lines(self): +        """After additional lines after `max_size` is exceeded should go on the next page.""" +        self.paginator.add_line('x' * self.paginator.max_size) +        self.assertEqual(len(self.paginator._pages), 0) + +        # Any additional lines should start a new page after `max_size` is exceeded. +        self.paginator.add_line('x') +        self.assertEqual(len(self.paginator._pages), 1) + +    def test_add_line_continuation(self): +        """When `scale_to_size` is exceeded, remaining words should be split onto the next page.""" +        self.paginator.add_line('zyz ' * (self.paginator.scale_to_size//4 + 1)) +        self.assertEqual(len(self.paginator._pages), 1) + +    def test_add_line_no_continuation(self): +        """If adding a new line to an existing page would exceed `max_size`, it should start a new +        page rather than using continuation. +        """ +        self.paginator.add_line('z' * (self.paginator.max_size - 3)) +        self.paginator.add_line('z') +        self.assertEqual(len(self.paginator._pages), 1) + +    def test_add_line_truncates_very_long_words(self): +        """`add_line` should truncate if a single long word exceeds `scale_to_size`.""" +        self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) +        # Note: item at index 1 is the truncated line, index 0 is prefix +        self.assertEqual(self.paginator._current_page[1], 'x' * self.paginator.scale_to_size) diff --git a/tests/bot/utils/test_messages.py b/tests/bot/utils/test_messages.py new file mode 100644 index 000000000..9c22c9751 --- /dev/null +++ b/tests/bot/utils/test_messages.py @@ -0,0 +1,27 @@ +import unittest + +from bot.utils import messages + + +class TestMessages(unittest.TestCase): +    """Tests for functions in the `bot.utils.messages` module.""" + +    def test_sub_clyde(self): +        """Uppercase E's and lowercase e's are substituted with their cyrillic counterparts.""" +        sub_e = "\u0435" +        sub_E = "\u0415"  # noqa: N806: Uppercase E in variable name + +        test_cases = ( +            (None, None), +            ("", ""), +            ("clyde", f"clyd{sub_e}"), +            ("CLYDE", f"CLYD{sub_E}"), +            ("cLyDe", f"cLyD{sub_e}"), +            ("BIGclyde", f"BIGclyd{sub_e}"), +            ("small clydeus the unholy", f"small clyd{sub_e}us the unholy"), +            ("BIGCLYDE, babyclyde", f"BIGCLYD{sub_E}, babyclyd{sub_e}"), +        ) + +        for username_in, username_out in test_cases: +            with self.subTest(input=username_in, expected_output=username_out): +                self.assertEqual(messages.sub_clyde(username_in), username_out) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 8c1a40640..a2f0fe55d 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -44,22 +44,14 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase):          with self.assertRaises(RuntimeError):              await bad_cache.set("test", "me_up_deadman") -    def test_namespace_collision(self): -        """Test that we prevent colliding namespaces.""" -        bob_cache_1 = RedisCache() -        bob_cache_1._set_namespace("BobRoss") -        self.assertEqual(bob_cache_1._namespace, "BobRoss") - -        bob_cache_2 = RedisCache() -        bob_cache_2._set_namespace("BobRoss") -        self.assertEqual(bob_cache_2._namespace, "BobRoss_") -      async def test_set_get_item(self):          """Test that users can set and get items from the RedisDict."""          test_cases = (              ('favorite_fruit', 'melon'),              ('favorite_number', 86), -            ('favorite_fraction', 86.54) +            ('favorite_fraction', 86.54), +            ('favorite_boolean', False), +            ('other_boolean', True),          )          # Test that we can get and set different types. diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py new file mode 100644 index 000000000..5e0855704 --- /dev/null +++ b/tests/bot/utils/test_services.py @@ -0,0 +1,74 @@ +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from aiohttp import ClientConnectorError + +from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service + + +class PasteTests(unittest.IsolatedAsyncioTestCase): +    def setUp(self) -> None: +        self.http_session = MagicMock() + +    @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") +    async def test_url_and_sent_contents(self): +        """Correct url was used and post was called with expected data.""" +        response = MagicMock( +            json=AsyncMock(return_value={"key": ""}) +        ) +        self.http_session.post().__aenter__.return_value = response +        self.http_session.post.reset_mock() +        await send_to_paste_service(self.http_session, "Content") +        self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + +    @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") +    async def test_paste_returns_correct_url_on_success(self): +        """Url with specified extension is returned on successful requests.""" +        key = "paste_key" +        test_cases = ( +            (f"https://paste_service.com/{key}.txt", "txt"), +            (f"https://paste_service.com/{key}.py", "py"), +            (f"https://paste_service.com/{key}", ""), +        ) +        response = MagicMock( +            json=AsyncMock(return_value={"key": key}) +        ) +        self.http_session.post().__aenter__.return_value = response + +        for expected_output, extension in test_cases: +            with self.subTest(msg=f"Send contents with extension {repr(extension)}"): +                self.assertEqual( +                    await send_to_paste_service(self.http_session, "", extension=extension), +                    expected_output +                ) + +    async def test_request_repeated_on_json_errors(self): +        """Json with error message and invalid json are handled as errors and requests repeated.""" +        test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) +        self.http_session.post().__aenter__.return_value = response = MagicMock() +        self.http_session.post.reset_mock() + +        for error_json in test_cases: +            with self.subTest(error_json=error_json): +                response.json = AsyncMock(return_value=error_json) +                result = await send_to_paste_service(self.http_session, "") +                self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +                self.assertIsNone(result) + +            self.http_session.post.reset_mock() + +    async def test_request_repeated_on_connection_errors(self): +        """Requests are repeated in the case of connection errors.""" +        self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) +        result = await send_to_paste_service(self.http_session, "") +        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.assertIsNone(result) + +    async def test_general_error_handled_and_request_repeated(self): +        """All `Exception`s are handled, logged and request repeated.""" +        self.http_session.post = MagicMock(side_effect=Exception) +        result = await send_to_paste_service(self.http_session, "") +        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.assertLogs("bot.utils", logging.ERROR) +        self.assertIsNone(result) diff --git a/tests/helpers.py b/tests/helpers.py index faa839370..facc4e1af 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools  import logging  import unittest.mock  from asyncio import AbstractEventLoop -from typing import Iterable, Optional +from typing import Callable, Iterable, Optional  import discord  from aiohttp import ClientSession @@ -26,6 +26,24 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) +def autospec(target, *attributes: str, **kwargs) -> Callable: +    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = {'spec_set': True, 'autospec': True, **kwargs} + +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            func = patcher(func) +        return func +    return decorator + +  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. | 
