diff options
| author | 2020-06-17 21:05:47 +0100 | |
|---|---|---|
| committer | 2020-06-17 21:05:47 +0100 | |
| commit | c1312f97327733b60555644da49c0419eb6759cb (patch) | |
| tree | cbfb3f6060eb9338527466a01054315bf43a085e /tests/bot | |
| parent | Delete the loop argument from schedule_task calls (diff) | |
| parent | Merge pull request #991 from crazygmr101/feature/cooldown-tag (diff) | |
Merge branch 'master' into #364-offensive-msg-autodeletion
Diffstat (limited to 'tests/bot')
33 files changed, 1609 insertions, 432 deletions
| diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/cogs/moderation/__init__.py diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..da4e92ccc --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): +    """Tests for ban and kick command reason truncation.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = Infractions(self.bot) +        self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) +        self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) +        self.guild = MockGuild(id=4567) +        self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + +    @patch("bot.cogs.moderation.utils.get_active_infraction") +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): +        """Should truncate reason for `ctx.guild.ban`.""" +        get_active_mock.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.bot.get_cog.return_value = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.ctx.guild.ban = Mock() + +        await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) +        self.ctx.guild.ban.assert_called_once_with( +            self.target, +            reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), +            delete_message_days=0 +        ) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value +        ) + +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_kick_reason_truncation(self, post_infraction_mock): +        """Should truncate reason for `Member.kick`.""" +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.target.kick = Mock() + +        await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) +        self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value +        ) diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..f2809f40a --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for moderation logs.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = ModLog(self.bot) +        self.channel = MockTextChannel() + +    async def test_log_entry_description_truncation(self): +        """Test that embed description for ModLog entry is truncated.""" +        self.bot.get_channel.return_value = self.channel +        await self.cog.send_log_message( +            icon_url="foo", +            colour=discord.Colour.blue(), +            title="bar", +            text="foo bar" * 3000 +        ) +        embed = self.channel.send.call_args[1]["embed"] +        self.assertEqual( +            embed.description, ("foo bar" * 3000)[:2045] + "..." +        ) diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py new file mode 100644 index 000000000..ab3d0742a --- /dev/null +++ b/tests/bot/cogs/moderation/test_silence.py @@ -0,0 +1,261 @@ +import unittest +from unittest import mock +from unittest.mock import MagicMock, Mock + +from discord import PermissionOverwrite + +from bot.cogs.moderation.silence import Silence, SilenceNotifier +from bot.constants import Channels, Emojis, Guild, Roles +from tests.helpers import MockBot, MockContext, MockTextChannel + + +class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +    def setUp(self) -> None: +        self.alert_channel = MockTextChannel() +        self.notifier = SilenceNotifier(self.alert_channel) +        self.notifier.stop = self.notifier_stop_mock = Mock() +        self.notifier.start = self.notifier_start_mock = Mock() + +    def test_add_channel_adds_channel(self): +        """Channel in FirstHash with current loop is added to internal set.""" +        channel = Mock() +        with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: +            self.notifier.add_channel(channel) +        silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) + +    def test_add_channel_starts_loop(self): +        """Loop is started if `_silenced_channels` was empty.""" +        self.notifier.add_channel(Mock()) +        self.notifier_start_mock.assert_called_once() + +    def test_add_channel_skips_start_with_channels(self): +        """Loop start is not called when `_silenced_channels` is not empty.""" +        with mock.patch.object(self.notifier, "_silenced_channels"): +            self.notifier.add_channel(Mock()) +        self.notifier_start_mock.assert_not_called() + +    def test_remove_channel_removes_channel(self): +        """Channel in FirstHash is removed from `_silenced_channels`.""" +        channel = Mock() +        with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: +            self.notifier.remove_channel(channel) +        silenced_channels.__delitem__.assert_called_with(channel) + +    def test_remove_channel_stops_loop(self): +        """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" +        with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): +            self.notifier.remove_channel(Mock()) +        self.notifier_stop_mock.assert_called_once() + +    def test_remove_channel_skips_stop_with_channels(self): +        """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" +        self.notifier.remove_channel(Mock()) +        self.notifier_stop_mock.assert_not_called() + +    async def test_notifier_private_sends_alert(self): +        """Alert is sent on 15 min intervals.""" +        test_cases = (900, 1800, 2700) +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with mock.patch.object(self.notifier, "_current_loop", new=current_loop): +                    await self.notifier._notifier() +                self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") +            self.alert_channel.send.reset_mock() + +    async def test_notifier_skips_alert(self): +        """Alert is skipped on first loop or not an increment of 900.""" +        test_cases = (0, 15, 5000) +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with mock.patch.object(self.notifier, "_current_loop", new=current_loop): +                    await self.notifier._notifier() +                    self.alert_channel.send.assert_not_called() + + +class SilenceTests(unittest.IsolatedAsyncioTestCase): +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = Silence(self.bot) +        self.ctx = MockContext() +        self.cog._verified_role = None +        # Set event so command callbacks can continue. +        self.cog._get_instance_vars_event.set() + +    async def test_instance_vars_got_guild(self): +        """Bot got guild after it became available.""" +        await self.cog._get_instance_vars() +        self.bot.wait_until_guild_available.assert_called_once() +        self.bot.get_guild.assert_called_once_with(Guild.id) + +    async def test_instance_vars_got_role(self): +        """Got `Roles.verified` role from guild.""" +        await self.cog._get_instance_vars() +        guild = self.bot.get_guild() +        guild.get_role.assert_called_once_with(Roles.verified) + +    async def test_instance_vars_got_channels(self): +        """Got channels from bot.""" +        await self.cog._get_instance_vars() +        self.bot.get_channel.called_once_with(Channels.mod_alerts) +        self.bot.get_channel.called_once_with(Channels.mod_log) + +    @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") +    async def test_instance_vars_got_notifier(self, notifier): +        """Notifier was started with channel.""" +        mod_log = MockTextChannel() +        self.bot.get_channel.side_effect = (None, mod_log) +        await self.cog._get_instance_vars() +        notifier.assert_called_once_with(mod_log) +        self.bot.get_channel.side_effect = None + +    async def test_silence_sent_correct_discord_message(self): +        """Check if proper message was sent when called with duration in channel with previous state.""" +        test_cases = ( +            (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), +            (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), +            (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), +        ) +        for duration, result_message, _silence_patch_return in test_cases: +            with self.subTest( +                silence_duration=duration, +                result_message=result_message, +                starting_unsilenced_state=_silence_patch_return +            ): +                with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): +                    await self.cog.silence.callback(self.cog, self.ctx, duration) +                    self.ctx.send.assert_called_once_with(result_message) +            self.ctx.reset_mock() + +    async def test_unsilence_sent_correct_discord_message(self): +        """Check if proper message was sent when unsilencing channel.""" +        test_cases = ( +            (True, f"{Emojis.check_mark} unsilenced current channel."), +            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        ) +        for _unsilence_patch_return, result_message in test_cases: +            with self.subTest( +                starting_silenced_state=_unsilence_patch_return, +                result_message=result_message +            ): +                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): +                    await self.cog.unsilence.callback(self.cog, self.ctx) +                    self.ctx.send.assert_called_once_with(result_message) +            self.ctx.reset_mock() + +    async def test_silence_private_for_false(self): +        """Permissions are not set and `False` is returned in an already silenced channel.""" +        perm_overwrite = Mock(send_messages=False) +        channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) + +        self.assertFalse(await self.cog._silence(channel, True, None)) +        channel.set_permissions.assert_not_called() + +    async def test_silence_private_silenced_channel(self): +        """Channel had `send_message` permissions revoked.""" +        channel = MockTextChannel() +        self.assertTrue(await self.cog._silence(channel, False, None)) +        channel.set_permissions.assert_called_once() +        self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + +    async def test_silence_private_preserves_permissions(self): +        """Previous permissions were preserved when channel was silenced.""" +        channel = MockTextChannel() +        # Set up mock channel permission state. +        mock_permissions = PermissionOverwrite() +        mock_permissions_dict = dict(mock_permissions) +        channel.overwrites_for.return_value = mock_permissions +        await self.cog._silence(channel, False, None) +        new_permissions = channel.set_permissions.call_args.kwargs +        # Remove 'send_messages' key because it got changed in the method. +        del new_permissions['send_messages'] +        del mock_permissions_dict['send_messages'] +        self.assertDictEqual(mock_permissions_dict, new_permissions) + +    async def test_silence_private_notifier(self): +        """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" +        channel = MockTextChannel() +        with mock.patch.object(self.cog, "notifier", create=True): +            with self.subTest(persistent=True): +                await self.cog._silence(channel, True, None) +                self.cog.notifier.add_channel.assert_called_once() + +        with mock.patch.object(self.cog, "notifier", create=True): +            with self.subTest(persistent=False): +                await self.cog._silence(channel, False, None) +                self.cog.notifier.add_channel.assert_not_called() + +    async def test_silence_private_added_muted_channel(self): +        """Channel was added to `muted_channels` on silence.""" +        channel = MockTextChannel() +        with mock.patch.object(self.cog, "muted_channels") as muted_channels: +            await self.cog._silence(channel, False, None) +        muted_channels.add.assert_called_once_with(channel) + +    async def test_unsilence_private_for_false(self): +        """Permissions are not set and `False` is returned in an unsilenced channel.""" +        channel = Mock() +        self.assertFalse(await self.cog._unsilence(channel)) +        channel.set_permissions.assert_not_called() + +    @mock.patch.object(Silence, "notifier", create=True) +    async def test_unsilence_private_unsilenced_channel(self, _): +        """Channel had `send_message` permissions restored""" +        perm_overwrite = MagicMock(send_messages=False) +        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) +        self.assertTrue(await self.cog._unsilence(channel)) +        channel.set_permissions.assert_called_once() +        self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) + +    @mock.patch.object(Silence, "notifier", create=True) +    async def test_unsilence_private_removed_notifier(self, notifier): +        """Channel was removed from `notifier` on unsilence.""" +        perm_overwrite = MagicMock(send_messages=False) +        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) +        await self.cog._unsilence(channel) +        notifier.remove_channel.assert_called_once_with(channel) + +    @mock.patch.object(Silence, "notifier", create=True) +    async def test_unsilence_private_removed_muted_channel(self, _): +        """Channel was removed from `muted_channels` on unsilence.""" +        perm_overwrite = MagicMock(send_messages=False) +        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) +        with mock.patch.object(self.cog, "muted_channels") as muted_channels: +            await self.cog._unsilence(channel) +        muted_channels.discard.assert_called_once_with(channel) + +    @mock.patch.object(Silence, "notifier", create=True) +    async def test_unsilence_private_preserves_permissions(self, _): +        """Previous permissions were preserved when channel was unsilenced.""" +        channel = MockTextChannel() +        # Set up mock channel permission state. +        mock_permissions = PermissionOverwrite(send_messages=False) +        mock_permissions_dict = dict(mock_permissions) +        channel.overwrites_for.return_value = mock_permissions +        await self.cog._unsilence(channel) +        new_permissions = channel.set_permissions.call_args.kwargs +        # Remove 'send_messages' key because it got changed in the method. +        del new_permissions['send_messages'] +        del mock_permissions_dict['send_messages'] +        self.assertDictEqual(mock_permissions_dict, new_permissions) + +    @mock.patch("bot.cogs.moderation.silence.asyncio") +    @mock.patch.object(Silence, "_mod_alerts_channel", create=True) +    def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): +        """Task for sending an alert was created with present `muted_channels`.""" +        with mock.patch.object(self.cog, "muted_channels"): +            self.cog.cog_unload() +            alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") +            asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) + +    @mock.patch("bot.cogs.moderation.silence.asyncio") +    def test_cog_unload_skips_task_start(self, asyncio_mock): +        """No task created with no channels.""" +        self.cog.cog_unload() +        asyncio_mock.create_task.assert_not_called() + +    @mock.patch("bot.cogs.moderation.silence.with_role_check") +    @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) +    def test_cog_check(self, role_check): +        """Role check is called with `MODERATION_ROLES`""" +        self.cog.cog_check(self.ctx) +        role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index c2e143865..70aea2bab 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -1,3 +1,4 @@ +import asyncio  import unittest  from unittest import mock @@ -13,8 +14,8 @@ class TestSyncer(Syncer):      """Syncer subclass with mocks for abstract methods for testing purposes."""      name = "test" -    _get_diff = helpers.AsyncMock() -    _sync = helpers.AsyncMock() +    _get_diff = mock.AsyncMock() +    _sync = mock.AsyncMock()  class SyncerBaseTests(unittest.TestCase): @@ -29,7 +30,7 @@ class SyncerBaseTests(unittest.TestCase):              Syncer(self.bot) -class SyncerSendPromptTests(unittest.TestCase): +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):      """Tests for sending the sync confirmation prompt."""      def setUp(self): @@ -61,7 +62,6 @@ class SyncerSendPromptTests(unittest.TestCase):          return mock_channel, mock_message -    @helpers.async_test      async def test_send_prompt_edits_and_returns_message(self):          """The given message should be edited to display the prompt and then should be returned."""          msg = helpers.MockMessage() @@ -71,7 +71,6 @@ class SyncerSendPromptTests(unittest.TestCase):          self.assertIn("content", msg.edit.call_args[1])          self.assertEqual(ret_val, msg) -    @helpers.async_test      async def test_send_prompt_gets_dev_core_channel(self):          """The dev-core channel should be retrieved if an extant message isn't given."""          subtests = ( @@ -86,8 +85,7 @@ class SyncerSendPromptTests(unittest.TestCase):                  method.assert_called_once_with(constants.Channels.dev_core) -    @helpers.async_test -    async def test_send_prompt_returns_None_if_channel_fetch_fails(self): +    async def test_send_prompt_returns_none_if_channel_fetch_fails(self):          """None should be returned if there's an HTTPException when fetching the channel."""          self.bot.get_channel.return_value = None          self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") @@ -96,7 +94,6 @@ class SyncerSendPromptTests(unittest.TestCase):          self.assertIsNone(ret_val) -    @helpers.async_test      async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):          """A new message mentioning core devs should be sent and returned if message isn't given."""          for mock_ in (self.mock_get_channel, self.mock_fetch_channel): @@ -108,7 +105,6 @@ class SyncerSendPromptTests(unittest.TestCase):                  self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])                  self.assertEqual(ret_val, mock_message) -    @helpers.async_test      async def test_send_prompt_adds_reactions(self):          """The message should have reactions for confirmation added."""          extant_message = helpers.MockMessage() @@ -129,7 +125,7 @@ class SyncerSendPromptTests(unittest.TestCase):                  mock_message.add_reaction.assert_has_calls(calls) -class SyncerConfirmationTests(unittest.TestCase): +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):      """Tests for waiting for a sync confirmation reaction on the prompt."""      def setUp(self): @@ -211,13 +207,12 @@ class SyncerConfirmationTests(unittest.TestCase):                  ret_val = self.syncer._reaction_check(*args)                  self.assertFalse(ret_val) -    @helpers.async_test      async def test_wait_for_confirmation(self):          """The message should always be edited and only return True if the emoji is a check mark."""          subtests = (              (constants.Emojis.check_mark, True, None),              ("InVaLiD", False, None), -            (None, False, TimeoutError), +            (None, False, asyncio.TimeoutError),          )          for emoji, ret_val, side_effect in subtests: @@ -251,14 +246,13 @@ class SyncerConfirmationTests(unittest.TestCase):                      self.assertIs(actual_return, ret_val) -class SyncerSyncTests(unittest.TestCase): +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for main function orchestrating the sync."""      def setUp(self):          self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))          self.syncer = TestSyncer(self.bot) -    @helpers.async_test      async def test_sync_respects_confirmation_result(self):          """The sync should abort if confirmation fails and continue if confirmed."""          mock_message = helpers.MockMessage() @@ -274,7 +268,7 @@ class SyncerSyncTests(unittest.TestCase):                  diff = _Diff({1, 2, 3}, {4, 5}, None)                  self.syncer._get_diff.return_value = diff -                self.syncer._get_confirmation_result = helpers.AsyncMock( +                self.syncer._get_confirmation_result = mock.AsyncMock(                      return_value=(confirmed, message)                  ) @@ -289,7 +283,6 @@ class SyncerSyncTests(unittest.TestCase):                  else:                      self.syncer._sync.assert_not_called() -    @helpers.async_test      async def test_sync_diff_size(self):          """The diff size should be correctly calculated."""          subtests = ( @@ -303,7 +296,7 @@ class SyncerSyncTests(unittest.TestCase):              with self.subTest(size=size, diff=diff):                  self.syncer._get_diff.reset_mock()                  self.syncer._get_diff.return_value = diff -                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) +                self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))                  guild = helpers.MockGuild()                  await self.syncer.sync(guild) @@ -312,7 +305,6 @@ class SyncerSyncTests(unittest.TestCase):                  self.syncer._get_confirmation_result.assert_called_once()                  self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) -    @helpers.async_test      async def test_sync_message_edited(self):          """The message should be edited if one was sent, even if the sync has an API error."""          subtests = ( @@ -324,7 +316,7 @@ class SyncerSyncTests(unittest.TestCase):          for message, side_effect, should_edit in subtests:              with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):                  self.syncer._sync.side_effect = side_effect -                self.syncer._get_confirmation_result = helpers.AsyncMock( +                self.syncer._get_confirmation_result = mock.AsyncMock(                      return_value=(True, message)                  ) @@ -335,7 +327,6 @@ class SyncerSyncTests(unittest.TestCase):                      message.edit.assert_called_once()                      self.assertIn("content", message.edit.call_args[1]) -    @helpers.async_test      async def test_sync_confirmation_context_redirect(self):          """If ctx is given, a new message should be sent and author should be ctx's author."""          mock_member = helpers.MockMember() @@ -349,7 +340,10 @@ class SyncerSyncTests(unittest.TestCase):                  if ctx is not None:                      ctx.send.return_value = message -                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) +                # Make sure `_get_diff` returns a MagicMock, not an AsyncMock +                self.syncer._get_diff.return_value = mock.MagicMock() + +                self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))                  guild = helpers.MockGuild()                  await self.syncer.sync(guild, ctx) @@ -362,16 +356,15 @@ class SyncerSyncTests(unittest.TestCase):                  self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)      @mock.patch.object(constants.Sync, "max_diff", new=3) -    @helpers.async_test      async def test_confirmation_result_small_diff(self):          """Should always return True and the given message if the diff size is too small."""          author = helpers.MockMember()          expected_message = helpers.MockMessage() -        for size in (3, 2): +        for size in (3, 2):  # pragma: no cover              with self.subTest(size=size): -                self.syncer._send_prompt = helpers.AsyncMock() -                self.syncer._wait_for_confirmation = helpers.AsyncMock() +                self.syncer._send_prompt = mock.AsyncMock() +                self.syncer._wait_for_confirmation = mock.AsyncMock()                  coro = self.syncer._get_confirmation_result(size, author, expected_message)                  result, actual_message = await coro @@ -382,7 +375,6 @@ class SyncerSyncTests(unittest.TestCase):                  self.syncer._wait_for_confirmation.assert_not_called()      @mock.patch.object(constants.Sync, "max_diff", new=3) -    @helpers.async_test      async def test_confirmation_result_large_diff(self):          """Should return True if confirmed and False if _send_prompt fails or aborted."""          author = helpers.MockMember() @@ -394,10 +386,10 @@ class SyncerSyncTests(unittest.TestCase):              (False, mock_message, False, "aborted"),          ) -        for expected_result, expected_message, confirmed, msg in subtests: +        for expected_result, expected_message, confirmed, msg in subtests:  # pragma: no cover              with self.subTest(msg=msg): -                self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) -                self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) +                self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) +                self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)                  coro = self.syncer._get_confirmation_result(4, author)                  actual_result, actual_message = await coro diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 98c9afc0d..14fd909c4 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -11,19 +11,7 @@ from tests import helpers  from tests.base import CommandTestCase -class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): -    """ -    A MagicMock subclass to mock Syncer objects. - -    Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` -    instances. For more information, see the `MockGuild` docstring. -    """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=Syncer, **kwargs) - - -class SyncExtensionTests(unittest.TestCase): +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase):      """Tests for the sync extension."""      @staticmethod @@ -34,22 +22,21 @@ class SyncExtensionTests(unittest.TestCase):          bot.add_cog.assert_called_once() -class SyncCogTestCase(unittest.TestCase): +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):      """Base class for Sync cog tests. Sets up patches for syncers."""      def setUp(self):          self.bot = helpers.MockBot() -        # These patch the type. When the type is called, a MockSyncer instanced is returned. -        # MockSyncer is needed so that our custom AsyncMock is used. -        # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.          self.role_syncer_patcher = mock.patch(              "bot.cogs.sync.syncers.RoleSyncer", -            new=mock.MagicMock(return_value=MockSyncer()) +            autospec=Syncer, +            spec_set=True          )          self.user_syncer_patcher = mock.patch(              "bot.cogs.sync.syncers.UserSyncer", -            new=mock.MagicMock(return_value=MockSyncer()) +            autospec=Syncer, +            spec_set=True          )          self.RoleSyncer = self.role_syncer_patcher.start()          self.UserSyncer = self.user_syncer_patcher.start() @@ -72,13 +59,13 @@ class SyncCogTestCase(unittest.TestCase):  class SyncCogTests(SyncCogTestCase):      """Tests for the Sync cog.""" -    @mock.patch.object(sync.Sync, "sync_guild") +    @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock)      def test_sync_cog_init(self, sync_guild):          """Should instantiate syncers and run a sync for the guild."""          # Reset because a Sync cog was already instantiated in setUp.          self.RoleSyncer.reset_mock()          self.UserSyncer.reset_mock() -        self.bot.loop.create_task.reset_mock() +        self.bot.loop.create_task = mock.MagicMock()          mock_sync_guild_coro = mock.MagicMock()          sync_guild.return_value = mock_sync_guild_coro @@ -90,7 +77,6 @@ class SyncCogTests(SyncCogTestCase):          sync_guild.assert_called_once_with()          self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) -    @helpers.async_test      async def test_sync_cog_sync_guild(self):          """Roles and users should be synced only if a guild is successfully retrieved."""          for guild in (helpers.MockGuild(), None): @@ -126,14 +112,12 @@ class SyncCogTests(SyncCogTestCase):              json=updated_information,          ) -    @helpers.async_test      async def test_sync_cog_patch_user(self):          """A PATCH request should be sent and 404 errors ignored."""          for side_effect in (None, self.response_error(404)):              with self.subTest(side_effect=side_effect):                  await self.patch_user_helper(side_effect) -    @helpers.async_test      async def test_sync_cog_patch_user_non_404(self):          """A PATCH request should be sent and the error raised if it's not a 404."""          with self.assertRaises(ResponseCodeError): @@ -145,9 +129,8 @@ class SyncCogListenerTests(SyncCogTestCase):      def setUp(self):          super().setUp() -        self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) +        self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) -    @helpers.async_test      async def test_sync_cog_on_guild_role_create(self):          """A POST request should be sent with the new role's data."""          self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -164,7 +147,6 @@ class SyncCogListenerTests(SyncCogTestCase):          self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) -    @helpers.async_test      async def test_sync_cog_on_guild_role_delete(self):          """A DELETE request should be sent."""          self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) @@ -174,7 +156,6 @@ class SyncCogListenerTests(SyncCogTestCase):          self.bot.api_client.delete.assert_called_once_with("bot/roles/99") -    @helpers.async_test      async def test_sync_cog_on_guild_role_update(self):          """A PUT request should be sent if the colour, name, permissions, or position changes."""          self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -212,7 +193,6 @@ class SyncCogListenerTests(SyncCogTestCase):                      else:                          self.bot.api_client.put.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_member_remove(self):          """Member should patched to set in_guild as False."""          self.assertTrue(self.cog.on_member_remove.__cog_listener__) @@ -225,7 +205,6 @@ class SyncCogListenerTests(SyncCogTestCase):              updated_information={"in_guild": False}          ) -    @helpers.async_test      async def test_sync_cog_on_member_update_roles(self):          """Members should be patched if their roles have changed."""          self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -240,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase):          data = {"roles": sorted(role.id for role in after_member.roles)}          self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) -    @helpers.async_test      async def test_sync_cog_on_member_update_other(self):          """Members should not be patched if other attributes have changed."""          self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -262,7 +240,6 @@ class SyncCogListenerTests(SyncCogTestCase):                  self.cog.patch_user.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_user_update(self):          """A user should be patched only if the name, discriminator, or avatar changes."""          self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -270,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase):          before_data = {              "name": "old name",              "discriminator": "1234", -            "avatar": "old avatar",              "bot": False,          }          subtests = (              (True, "name", "name", "new name", "new name"),              (True, "discriminator", "discriminator", "8765", 8765), -            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),              (False, "bot", "bot", True, True),          ) @@ -318,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase):          )          data = { -            "avatar_hash": member.avatar,              "discriminator": int(member.discriminator),              "id": member.id,              "in_guild": True, @@ -341,7 +315,6 @@ class SyncCogListenerTests(SyncCogTestCase):          return data -    @helpers.async_test      async def test_sync_cog_on_member_join(self):          """Should PUT user's data or POST it if the user doesn't exist."""          for side_effect in (None, self.response_error(404)): @@ -354,7 +327,6 @@ class SyncCogListenerTests(SyncCogTestCase):                  else:                      self.bot.api_client.post.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_member_join_non_404(self):          """ResponseCodeError should be re-raised if status code isn't a 404."""          with self.assertRaises(ResponseCodeError): @@ -366,7 +338,6 @@ class SyncCogListenerTests(SyncCogTestCase):  class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):      """Tests for the commands in the Sync cog.""" -    @helpers.async_test      async def test_sync_roles_command(self):          """sync() should be called on the RoleSyncer."""          ctx = helpers.MockContext() @@ -374,7 +345,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) -    @helpers.async_test      async def test_sync_users_command(self):          """sync() should be called on the UserSyncer."""          ctx = helpers.MockContext() @@ -382,7 +352,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) -    def test_commands_require_admin(self): +    async def test_commands_require_admin(self):          """The sync commands should only run if the author has the administrator permission."""          cmds = (              self.cog.sync_group, @@ -392,4 +362,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          for cmd in cmds:              with self.subTest(cmd=cmd): -                self.assertHasPermissionsCheck(cmd, {"administrator": True}) +                await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 14fb2577a..79eee98f4 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -18,7 +18,7 @@ def fake_role(**kwargs):      return kwargs -class RoleSyncerDiffTests(unittest.TestCase): +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between roles in the DB and roles in the Guild cache."""      def setUp(self): @@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          return guild -    @helpers.async_test      async def test_empty_diff_for_identical_roles(self):          """No differences should be found if the roles in the guild and DB are identical."""          self.bot.api_client.get.return_value = [fake_role()] @@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_updated_roles(self):          """Only updated roles should be added to the 'updated' set of the diff."""          updated_role = fake_role(id=41, name="new") @@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_roles(self):          """Only new roles should be added to the 'created' set of the diff."""          new_role = fake_role(id=41, name="new") @@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_deleted_roles(self):          """Only deleted roles should be added to the 'deleted' set of the diff."""          deleted_role = fake_role(id=61, name="deleted") @@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_updated_and_deleted_roles(self):          """When roles are added, updated, and removed, all of them are returned properly."""          new = fake_role(id=41, name="new") @@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -class RoleSyncerSyncTests(unittest.TestCase): +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync roles."""      def setUp(self):          self.bot = helpers.MockBot()          self.syncer = RoleSyncer(self.bot) -    @helpers.async_test      async def test_sync_created_roles(self):          """Only POST requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] @@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase):          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_updated_roles(self):          """Only PUT requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] @@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase):          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_deleted_roles(self):          """Only DELETE requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 421bf6bb6..002a947ad 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -10,14 +10,13 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("avatar_hash", None)      kwargs.setdefault("roles", (666,))      kwargs.setdefault("in_guild", True)      return kwargs -class UserSyncerDiffTests(unittest.TestCase): +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache."""      def setUp(self): @@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.TestCase):          for member in members:              member = member.copy() -            member["avatar"] = member.pop("avatar_hash")              del member["in_guild"]              mock_member = helpers.MockMember(**member) @@ -42,7 +40,6 @@ class UserSyncerDiffTests(unittest.TestCase):          return guild -    @helpers.async_test      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned."""          guild = self.get_guild() @@ -52,7 +49,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical."""          self.bot.api_client.get.return_value = [fake_user()] @@ -63,7 +59,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_updated_users(self):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") @@ -76,7 +71,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_users(self):          """Only new users should be added to the 'created' set of the diff."""          new_user = fake_user(id=99, name="new") @@ -89,7 +83,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`."""          leaving_user = fake_user(id=63, in_guild=False) @@ -102,7 +95,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") @@ -117,7 +109,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_empty_diff_for_db_users_not_in_guild(self):          """When the DB knows a user the guild doesn't, no difference is found."""          self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] @@ -129,14 +120,13 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -class UserSyncerSyncTests(unittest.TestCase): +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self):          self.bot = helpers.MockBot()          self.syncer = UserSyncer(self.bot) -    @helpers.async_test      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] @@ -152,7 +142,6 @@ class UserSyncerSyncTests(unittest.TestCase):          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_updated_users(self):          """Only PUT requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): +    """Test the AntiMalware cog.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() +        self.cog = antimalware.AntiMalware(self.bot) +        self.message = MockMessage() + +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should not be deleted""" +        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_message_without_attachment(self): +        """Messages without attachments should result in no action.""" +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_direct_message_with_attachment(self): +        """Direct messages should have no action taken.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.guild = None + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_message_with_illegal_extension_gets_deleted(self): +        """A message containing an illegal extension should send an embed.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_called_once() + +    async def test_message_send_by_staff(self): +        """A message send by a member of staff should be ignored.""" +        staff_role = MockRole(id=STAFF_ROLES[0]) +        self.message.author.roles.append(staff_role) +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site""" +        attachment = MockAttachment(filename="python.py") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") + +        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt file should result in the correct embed.""" +        attachment = MockAttachment(filename="python.txt") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.TXT_EMBED_DESCRIPTION = Mock() +        antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        cmd_channel = self.bot.get_channel(Channels.bot_commands) + +        self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) +        antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + +    async def test_other_disallowed_extention_embed_description(self): +        """Test the description for a non .py/.txt disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        meta_channel = self.bot.get_channel(Channels.meta) + +        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( +            blocked_extensions_str=".disallowed", +            meta_channel_mention=meta_channel.mention +        ) + +    async def test_removing_deleted_message_logs(self): +        """Removing an already deleted message logs the correct message""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) +        self.message.delete.assert_called_once() + +    async def test_message_with_illegal_attachment_logs(self): +        """Deleting a message with an illegal attachment should result in a log.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) + +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (AntiMalwareConfig.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], [".disallowed"]), +            ([".disallowed"], [".disallowed"]), +            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] +                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): +    """Tests setup of the `AntiMalware` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        antimalware.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py new file mode 100644 index 000000000..fdda59a8f --- /dev/null +++ b/tests/bot/cogs/test_cogs.py @@ -0,0 +1,80 @@ +"""Test suite for general tests which apply to all cogs.""" + +import importlib +import pkgutil +import typing as t +import unittest +from collections import defaultdict +from types import ModuleType +from unittest import mock + +from discord.ext import commands + +from bot import cogs + + +class CommandNameTests(unittest.TestCase): +    """Tests for shadowing command names and aliases.""" + +    @staticmethod +    def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: +        """An iterator that recursively walks through `cog`'s commands and subcommands.""" +        # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. +        for command in cog.__cog_commands__: +            if command.parent is None: +                yield command +                if isinstance(command, commands.GroupMixin): +                    # Annoyingly it returns duplicates for each alias so use a set to fix that +                    yield from set(command.walk_commands()) + +    @staticmethod +    def walk_modules() -> t.Iterator[ModuleType]: +        """Yield imported modules from the bot.cogs subpackage.""" +        def on_error(name: str) -> t.NoReturn: +            raise ImportError(name=name)  # pragma: no cover + +        # The mock prevents asyncio.get_event_loop() from being called. +        with mock.patch("discord.ext.tasks.loop"): +            for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): +                if not module.ispkg: +                    yield importlib.import_module(module.name) + +    @staticmethod +    def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: +        """Yield all cogs defined in an extension.""" +        for obj in module.__dict__.values(): +            # Check if it's a class type cause otherwise issubclass() may raise a TypeError. +            is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) +            if is_cog and obj.__module__ == module.__name__: +                yield obj + +    @staticmethod +    def get_qualified_names(command: commands.Command) -> t.List[str]: +        """Return a list of all qualified names, including aliases, for the `command`.""" +        names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] +        names.append(command.qualified_name) + +        return names + +    def get_all_commands(self) -> t.Iterator[commands.Command]: +        """Yield all commands for all cogs in all extensions.""" +        for module in self.walk_modules(): +            for cog in self.walk_cogs(module): +                for cmd in self.walk_commands(cog): +                    yield cmd + +    def test_names_dont_shadow(self): +        """Names and aliases of commands should be unique.""" +        all_names = defaultdict(list) +        for cmd in self.get_all_commands(): +            func_name = f"{cmd.module}.{cmd.callback.__qualname__}" + +            for name in self.get_qualified_names(cmd): +                with self.subTest(cmd=func_name, name=name): +                    if name in all_names:  # pragma: no cover +                        conflicts = ", ".join(all_names.get(name, "")) +                        self.fail( +                            f"Name '{name}' of the command {func_name} conflicts with {conflicts}." +                        ) + +                all_names[name].append(func_name) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 5b0a3b8c3..a8c0107c6 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio  import logging  import typing  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch  import discord @@ -14,7 +14,7 @@ from tests import helpers  MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):      """Tests for DuckPond functionality."""      @classmethod @@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestCase):          self.assertEqual(cog.bot, bot)          self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) -        bot.loop.create_loop.called_once_with(cog.fetch_webhook()) +        bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())      def test_fetch_webhook_succeeds_without_connectivity_issues(self):          """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" @@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):              with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):                  self.assertEqual(expected_return, actual_return) -    @helpers.async_test      async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):          """The `has_green_checkmark` method should only return `True` if one is present."""          test_cases = ( @@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):          nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]          return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) -    @helpers.async_test      async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):          """The `count_ducks` method should return the number of unique staffers who gave a duck."""          test_cases = ( @@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):              with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):                  self.assertEqual(expected_count, actual_count) -    @helpers.async_test      async def test_relay_message_correctly_relays_content_and_attachments(self):          """The `relay_message` method should correctly relay message content and attachments."""          send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" @@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):          )          for message, expect_webhook_call, expect_attachment_call in test_values: -            with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: -                with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: +            with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: +                with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:                      with self.subTest(clean_content=message.clean_content, attachments=message.attachments):                          await self.cog.relay_message(message) @@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):                          message.add_reaction.assert_called_once_with(self.checkmark_emoji) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) -    @helpers.async_test +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):          self.cog.webhook = helpers.MockAsyncWebhook()          log = logging.getLogger("bot.cogs.duck_pond") -        for side_effect in side_effects: +        for side_effect in side_effects:  # pragma: no cover              send_attachments.side_effect = side_effect -            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: +            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:                  with self.subTest(side_effect=type(side_effect).__name__):                      with self.assertNotLogs(logger=log, level=logging.ERROR):                          await self.cog.relay_message(message)                      self.assertEqual(send_webhook.call_count, 2) -    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) -    @helpers.async_test +    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):          payload.emoji.name = emoji_name          return payload -    @helpers.async_test      async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):          """The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""          test_values = ( @@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):          return channel, message, member, payload -    @helpers.async_test      async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):          """The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""          channel_id = 1234 @@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):                  channel.fetch_message.reset_mock()      @patch(f"{MODULE_PATH}.DuckPond.is_staff") -    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) +    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)      def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):          """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""          channel_id = 31415926535 @@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):          # Assert that we've made it past `self.is_staff`          is_staff.assert_called_once() -    @helpers.async_test      async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):          """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""          test_cases = ( @@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):          payload.emoji = self.duck_pond_emoji          for duck_count, should_relay in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: -                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: +                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                      count_ducks.return_value = duck_count                      with self.subTest(duck_count=duck_count, should_relay=should_relay):                          await self.cog.on_raw_reaction_add(payload) @@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):                          if should_relay:                              relay_message.assert_called_once_with(message) -    @helpers.async_test      async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):          """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""          checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) @@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):              (constants.DuckPond.threshold + 1, True),          )          for duck_count, should_re_add_checkmark in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                  count_ducks.return_value = duck_count                  with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):                      await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 8443cfe71..79c0e0ad3 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,10 +7,9 @@ import discord  from bot import constants  from bot.cogs import information -from bot.decorators import InChannelCheckFailure +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  COG_PATH = "bot.cogs.information.Information" @@ -34,7 +33,7 @@ class InformationCogTests(unittest.TestCase):          """Test if the `role_info` command correctly returns the `moderator_role`."""          self.ctx.guild.roles.append(self.moderator_role) -        self.cog.roles_info.can_run = helpers.AsyncMock() +        self.cog.roles_info.can_run = unittest.mock.AsyncMock()          self.cog.roles_info.can_run.return_value = True          coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -45,10 +44,9 @@ class InformationCogTests(unittest.TestCase):          _, kwargs = self.ctx.send.call_args          embed = kwargs.pop('embed') -        self.assertEqual(embed.title, "Role information") +        self.assertEqual(embed.title, "Role information (Total 1 role)")          self.assertEqual(embed.colour, discord.Colour.blurple()) -        self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n") -        self.assertEqual(embed.footer.text, "Total roles: 1") +        self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")      def test_role_info_command(self):          """Tests the `role info` command.""" @@ -72,7 +70,7 @@ class InformationCogTests(unittest.TestCase):          self.ctx.guild.roles.append([dummy_role, admin_role]) -        self.cog.role_info.can_run = helpers.AsyncMock() +        self.cog.role_info.can_run = unittest.mock.AsyncMock()          self.cog.role_info.can_run.return_value = True          coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -150,14 +148,18 @@ class InformationCogTests(unittest.TestCase):                  Voice region: {self.ctx.guild.region}                  Features: {', '.join(self.ctx.guild.features)} -                **Counts** -                Members: {self.ctx.guild.member_count:,} -                Roles: {len(self.ctx.guild.roles)} +                **Channel counts**                  Category channels: 1                  Text channels: 1                  Voice channels: 1 +                Staff channels: 0 + +                **Member counts** +                Members: {self.ctx.guild.member_count:,} +                Staff members: 0 +                Roles: {len(self.ctx.guild.roles)} -                **Members** +                **Member statuses**                  {constants.Emojis.status_online} 2                  {constants.Emojis.status_idle} 1                  {constants.Emojis.status_dnd} 4 @@ -174,7 +176,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot)          self.member = helpers.MockMember(id=1234) @@ -345,10 +347,10 @@ class UserEmbedTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +362,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Mr. Hemlock") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_nick_in_title_if_available(self):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +374,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_ignores_everyone_role(self):          """Created `!user` embeds should not contain mention of the @everyone-role."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +389,8 @@ class UserEmbedTests(unittest.TestCase):          self.assertIn("&Admins", embed.description)          self.assertNotIn("&Everyone", embed.description) -    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) -    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):          """The embed should contain expanded infractions and nomination info in mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +425,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):          """The embed should contain only basic infraction data outside of mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +456,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):          """The embed should be created with the colour of the top role, if a top role is available."""          ctx = helpers.MockContext() @@ -467,7 +469,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):          """The embed should be created with a blurple colour if the user has no assigned roles."""          ctx = helpers.MockContext() @@ -477,7 +479,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour.blurple()) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):          """The embed thumbnail should be set to the user's avatar in `png` format."""          ctx = helpers.MockContext() @@ -486,7 +488,7 @@ class UserEmbedTests(unittest.TestCase):          user.avatar_url_as.return_value = "avatar url"          embed = asyncio.run(self.cog.create_user_embed(ctx, user)) -        user.avatar_url_as.assert_called_once_with(format="png") +        user.avatar_url_as.assert_called_once_with(static_format="png")          self.assertEqual(embed.thumbnail.url, "avatar url") @@ -526,10 +528,10 @@ class UserCommandTests(unittest.TestCase):          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))          msg = "Sorry, but you may only use this command within <#50>." -        with self.assertRaises(InChannelCheckFailure, msg=msg): +        with self.assertRaises(InWhitelistCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx)) -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -542,7 +544,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -555,7 +557,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -568,7 +570,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.moderator)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 985bc66a1..cf9adbee0 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,74 +1,79 @@  import asyncio  import logging  import unittest -from functools import partial -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch +from discord.ext import commands + +from bot import constants  from bot.cogs import snekbox  from bot.cogs.snekbox import Snekbox -from bot.constants import URLs -from tests.helpers import ( -    AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test -) +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser -class SnekboxTests(unittest.TestCase): +class SnekboxTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          """Add mocked bot and cog to the instance."""          self.bot = MockBot() - -        self.mocked_post = MagicMock() -        self.mocked_post.json = AsyncMock() -        self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post)) -          self.cog = Snekbox(bot=self.bot) -    @async_test      async def test_post_eval(self):          """Post the eval code to the URLs.snekbox_eval_api endpoint.""" -        self.mocked_post.json.return_value = {'lemon': 'AI'} +        resp = MagicMock() +        resp.json = AsyncMock(return_value="return") -        self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'}) -        self.bot.http_session.post.assert_called_once_with( -            URLs.snekbox_eval_api, +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager + +        self.assertEqual(await self.cog.post_eval("import random"), "return") +        self.bot.http_session.post.assert_called_with( +            constants.URLs.snekbox_eval_api,              json={"input": "import random"},              raise_for_status=True          ) +        resp.json.assert_awaited_once() -    @async_test      async def test_upload_output_reject_too_long(self):          """Reject output longer than MAX_PASTE_LEN."""          result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))          self.assertEqual(result, "too long to upload") -    @async_test      async def test_upload_output(self):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" -        key = "RainbowDash" -        self.mocked_post.json.return_value = {"key": key} +        key = "MarkDiamond" +        resp = MagicMock() +        resp.json = AsyncMock(return_value={"key": key}) + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(              await self.cog.upload_output("My awesome output"), -            URLs.paste_service.format(key=key) +            constants.URLs.paste_service.format(key=key)          ) -        self.bot.http_session.post.assert_called_once_with( -            URLs.paste_service.format(key="documents"), +        self.bot.http_session.post.assert_called_with( +            constants.URLs.paste_service.format(key="documents"),              data="My awesome output",              raise_for_status=True          ) -    @async_test      async def test_upload_output_gracefully_fallback_if_exception_during_request(self):          """Output upload gracefully fallback if the upload fail.""" -        self.mocked_post.json.side_effect = Exception +        resp = MagicMock() +        resp.json = AsyncMock(side_effect=Exception) + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager +          log = logging.getLogger("bot.cogs.snekbox")          with self.assertLogs(logger=log, level='ERROR'):              await self.cog.upload_output('My awesome output!') -    @async_test      async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):          """Output upload gracefully fallback if there is no key entry in the response body.""" -        self.mocked_post.json.return_value = {}          self.assertEqual((await self.cog.upload_output('My awesome output!')), None)      def test_prepare_input(self): @@ -95,15 +100,15 @@ class SnekboxTests(unittest.TestCase):                  self.assertEqual(actual, expected)      @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) -    def test_get_results_message_invalid_signal(self, mock_Signals: Mock): +    def test_get_results_message_invalid_signal(self, mock_signals: Mock):          self.assertEqual(              self.cog.get_results_message({'stdout': '', 'returncode': 127}),              ('Your eval job has completed with return code 127', '')          )      @patch('bot.cogs.snekbox.Signals') -    def test_get_results_message_valid_signal(self, mock_Signals: Mock): -        mock_Signals.return_value.name = 'SIGTEST' +    def test_get_results_message_valid_signal(self, mock_signals: Mock): +        mock_signals.return_value.name = 'SIGTEST'          self.assertEqual(              self.cog.get_results_message({'stdout': '', 'returncode': 127}),              ('Your eval job has completed with return code 127 (SIGTEST)', '') @@ -121,7 +126,6 @@ class SnekboxTests(unittest.TestCase):                  actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})                  self.assertEqual(actual, expected) -    @async_test      async def test_format_output(self):          """Test output formatting."""          self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') @@ -172,7 +176,6 @@ class SnekboxTests(unittest.TestCase):              with self.subTest(msg=testname, case=case, expected=expected):                  self.assertEqual(await self.cog.format_output(case), expected) -    @async_test      async def test_eval_command_evaluate_once(self):          """Test the eval command procedure."""          ctx = MockContext() @@ -186,7 +189,6 @@ class SnekboxTests(unittest.TestCase):          self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_once_with(ctx, response) -    @async_test      async def test_eval_command_evaluate_twice(self):          """Test the eval and re-eval command procedure."""          ctx = MockContext() @@ -201,7 +203,6 @@ class SnekboxTests(unittest.TestCase):          self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_with(ctx, response) -    @async_test      async def test_eval_command_reject_two_eval_at_the_same_time(self):          """Test if the eval command rejects an eval if the author already have a running eval."""          ctx = MockContext() @@ -214,22 +215,19 @@ class SnekboxTests(unittest.TestCase):              "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"          ) -    @async_test      async def test_eval_command_call_help(self):          """Test if the eval command call the help command if no code is provided.""" -        ctx = MockContext() -        ctx.invoke = AsyncMock() +        ctx = MockContext(command="sentinel")          await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') -        ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") +        ctx.send_help.assert_called_once_with("sentinel") -    @async_test      async def test_send_eval(self):          """Test the send_eval function."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +          self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -244,14 +242,13 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})          self.cog.format_output.assert_called_once_with('') -    @async_test      async def test_send_eval_with_paste_link(self):          """Test the send_eval function with a too long output that generate a paste link."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +          self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -267,14 +264,12 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})          self.cog.format_output.assert_called_once_with('Way too long beard') -    @async_test      async def test_send_eval_with_non_zero_eval(self):          """Test the send_eval function with a code returning a non-zero code."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))          self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})          self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))          self.cog.get_status_emoji = MagicMock(return_value=':nope!:') @@ -289,25 +284,33 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})          self.cog.format_output.assert_not_called() -    @async_test -    async def test_continue_eval_does_continue(self): +    @patch("bot.cogs.snekbox.partial") +    async def test_continue_eval_does_continue(self, partial_mock):          """Test that the continue_eval function does continue if required conditions are met."""          ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))          response = MockMessage(delete=AsyncMock()) -        new_msg = MockMessage(content='!e NewCode') +        new_msg = MockMessage()          self.bot.wait_for.side_effect = ((None, new_msg), None) +        expected = "NewCode" +        self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)          actual = await self.cog.continue_eval(ctx, response) -        self.assertEqual(actual, 'NewCode') -        self.bot.wait_for.has_calls( -            call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10), -            call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +        self.cog.get_code.assert_awaited_once_with(new_msg) +        self.assertEqual(actual, expected) +        self.bot.wait_for.assert_has_awaits( +            ( +                call( +                    'message_edit', +                    check=partial_mock(snekbox.predicate_eval_message_edit, ctx), +                    timeout=snekbox.REEVAL_TIMEOUT, +                ), +                call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +            )          )          ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)          ctx.message.clear_reactions.assert_called_once()          response.delete.assert_called_once() -    @async_test      async def test_continue_eval_does_not_continue(self):          ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))          self.bot.wait_for.side_effect = asyncio.TimeoutError @@ -316,6 +319,32 @@ class SnekboxTests(unittest.TestCase):          self.assertEqual(actual, None)          ctx.message.clear_reactions.assert_called_once() +    async def test_get_code(self): +        """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" +        prefix = constants.Bot.prefix +        subtests = ( +            (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"), +            (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None), +            (MagicMock(spec=commands.Command), f"{prefix}tags get foo"), +            (None, "print(123)") +        ) + +        for command, content, *expected_code in subtests: +            if not expected_code: +                expected_code = content +            else: +                [expected_code] = expected_code + +            with self.subTest(content=content, expected_code=expected_code): +                self.bot.get_context.reset_mock() +                self.bot.get_context.return_value = MockContext(command=command) +                message = MockMessage(content=content) + +                actual_code = await self.cog.get_code(message) + +                self.bot.get_context.assert_awaited_once_with(message) +                self.assertEqual(actual_code, expected_code) +      def test_predicate_eval_message_edit(self):          """Test the predicate_eval_message_edit function."""          msg0 = MockMessage(id=1, content='abc') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..a10124d2d 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,56 +1,89 @@ -import asyncio -import logging  import unittest +from re import Match +from unittest import mock  from unittest.mock import MagicMock  from discord import Colour -from bot.cogs.token_remover import ( -    DELETION_MESSAGE_TEMPLATE, -    TokenRemover, -    setup as setup_cog, -) -from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from bot import constants +from bot.cogs import token_remover +from bot.cogs.moderation import ModLog +from bot.cogs.token_remover import Token, TokenRemover +from tests.helpers import MockBot, MockMessage, autospec -class TokenRemoverTests(unittest.TestCase): +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):      """Tests the `TokenRemover` cog."""      def setUp(self):          """Adds the cog, a bot, and a message to the instance for usage in tests."""          self.bot = MockBot() -        self.bot.get_cog.return_value = MagicMock() -        self.bot.get_cog.return_value.send_log_message = AsyncMock()          self.cog = TokenRemover(bot=self.bot) -        self.msg = MockMessage(id=555, content='') -        self.msg.author.__str__ = MagicMock() -        self.msg.author.__str__.return_value = 'lemon' -        self.msg.author.bot = False -        self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' -        self.msg.author.id = 42 -        self.msg.author.mention = '@lemon' +        self.msg = MockMessage(id=555, content="hello world")          self.msg.channel.mention = "#lemonade-stand" +        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) +        self.msg.author.avatar_url_as.return_value = "picture-lemon.png" -    def test_is_valid_user_id_is_true_for_numeric_content(self): -        """A string decoding to numeric characters is a valid user ID.""" -        # MTIz = base64(123) -        self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) +    def test_is_valid_user_id_valid(self): +        """Should consider user IDs valid if they decode entirely to ASCII digits.""" +        ids = ( +            "NDcyMjY1OTQzMDYyNDEzMzMy", +            "NDc1MDczNjI5Mzk5NTQ3OTA0", +            "NDY3MjIzMjMwNjUwNzc3NjQx", +        ) + +        for user_id in ids: +            with self.subTest(user_id=user_id): +                result = TokenRemover.is_valid_user_id(user_id) +                self.assertTrue(result) -    def test_is_valid_user_id_is_false_for_alphabetic_content(self): -        """A string decoding to alphabetic characters is not a valid user ID.""" -        # YWJj = base64(abc) -        self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) +    def test_is_valid_user_id_invalid(self): +        """Should consider non-digit and non-ASCII IDs invalid.""" +        ids = ( +            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), +            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), +            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), +            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), +            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) -    def test_is_valid_timestamp_is_true_for_valid_timestamps(self): -        """A string decoding to a valid timestamp should be recognized as such.""" -        self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) +        for user_id, msg in ids: +            with self.subTest(msg=msg): +                result = TokenRemover.is_valid_user_id(user_id) +                self.assertFalse(result) -    def test_is_valid_timestamp_is_false_for_invalid_values(self): -        """A string not decoding to a valid timestamp should not be recognized as such.""" -        # MTIz = base64(123) -        self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) +    def test_is_valid_timestamp_valid(self): +        """Should consider timestamps valid if they're greater than the Discord epoch.""" +        timestamps = ( +            "XsyRkw", +            "Xrim9Q", +            "XsyR-w", +            "XsySD_", +            "Dn9r_A", +        ) + +        for timestamp in timestamps: +            with self.subTest(timestamp=timestamp): +                result = TokenRemover.is_valid_timestamp(timestamp) +                self.assertTrue(result) + +    def test_is_valid_timestamp_invalid(self): +        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" +        timestamps = ( +            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), +            ("ew", "123"), +            ("AoIKgA", "42076800"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for timestamp, msg in timestamps: +            with self.subTest(msg=msg): +                result = TokenRemover.is_valid_timestamp(timestamp) +                self.assertFalse(result)      def test_mod_log_property(self):          """The `mod_log` property should ask the bot to return the `ModLog` cog.""" @@ -58,74 +91,206 @@ class TokenRemoverTests(unittest.TestCase):          self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)          self.bot.get_cog.assert_called_once_with('ModLog') -    def test_ignores_bot_messages(self): -        """When the message event handler is called with a bot message, nothing is done.""" +    async def test_on_message_edit_uses_on_message(self): +        """The edit listener should delegate handling of the message to the normal listener.""" +        self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + +        await self.cog.on_message_edit(MockMessage(), self.msg) +        self.cog.on_message.assert_awaited_once_with(self.msg) + +    @autospec(TokenRemover, "find_token_in_message", "take_action") +    async def test_on_message_takes_action(self, find_token_in_message, take_action): +        """Should take action if a valid token is found when a message is sent.""" +        cog = TokenRemover(self.bot) +        found_token = "foobar" +        find_token_in_message.return_value = found_token + +        await cog.on_message(self.msg) + +        find_token_in_message.assert_called_once_with(self.msg) +        take_action.assert_awaited_once_with(cog, self.msg, found_token) + +    @autospec(TokenRemover, "find_token_in_message", "take_action") +    async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): +        """Shouldn't take action if a valid token isn't found when a message is sent.""" +        cog = TokenRemover(self.bot) +        find_token_in_message.return_value = False + +        await cog.on_message(self.msg) + +        find_token_in_message.assert_called_once_with(self.msg) +        take_action.assert_not_awaited() + +    @autospec("bot.cogs.token_remover", "TOKEN_RE") +    def test_find_token_ignores_bot_messages(self, token_re): +        """The token finder should ignore messages authored by bots."""          self.msg.author.bot = True -        coroutine = self.cog.on_message(self.msg) -        self.assertIsNone(asyncio.run(coroutine)) - -    def test_ignores_messages_without_tokens(self): -        """Messages without anything looking like a token are ignored.""" -        for content in ('', 'lemon wins'): -            with self.subTest(content=content): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                self.assertIsNone(asyncio.run(coroutine)) - -    def test_ignores_messages_with_invalid_tokens(self): -        """Messages with values that are invalid tokens are ignored.""" -        for content in ('foo.bar.baz', 'x.y.'): -            with self.subTest(content=content): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                self.assertIsNone(asyncio.run(coroutine)) - -    def test_censors_valid_tokens(self): -        """Valid tokens are censored.""" -        cases = ( -            # (content, censored_token) -            ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) +        token_re.finditer.assert_not_called() + +    @autospec("bot.cogs.token_remover", "TOKEN_RE") +    def test_find_token_no_matches(self, token_re): +        """None should be returned if the regex matches no tokens in a message.""" +        token_re.finditer.return_value = () + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") +    @autospec("bot.cogs.token_remover", "Token") +    @autospec("bot.cogs.token_remover", "TOKEN_RE") +    def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): +        """The first match with a valid user ID and timestamp should be returned as a `Token`.""" +        matches = [ +            mock.create_autospec(Match, spec_set=True, instance=True), +            mock.create_autospec(Match, spec_set=True, instance=True), +        ] +        tokens = [ +            mock.create_autospec(Token, spec_set=True, instance=True), +            mock.create_autospec(Token, spec_set=True, instance=True), +        ] + +        token_re.finditer.return_value = matches +        token_cls.side_effect = tokens +        is_valid_id.side_effect = (False, True)  # The 1st match will be invalid, 2nd one valid. +        is_valid_timestamp.return_value = True + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertEqual(tokens[1], return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") +    @autospec("bot.cogs.token_remover", "Token") +    @autospec("bot.cogs.token_remover", "TOKEN_RE") +    def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): +        """None should be returned if no matches have valid user IDs or timestamps.""" +        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] +        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) +        is_valid_id.return_value = False +        is_valid_timestamp.return_value = False + +        return_value = TokenRemover.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) +        token_re.finditer.assert_called_once_with(self.msg.content) + +    def test_regex_invalid_tokens(self): +        """Messages without anything looking like a token are not matched.""" +        tokens = ( +            "", +            "lemon wins", +            "..", +            "x.y", +            "x.y.", +            ".y.z", +            ".y.", +            "..z", +            "x..z", +            " . . ", +            "\n.\n.\n", +            "hellö.world.bye", +            "base64.nötbåse64.morebase64", +            "19jd3J.dfkm3d.€víł§tüff", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = token_remover.TOKEN_RE.findall(token) +                self.assertEqual(len(results), 0) + +    def test_regex_valid_tokens(self): +        """Messages that look like tokens should be matched.""" +        # Don't worry, these tokens have been invalidated. +        tokens = ( +            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", +            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", +            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",          ) -        for content, censored_token in cases: -            with self.subTest(content=content, censored_token=censored_token): -                self.msg.content = content -                coroutine = self.cog.on_message(self.msg) -                with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: -                    self.assertIsNone(asyncio.run(coroutine))  # no return value - -                [line] = cm.output -                log_message = ( -                    "Censored a seemingly valid token sent by " -                    "lemon (`42`) in #lemonade-stand, " -                    f"token was `{censored_token}`" -                ) -                self.assertIn(log_message, line) - -                self.msg.delete.assert_called_once_with() -                self.msg.channel.send.assert_called_once_with( -                    DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') -                ) -                self.bot.get_cog.assert_called_with('ModLog') -                self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') - -                mod_log = self.bot.get_cog.return_value -                mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) -                mod_log.send_log_message.assert_called_once_with( -                    icon_url=Icons.token_removed, -                    colour=Colour(Colours.soft_red), -                    title="Token removed!", -                    text=log_message, -                    thumbnail='picture-lemon.png', -                    channel_id=Channels.mod_alerts -                ) - - -class TokenRemoverSetupTests(unittest.TestCase): -    """Tests setup of the `TokenRemover` cog.""" - -    def test_setup(self): -        """Setup of the extension should call add_cog.""" +        for token in tokens: +            with self.subTest(token=token): +                results = token_remover.TOKEN_RE.fullmatch(token) +                self.assertIsNotNone(results, f"{token} was not matched by the regex") + +    def test_regex_matches_multiple_valid(self): +        """Should support multiple matches in the middle of a string.""" +        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" +        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" +        message = f"garbage {token_1} hello {token_2} world" + +        results = token_remover.TOKEN_RE.finditer(message) +        results = [match[0] for match in results] +        self.assertCountEqual((token_1, token_2), results) + +    @autospec("bot.cogs.token_remover", "LOG_MESSAGE") +    def test_format_log_message(self, log_message): +        """Should correctly format the log message with info from the message and token.""" +        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        log_message.format.return_value = "Howdy" + +        return_value = TokenRemover.format_log_message(self.msg, token) + +        self.assertEqual(return_value, log_message.format.return_value) +        log_message.format.assert_called_once_with( +            author=self.msg.author, +            author_id=self.msg.author.id, +            channel=self.msg.channel.mention, +            user_id=token.user_id, +            timestamp=token.timestamp, +            hmac="x" * len(token.hmac), +        ) + +    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) +    @autospec("bot.cogs.token_remover", "log") +    @autospec(TokenRemover, "format_log_message") +    async def test_take_action(self, format_log_message, logger, mod_log_property): +        """Should delete the message and send a mod log.""" +        cog = TokenRemover(self.bot) +        mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) +        token = mock.create_autospec(Token, spec_set=True, instance=True) +        log_msg = "testing123" + +        mod_log_property.return_value = mod_log +        format_log_message.return_value = log_msg + +        await cog.take_action(self.msg, token) + +        self.msg.delete.assert_called_once_with() +        self.msg.channel.send.assert_called_once_with( +            token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) +        ) + +        format_log_message.assert_called_once_with(self.msg, token) +        logger.debug.assert_called_with(log_msg) +        self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + +        mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) +        mod_log.send_log_message.assert_called_once_with( +            icon_url=constants.Icons.token_removed, +            colour=Colour(constants.Colours.soft_red), +            title="Token removed!", +            text=log_msg, +            thumbnail=self.msg.author.avatar_url_as.return_value, +            channel_id=constants.Channels.mod_alerts +        ) + + +class TokenRemoverExtensionTests(unittest.TestCase): +    """Tests for the token_remover extension.""" + +    @autospec("bot.cogs.token_remover", "TokenRemover") +    def test_extension_setup(self, cog): +        """The TokenRemover cog should be added."""          bot = MockBot() -        setup_cog(bot) +        token_remover.setup(bot) + +        cog.assert_called_once_with(bot)          bot.add_cog.assert_called_once() +        self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index 36c986fe1..0d570f5a3 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple):      n_violations: int -class RuleTest(unittest.TestCase, metaclass=ABCMeta): +class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):      """      Abstract class for antispam rule test cases. @@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta):      @abstractmethod      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          """Give expected relevant messages for `case`.""" -        raise NotImplementedError +        raise NotImplementedError  # pragma: no cover      @abstractmethod      def get_report(self, case: DisallowedCase) -> str:          """Give expected error report for `case`.""" -        raise NotImplementedError +        raise NotImplementedError  # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index e54b4b5b8..d7e779221 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import attachments  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_attachments: int) -> MockMessage: @@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest):          self.apply = attachments.apply          self.config = {"max": 5, "interval": 10} -    @async_test      async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( @@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py index 72f0be0c7..03682966b 100644 --- a/tests/bot/rules/test_burst.py +++ b/tests/bot/rules/test_burst.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import burst  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest):          self.apply = burst.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py index 47367a5f8..3275143d5 100644 --- a/tests/bot/rules/test_burst_shared.py +++ b/tests/bot/rules/test_burst_shared.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import burst_shared  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest):          self.apply = burst_shared.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """          Cases that do not violate the rule. @@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py index 7cc36f49e..f1e3c76a7 100644 --- a/tests/bot/rules/test_chars.py +++ b/tests/bot/rules/test_chars.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import chars  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, n_chars: int) -> MockMessage: @@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest):              "interval": 10,          } -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of chars within limit."""          cases = ( @@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the total amount of chars exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 0239b0b00..9a72723e2 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import discord_emojis  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> @@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest):          self.apply = discord_emojis.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of discord emojis within limit."""          cases = ( @@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of discord emojis."""          cases = ( diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py index 59e0fb6ef..9bd886a77 100644 --- a/tests/bot/rules/test_duplicates.py +++ b/tests/bot/rules/test_duplicates.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import duplicates  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, content: str) -> MockMessage: @@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest):          self.apply = duplicates.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with too many duplicate messages from the same author."""          cases = ( diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 3c3f90e5f..b091bd9d7 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import links  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_links: int) -> MockMessage: @@ -21,7 +21,6 @@ class LinksTests(RuleTest):              "interval": 10          } -    @async_test      async def test_links_within_limit(self):          """Messages with an allowed amount of links."""          cases = ( @@ -34,7 +33,6 @@ class LinksTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ebcdabac6..6444532f2 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_mentions: int) -> MockMessage: @@ -20,7 +20,6 @@ class TestMentions(RuleTest):              "interval": 10,          } -    @async_test      async def test_mentions_within_limit(self):          """Messages with an allowed amount of mentions."""          cases = ( @@ -32,7 +31,6 @@ class TestMentions(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py index d61c4609d..e35377773 100644 --- a/tests/bot/rules/test_newlines.py +++ b/tests/bot/rules/test_newlines.py @@ -2,7 +2,7 @@ from typing import Iterable, List  from bot.rules import newlines  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, newline_groups: List[int]) -> MockMessage: @@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest):              "interval": 10,          } -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_total(self):          """Cases which violate the rule by having too many newlines in total."""          cases = ( @@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest):          self.apply = newlines.apply          self.config = {"max": 5, "max_consecutive": 3, "interval": 10} -    @async_test      async def test_disallows_messages_consecutive(self):          """Cases which violate the rule due to having too many consecutive newlines."""          cases = ( diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py index b339cccf7..26c05d527 100644 --- a/tests/bot/rules/test_role_mentions.py +++ b/tests/bot/rules/test_role_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import role_mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, n_mentions: int) -> MockMessage: @@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest):          self.apply = role_mentions.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of role mentions within limit."""          cases = ( @@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of role mentions."""          cases = ( diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index bdfcc73e4..99e942813 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -2,10 +2,9 @@ import unittest  from unittest.mock import MagicMock  from bot import api -from tests.helpers import async_test -class APIClientTests(unittest.TestCase): +class APIClientTests(unittest.IsolatedAsyncioTestCase):      """Tests for the bot's API client."""      @classmethod @@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase):          """The event loop should not be running by default."""          self.assertFalse(api.loop_is_running()) -    @async_test      async def test_loop_is_running_in_async_context(self):          """The event loop should be running in an async context."""          self.assertTrue(api.loop_is_running()) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index dae7c066c..f10d6fbe8 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,14 +1,40 @@  import inspect +import typing  import unittest  from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: +    """ +    Return True if `value` is an instance of the type represented by `annotation`. + +    This doesn't account for things like Unions or checking for homogenous types in collections. +    """ +    origin = typing.get_origin(annotation) + +    # This is done in case a bare e.g. `typing.List` is used. +    # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. +    # `get_origin()` does this normalisation for us. +    type_ = annotation if origin is None else origin + +    return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: +    """Return True if `value` is an instance of any type in `types`.""" +    for type_ in types: +        if is_annotation_instance(value, type_): +            return True + +    return False + +  class ConstantsTests(unittest.TestCase):      """Tests for our constants."""      def test_section_configuration_matches_type_specification(self): -        """The section annotations should match the actual types of the sections.""" +        """"The section annotations should match the actual types of the sections."""          sections = (              cls @@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):          )          for section in sections:              for name, annotation in section.__annotations__.items(): -                with self.subTest(section=section, name=name, annotation=annotation): +                with self.subTest(section=section.__name__, name=name, annotation=annotation):                      value = getattr(section, name) +                    origin = typing.get_origin(annotation) +                    annotation_args = typing.get_args(annotation) +                    failure_msg = f"{value} is not an instance of {annotation}" -                    if getattr(annotation, '_name', None) in ('Dict', 'List'): -                        self.skipTest("Cannot validate containers yet.") - -                    self.assertIsInstance(value, annotation) +                    if origin is typing.Union: +                        is_instance = is_any_instance(value, annotation_args) +                        self.assertTrue(is_instance, failure_msg) +                    else: +                        is_instance = is_annotation_instance(value, annotation) +                        self.assertTrue(is_instance, failure_msg) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 1e5ca62ae..c42111f3f 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -1,5 +1,5 @@ -import asyncio  import datetime +import re  import unittest  from unittest.mock import MagicMock, patch @@ -8,6 +8,7 @@ from discord.ext.commands import BadArgument  from bot.converters import (      Duration, +    HushDurationConverter,      ISODateTime,      TagContentConverter,      TagNameConverter, @@ -15,7 +16,7 @@ from bot.converters import (  ) -class ConverterTests(unittest.TestCase): +class ConverterTests(unittest.IsolatedAsyncioTestCase):      """Tests our custom argument converters."""      @classmethod @@ -25,7 +26,7 @@ class ConverterTests(unittest.TestCase):          cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') -    def test_tag_content_converter_for_valid(self): +    async def test_tag_content_converter_for_valid(self):          """TagContentConverter should return correct values for valid input."""          test_values = (              ('hello', 'hello'), @@ -34,10 +35,10 @@ class ConverterTests(unittest.TestCase):          for content, expected_conversion in test_values:              with self.subTest(content=content, expected_conversion=expected_conversion): -                conversion = asyncio.run(TagContentConverter.convert(self.context, content)) +                conversion = await TagContentConverter.convert(self.context, content)                  self.assertEqual(conversion, expected_conversion) -    def test_tag_content_converter_for_invalid(self): +    async def test_tag_content_converter_for_invalid(self):          """TagContentConverter should raise the proper exception for invalid input."""          test_values = (              ('', "Tag contents should not be empty, or filled with whitespace."), @@ -46,10 +47,10 @@ class ConverterTests(unittest.TestCase):          for value, exception_message in test_values:              with self.subTest(tag_content=value, exception_message=exception_message): -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(TagContentConverter.convert(self.context, value)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await TagContentConverter.convert(self.context, value) -    def test_tag_name_converter_for_valid(self): +    async def test_tag_name_converter_for_valid(self):          """TagNameConverter should return the correct values for valid tag names."""          test_values = (              ('tracebacks', 'tracebacks'), @@ -59,10 +60,10 @@ class ConverterTests(unittest.TestCase):          for name, expected_conversion in test_values:              with self.subTest(name=name, expected_conversion=expected_conversion): -                conversion = asyncio.run(TagNameConverter.convert(self.context, name)) +                conversion = await TagNameConverter.convert(self.context, name)                  self.assertEqual(conversion, expected_conversion) -    def test_tag_name_converter_for_invalid(self): +    async def test_tag_name_converter_for_invalid(self):          """TagNameConverter should raise the correct exception for invalid tag names."""          test_values = (              ('👋', "Don't be ridiculous, you can't use that character!"), @@ -74,29 +75,29 @@ class ConverterTests(unittest.TestCase):          for invalid_name, exception_message in test_values:              with self.subTest(invalid_name=invalid_name, exception_message=exception_message): -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(TagNameConverter.convert(self.context, invalid_name)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await TagNameConverter.convert(self.context, invalid_name) -    def test_valid_python_identifier_for_valid(self): +    async def test_valid_python_identifier_for_valid(self):          """ValidPythonIdentifier returns valid identifiers unchanged."""          test_values = ('foo', 'lemon')          for name in test_values:              with self.subTest(identifier=name): -                conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                conversion = await ValidPythonIdentifier.convert(self.context, name)                  self.assertEqual(name, conversion) -    def test_valid_python_identifier_for_invalid(self): +    async def test_valid_python_identifier_for_invalid(self):          """ValidPythonIdentifier raises the proper exception for invalid identifiers."""          test_values = ('nested.stuff', '#####')          for name in test_values:              with self.subTest(identifier=name):                  exception_message = f'`{name}` is not a valid Python identifier' -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await ValidPythonIdentifier.convert(self.context, name) -    def test_duration_converter_for_valid(self): +    async def test_duration_converter_for_valid(self):          """Duration returns the correct `datetime` for valid duration strings."""          test_values = (              # Simple duration strings @@ -158,35 +159,35 @@ class ConverterTests(unittest.TestCase):                  mock_datetime.utcnow.return_value = self.fixed_utc_now                  with self.subTest(duration=duration, duration_dict=duration_dict): -                    converted_datetime = asyncio.run(converter.convert(self.context, duration)) +                    converted_datetime = await converter.convert(self.context, duration)                      self.assertEqual(converted_datetime, expected_datetime) -    def test_duration_converter_for_invalid(self): +    async def test_duration_converter_for_invalid(self):          """Duration raises the right exception for invalid duration strings."""          test_values = (              # Units in wrong order -            ('1d1w'), -            ('1s1y'), +            '1d1w', +            '1s1y',              # Duplicated units -            ('1 year 2 years'), -            ('1 M 10 minutes'), +            '1 year 2 years', +            '1 M 10 minutes',              # Unknown substrings -            ('1MVes'), -            ('1y3breads'), +            '1MVes', +            '1y3breads',              # Missing amount -            ('ym'), +            'ym',              # Incorrect whitespace -            (" 1y"), -            ("1S "), -            ("1y  1m"), +            " 1y", +            "1S ", +            "1y  1m",              # Garbage -            ('Guido van Rossum'), -            ('lemon lemon lemon lemon lemon lemon lemon'), +            'Guido van Rossum', +            'lemon lemon lemon lemon lemon lemon lemon',          )          converter = Duration() @@ -194,10 +195,21 @@ class ConverterTests(unittest.TestCase):          for invalid_duration in test_values:              with self.subTest(invalid_duration=invalid_duration):                  exception_message = f'`{invalid_duration}` is not a valid duration string.' -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(converter.convert(self.context, invalid_duration)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, invalid_duration) -    def test_isodatetime_converter_for_valid(self): +    @patch("bot.converters.datetime") +    async def test_duration_converter_out_of_range(self, mock_datetime): +        """Duration converter should raise BadArgument if datetime raises a ValueError.""" +        mock_datetime.__add__.side_effect = ValueError +        mock_datetime.utcnow.return_value = mock_datetime + +        duration = f"{datetime.MAXYEAR}y" +        exception_message = f"`{duration}` results in a datetime outside the supported range." +        with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +            await Duration().convert(self.context, duration) + +    async def test_isodatetime_converter_for_valid(self):          """ISODateTime converter returns correct datetime for valid datetime string."""          test_values = (              # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` @@ -242,32 +254,61 @@ class ConverterTests(unittest.TestCase):          for datetime_string, expected_dt in test_values:              with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt): -                converted_dt = asyncio.run(converter.convert(self.context, datetime_string)) +                converted_dt = await converter.convert(self.context, datetime_string)                  self.assertIsNone(converted_dt.tzinfo)                  self.assertEqual(converted_dt, expected_dt) -    def test_isodatetime_converter_for_invalid(self): +    async def test_isodatetime_converter_for_invalid(self):          """ISODateTime converter raises the correct exception for invalid datetime strings."""          test_values = (              # Make sure it doesn't interfere with the Duration converter -            ('1Y'), -            ('1d'), -            ('1H'), +            '1Y', +            '1d', +            '1H',              # Check if it fails when only providing the optional time part -            ('10:10:10'), -            ('10:00'), +            '10:10:10', +            '10:00',              # Invalid date format -            ('19-01-01'), +            '19-01-01',              # Other non-valid strings -            ('fisk the tag master'), +            'fisk the tag master',          )          converter = ISODateTime()          for datetime_string in test_values:              with self.subTest(datetime_string=datetime_string):                  exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(converter.convert(self.context, datetime_string)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, datetime_string) + +    async def test_hush_duration_converter_for_valid(self): +        """HushDurationConverter returns correct value for minutes duration or `"forever"` strings.""" +        test_values = ( +            ("0", 0), +            ("15", 15), +            ("10", 10), +            ("5m", 5), +            ("5M", 5), +            ("forever", None), +        ) +        converter = HushDurationConverter() +        for minutes_string, expected_minutes in test_values: +            with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes): +                converted = await converter.convert(self.context, minutes_string) +                self.assertEqual(expected_minutes, converted) + +    async def test_hush_duration_converter_for_invalid(self): +        """HushDurationConverter raises correct exception for invalid minutes duration strings.""" +        test_values = ( +            ("16", "Duration must be at most 15 minutes."), +            ("10d", "10d is not a valid minutes duration."), +            ("-1", "-1 is not a valid minutes duration."), +        ) +        converter = HushDurationConverter() +        for invalid_minutes_string, exception_message in test_values: +            with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message): +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, invalid_minutes_string) diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py new file mode 100644 index 000000000..3d450caa0 --- /dev/null +++ b/tests/bot/test_decorators.py @@ -0,0 +1,147 @@ +import collections +import unittest +import unittest.mock + +from bot import constants +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) + + +class InWhitelistTests(unittest.TestCase): +    """Tests for the `in_whitelist` check.""" + +    @classmethod +    def setUpClass(cls): +        """Set up helpers that only need to be defined once.""" +        cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456) +        cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654) +        cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666) +        cls.dm_channel = helpers.MockDMChannel() + +        cls.non_staff_member = helpers.MockMember() +        cls.staff_role = helpers.MockRole(id=121212) +        cls.staff_member = helpers.MockMember(roles=(cls.staff_role,)) + +        cls.channels = (cls.bot_commands.id,) +        cls.categories = (cls.help_channel.category_id,) +        cls.roles = (cls.staff_role.id,) + +    def test_predicate_returns_true_for_whitelisted_context(self): +        """The predicate should return `True` if a whitelisted context was passed to it.""" +        test_cases = ( +            InWhitelistTestCase( +                kwargs={"channels": self.channels}, +                ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), +                description="In whitelisted channels by members without whitelisted roles", +            ), +            InWhitelistTestCase( +                kwargs={"redirect": self.bot_commands.id}, +                ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), +                description="`redirect` should be implicitly added to `channels`", +            ), +            InWhitelistTestCase( +                kwargs={"categories": self.categories}, +                ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member), +                description="Whitelisted category without whitelisted role", +            ), +            InWhitelistTestCase( +                kwargs={"roles": self.roles}, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member), +                description="Whitelisted role outside of whitelisted channel/category" +            ), +            InWhitelistTestCase( +                kwargs={ +                    "channels": self.channels, +                    "categories": self.categories, +                    "roles": self.roles, +                    "redirect": self.bot_commands, +                }, +                ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member), +                description="Case with all whitelist kwargs used", +            ), +        ) + +        for test_case in test_cases: +            # patch `commands.check` with a no-op lambda that just returns the predicate passed to it +            # so we can test the predicate that was generated from the specified kwargs. +            with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): +                predicate = in_whitelist(**test_case.kwargs) + +            with self.subTest(test_description=test_case.description): +                self.assertTrue(predicate(test_case.ctx)) + +    def test_predicate_raises_exception_for_non_whitelisted_context(self): +        """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context.""" +        test_cases = ( +            # Failing check with explicit `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": self.bot_commands.id, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check with an explicit redirect channel", +            ), + +            # Failing check with implicit `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check with an implicit redirect channel", +            ), + +            # Failing check without `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": None, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check without a redirect channel", +            ), + +            # Command issued in DM channel +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": None, +                }, +                ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me), +                description="Commands issued in DM channel should be rejected", +            ), +        ) + +        for test_case in test_cases: +            if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None: +                # There are two cases in which we have a redirect channel: +                #   1. No redirect channel was passed; the default value of `bot_commands` is used +                #   2. An explicit `redirect` is set that is "not None" +                redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands) +                redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +            else: +                # If an explicit `None` was passed for `redirect`, there is no redirect channel +                redirect_message = "" + +            exception_message = f"You are not allowed to use that command{redirect_message}." + +            # patch `commands.check` with a no-op lambda that just returns the predicate passed to it +            # so we can test the predicate that was generated from the specified kwargs. +            with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): +                predicate = in_whitelist(**test_case.kwargs) + +            with self.subTest(test_description=test_case.description): +                with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message): +                    predicate(test_case.ctx) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py deleted file mode 100644 index d7bcc3ba6..000000000 --- a/tests/bot/test_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import unittest - -from bot import utils - - -class CaseInsensitiveDictTests(unittest.TestCase): -    """Tests for the `CaseInsensitiveDict` container.""" - -    def test_case_insensitive_key_access(self): -        """Tests case insensitive key access and storage.""" -        instance = utils.CaseInsensitiveDict() - -        key = 'LEMON' -        value = 'trees' - -        instance[key] = value -        self.assertIn(key, instance) -        self.assertEqual(instance.get(key), value) -        self.assertEqual(instance.get(key.casefold()), value) -        self.assertEqual(instance.pop(key.casefold()), value) -        self.assertNotIn(key, instance) -        self.assertNotIn(key.casefold(), instance) - -        instance.setdefault(key, value) -        del instance[key] -        self.assertNotIn(key, instance) - -    def test_initialization_from_kwargs(self): -        """Tests creating the dictionary from keyword arguments.""" -        instance = utils.CaseInsensitiveDict({'FOO': 'bar'}) -        self.assertEqual(instance['foo'], 'bar') - -    def test_update_from_other_mapping(self): -        """Tests updating the dictionary from another mapping.""" -        instance = utils.CaseInsensitiveDict() -        instance.update({'FOO': 'bar'}) -        self.assertEqual(instance['foo'], 'bar') diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 9610771e5..de72e5748 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,6 +1,8 @@  import unittest +from unittest.mock import MagicMock  from bot.utils import checks +from bot.utils.checks import InWhitelistCheckFailure  from tests.helpers import MockContext, MockRole @@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):          self.ctx.author.roles.append(MockRole(id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) -    def test_in_channel_check_for_correct_channel(self): -        self.ctx.channel.id = 42 -        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_correct_channel(self): +        """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" +        channel_id = 3 +        self.ctx.channel.id = channel_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id])) -    def test_in_channel_check_for_incorrect_channel(self): -        self.ctx.channel.id = 42 + 10 -        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_incorrect_channel(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match.""" +        self.ctx.channel.id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, [4]) + +    def test_in_whitelist_check_correct_category(self): +        """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list.""" +        category_id = 3 +        self.ctx.channel.category_id = category_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id])) + +    def test_in_whitelist_check_incorrect_category(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match.""" +        self.ctx.channel.category_id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, categories=[4]) + +    def test_in_whitelist_check_correct_role(self): +        """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6])) + +    def test_in_whitelist_check_incorrect_role(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, roles=[4]) + +    def test_in_whitelist_check_fail_silently(self): +        """`in_whitelist_check` test no exception raised if `fail_silently` is `True`""" +        self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True)) + +    def test_in_whitelist_check_complex(self): +        """`in_whitelist_check` test with multiple parameters""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.ctx.channel.category_id = 3 +        self.ctx.channel.id = 5 +        self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2])) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py new file mode 100644 index 000000000..a2f0fe55d --- /dev/null +++ b/tests/bot/utils/test_redis_cache.py @@ -0,0 +1,265 @@ +import asyncio +import unittest + +import fakeredis.aioredis + +from bot.utils import RedisCache +from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError +from tests import helpers + + +class RedisCacheTests(unittest.IsolatedAsyncioTestCase): +    """Tests the RedisCache class from utils.redis_dict.py.""" + +    async def asyncSetUp(self):  # noqa: N802 +        """Sets up the objects that only have to be initialized once.""" +        self.bot = helpers.MockBot() +        self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() + +        # Okay, so this is necessary so that we can create a clean new +        # class for every test method, and we want that because it will +        # ensure we get a fresh loop, which is necessary for test_increment_lock +        # to be able to pass. +        class DummyCog: +            """A dummy cog, for dummies.""" + +            redis = RedisCache() + +            def __init__(self, bot: helpers.MockBot): +                self.bot = bot + +        self.cog = DummyCog(self.bot) + +        await self.cog.redis.clear() + +    def test_class_attribute_namespace(self): +        """Test that RedisDict creates a namespace automatically for class attributes.""" +        self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") + +    async def test_class_attribute_required(self): +        """Test that errors are raised when not assigned as a class attribute.""" +        bad_cache = RedisCache() +        self.assertIs(bad_cache._namespace, None) + +        with self.assertRaises(RuntimeError): +            await bad_cache.set("test", "me_up_deadman") + +    async def test_set_get_item(self): +        """Test that users can set and get items from the RedisDict.""" +        test_cases = ( +            ('favorite_fruit', 'melon'), +            ('favorite_number', 86), +            ('favorite_fraction', 86.54), +            ('favorite_boolean', False), +            ('other_boolean', True), +        ) + +        # Test that we can get and set different types. +        for test in test_cases: +            await self.cog.redis.set(*test) +            self.assertEqual(await self.cog.redis.get(test[0]), test[1]) + +        # Test that .get allows a default value +        self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + +    async def test_set_item_type(self): +        """Test that .set rejects keys and values that are not permitted.""" +        fruits = ["lemon", "melon", "apple"] + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(fruits, "nice") + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(4.23, "nice") + +    async def test_delete_item(self): +        """Test that .delete allows us to delete stuff from the RedisCache.""" +        # Add an item and verify that it gets added +        await self.cog.redis.set("internet", "firetruck") +        self.assertEqual(await self.cog.redis.get("internet"), "firetruck") + +        # Delete that item and verify that it gets deleted +        await self.cog.redis.delete("internet") +        self.assertIs(await self.cog.redis.get("internet"), None) + +    async def test_contains(self): +        """Test that we can check membership with .contains.""" +        await self.cog.redis.set('favorite_country', "Burkina Faso") + +        self.assertIs(await self.cog.redis.contains('favorite_country'), True) +        self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) + +    async def test_items(self): +        """Test that the RedisDict can be iterated.""" +        # Set up our test cases in the Redis cache +        test_cases = [ +            ('favorite_turtle', 'Donatello'), +            ('second_favorite_turtle', 'Leonardo'), +            ('third_favorite_turtle', 'Raphael'), +        ] +        for key, value in test_cases: +            await self.cog.redis.set(key, value) + +        # Consume the AsyncIterator into a regular list, easier to compare that way. +        redis_items = [item for item in await self.cog.redis.items()] + +        # These sequences are probably in the same order now, but probably +        # isn't good enough for tests. Let's not rely on .hgetall always +        # returning things in sequence, and just sort both lists to be safe. +        redis_items = sorted(redis_items) +        test_cases = sorted(test_cases) + +        # If these are equal now, everything works fine. +        self.assertSequenceEqual(test_cases, redis_items) + +    async def test_length(self): +        """Test that we can get the correct .length from the RedisDict.""" +        await self.cog.redis.set('one', 1) +        await self.cog.redis.set('two', 2) +        await self.cog.redis.set('three', 3) +        self.assertEqual(await self.cog.redis.length(), 3) + +        await self.cog.redis.set('four', 4) +        self.assertEqual(await self.cog.redis.length(), 4) + +    async def test_to_dict(self): +        """Test that the .to_dict method returns a workable dictionary copy.""" +        copy = await self.cog.redis.to_dict() +        local_copy = {key: value for key, value in await self.cog.redis.items()} +        self.assertIs(type(copy), dict) +        self.assertDictEqual(copy, local_copy) + +    async def test_clear(self): +        """Test that the .clear method removes the entire hash.""" +        await self.cog.redis.set('teddy', 'with me') +        await self.cog.redis.set('in my dreams', 'you have a weird hat') +        self.assertEqual(await self.cog.redis.length(), 2) + +        await self.cog.redis.clear() +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_pop(self): +        """Test that we can .pop an item from the RedisDict.""" +        await self.cog.redis.set('john', 'was afraid') + +        self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') +        self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_update(self): +        """Test that we can .update the RedisDict with multiple items.""" +        await self.cog.redis.set("reckfried", "lona") +        await self.cog.redis.set("bel air", "prince") +        await self.cog.redis.update({ +            "reckfried": "jona", +            "mega": "hungry, though", +        }) + +        result = { +            "reckfried": "jona", +            "bel air": "prince", +            "mega": "hungry, though", +        } +        self.assertDictEqual(await self.cog.redis.to_dict(), result) + +    def test_typestring_conversion(self): +        """Test the typestring-related helper functions.""" +        conversion_tests = ( +            (12, "i|12"), +            (12.4, "f|12.4"), +            ("cowabunga", "s|cowabunga"), +        ) + +        # Test conversion to typestring +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) + +        # Test conversion from typestrings +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) + +        # Test that exceptions are raised on invalid input +        with self.assertRaises(TypeError): +            self.cog.redis._value_to_typestring(["internet"]) +            self.cog.redis._value_from_typestring("o|firedog") + +    async def test_increment_decrement(self): +        """Test .increment and .decrement methods.""" +        await self.cog.redis.set("entropic", 5) +        await self.cog.redis.set("disentropic", 12.5) + +        # Test default increment +        await self.cog.redis.increment("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 6) + +        # Test default decrement +        await self.cog.redis.decrement("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # Test float increment with float +        await self.cog.redis.increment("disentropic", 2.0) +        self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) + +        # Test float increment with int +        await self.cog.redis.increment("disentropic", 2) +        self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) + +        # Test negative increments, because why not. +        await self.cog.redis.increment("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 0) + +        # Negative decrements? Sure. +        await self.cog.redis.decrement("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # What about if we use a negative float to decrement an int? +        # This should convert the type into a float. +        await self.cog.redis.decrement("entropic", -2.5) +        self.assertEqual(await self.cog.redis.get("entropic"), 7.5) + +        # Let's test that they raise the right errors +        with self.assertRaises(KeyError): +            await self.cog.redis.increment("doesn't_exist!") + +        await self.cog.redis.set("stringthing", "stringthing") +        with self.assertRaises(TypeError): +            await self.cog.redis.increment("stringthing") + +    async def test_increment_lock(self): +        """Test that we can't produce a race condition in .increment.""" +        await self.cog.redis.set("test_key", 0) +        tasks = [] + +        # Increment this a lot in different tasks +        for _ in range(100): +            task = asyncio.create_task( +                self.cog.redis.increment("test_key", 1) +            ) +            tasks.append(task) +        await asyncio.gather(*tasks) + +        # Confirm that the value has been incremented the exact right number of times. +        value = await self.cog.redis.get("test_key") +        self.assertEqual(value, 100) + +    async def test_exceptions_raised(self): +        """Testing that the various RuntimeErrors are reachable.""" +        class MyCog: +            cache = RedisCache() + +            def __init__(self): +                self.other_cache = RedisCache() + +        cog = MyCog() + +        # Raises "No Bot instance" +        with self.assertRaises(NoBotInstanceError): +            await cog.cache.get("john") + +        # Raises "RedisCache has no namespace" +        with self.assertRaises(NoNamespaceError): +            await cog.other_cache.get("was") + +        # Raises "You must access the RedisCache instance through the cog instance" +        with self.assertRaises(NoParentInstanceError): +            await MyCog.cache.get("afraid") diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..694d3a40f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@  import asyncio  import unittest  from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch  from dateutil.relativedelta import relativedelta  from bot.utils import time -from tests.helpers import AsyncMock  class TimeTests(unittest.TestCase): @@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):          for max_units in test_cases:              with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:                  time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) -                self.assertEqual(str(error), 'max_units must be positive') +            self.assertEqual(str(error.exception), 'max_units must be positive')      def test_parse_rfc1123(self):          """Testing parse_rfc1123.""" | 
