diff options
| author | 2024-04-15 08:34:44 -0400 | |
|---|---|---|
| committer | 2024-04-15 08:34:44 -0400 | |
| commit | c071e39685c1d54ccb4a5b322bf127c3a0e16737 (patch) | |
| tree | 1d1da6fb472bac863a9f9f3bfaf8daa931ecf726 /tests/bot | |
| parent | Update comment for clarity on skipping adding of phishing button (diff) | |
| parent | Update site namespace in constants (diff) | |
Merge branch 'main' into vivek/fix-phishing-button
Diffstat (limited to 'tests/bot')
| -rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py | 64 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 1 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 20 | ||||
| -rw-r--r-- | tests/bot/test_constants.py | 4 | 
4 files changed, 59 insertions, 30 deletions
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 2ce950965..6d7356bf2 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -1,4 +1,6 @@ +import types  import unittest +import unittest.mock  from unittest import mock  import discord @@ -60,40 +62,54 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):  class SyncCogTests(SyncCogTestCase):      """Tests for the Sync cog.""" -    async def test_sync_cog_sync_on_load(self): -        """Roles and users should be synced on cog load.""" -        guild = helpers.MockGuild() -        self.bot.get_guild = mock.MagicMock(return_value=guild) - -        self.RoleSyncer.reset_mock() -        self.UserSyncer.reset_mock() - -        await self.cog.cog_load() - -        self.RoleSyncer.sync.assert_called_once_with(guild) -        self.UserSyncer.sync.assert_called_once_with(guild) - -    async def test_sync_cog_sync_guild(self): -        """Roles and users should be synced only if a guild is successfully retrieved.""" +    @unittest.mock.patch("bot.exts.backend.sync._cog.create_task", new_callable=unittest.mock.MagicMock) +    async def test_sync_cog_sync_on_load(self, mock_create_task: unittest.mock.MagicMock): +        """Sync function should be synced on cog load only if guild is found."""          for guild in (helpers.MockGuild(), None):              with self.subTest(guild=guild): +                mock_create_task.reset_mock()                  self.bot.reset_mock()                  self.RoleSyncer.reset_mock()                  self.UserSyncer.reset_mock()                  self.bot.get_guild = mock.MagicMock(return_value=guild) - -                await self.cog.cog_load() - -                self.bot.wait_until_guild_available.assert_called_once() -                self.bot.get_guild.assert_called_once_with(constants.Guild.id) +                error_raised = False +                try: +                    await self.cog.cog_load() +                except ValueError: +                    if guild is None: +                        error_raised = True +                    else: +                        raise                  if guild is None: -                    self.RoleSyncer.sync.assert_not_called() -                    self.UserSyncer.sync.assert_not_called() +                    self.assertTrue(error_raised) +                    mock_create_task.assert_not_called()                  else: -                    self.RoleSyncer.sync.assert_called_once_with(guild) -                    self.UserSyncer.sync.assert_called_once_with(guild) +                    mock_create_task.assert_called_once() +                    create_task_arg = mock_create_task.call_args[0][0] +                    self.assertIsInstance(create_task_arg, types.CoroutineType) +                    self.assertEqual(create_task_arg.__qualname__, self.cog.sync.__qualname__) +                    create_task_arg.close() + +    async def test_sync_cog_sync_guild(self): +        """Roles and users should be synced only if a guild is successfully retrieved.""" +        guild = helpers.MockGuild() +        self.bot.reset_mock() +        self.RoleSyncer.reset_mock() +        self.UserSyncer.reset_mock() + +        self.bot.get_guild = mock.MagicMock(return_value=guild) +        await self.cog.cog_load() + +        with mock.patch("asyncio.sleep", new_callable=unittest.mock.AsyncMock): +            await self.cog.sync() + +        self.bot.wait_until_guild_available.assert_called_once() +        self.bot.get_guild.assert_called_once_with(constants.Guild.id) + +        self.RoleSyncer.sync.assert_called_once() +        self.UserSyncer.sync.assert_called_once()      async def patch_user_helper(self, side_effect: BaseException) -> None:          """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 2fc97af2d..2fc000446 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -11,6 +11,7 @@ def fake_user(**kwargs):      """Fixture to return a dictionary representing a user with default values set."""      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man") +    kwargs.setdefault("display_name", "bob")      kwargs.setdefault("discriminator", 1337)      kwargs.setdefault("roles", [helpers.MockRole(id=666)])      kwargs.setdefault("in_guild", True) diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 26ba770dc..f257bec7d 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -37,7 +37,9 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.mod_log.ignore = Mock()          self.ctx.guild.ban = AsyncMock() -        await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) +        infraction_reason = "foo bar" * 3000 + +        await self.cog.apply_ban(self.ctx, self.target, infraction_reason)          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar", "purge": ""}, self.target, ANY          ) @@ -46,10 +48,14 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          await action()          self.ctx.guild.ban.assert_awaited_once_with(              self.target, -            reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), +            reason=textwrap.shorten(infraction_reason, 512, placeholder="..."),              delete_message_days=0          ) +        # Assert that the reason sent to the database isn't truncated. +        post_infraction_mock.assert_awaited_once() +        self.assertEqual(post_infraction_mock.call_args.args[3], infraction_reason) +      @patch("bot.exts.moderation.infraction._utils.post_infraction")      async def test_apply_kick_reason_truncation(self, post_infraction_mock):          """Should truncate reason for `Member.kick`.""" @@ -59,14 +65,20 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.mod_log.ignore = Mock()          self.target.kick = AsyncMock() -        await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) +        infraction_reason = "foo bar" * 3000 + +        await self.cog.apply_kick(self.ctx, self.target, infraction_reason)          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar"}, self.target, ANY          )          action = self.cog.apply_infraction.call_args.args[-1]          await action() -        self.target.kick.assert_awaited_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +        self.target.kick.assert_awaited_once_with(reason=textwrap.shorten(infraction_reason, 512, placeholder="...")) + +        # Assert that the reason sent to the database isn't truncated. +        post_infraction_mock.assert_awaited_once() +        self.assertEqual(post_infraction_mock.call_args.args[3], infraction_reason)  @patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index 87933d59a..916e1d5bb 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -10,7 +10,7 @@ current_path = Path(__file__)  env_file_path = current_path.parent / ".testenv" -class TestEnvConfig( +class _TestEnvConfig(      EnvConfig,      env_file=env_file_path,  ): @@ -21,7 +21,7 @@ class NestedModel(BaseModel):      server_name: str -class _TestConfig(TestEnvConfig, env_prefix="unittests_"): +class _TestConfig(_TestEnvConfig, env_prefix="unittests_"):      goat: str      execution_env: str = "local"  |