diff options
Diffstat (limited to 'tests')
26 files changed, 196 insertions, 512 deletions
| diff --git a/tests/base.py b/tests/base.py index 88693f382..42174e911 100644 --- a/tests/base.py +++ b/tests/base.py @@ -22,8 +22,13 @@ class _CaptureLogHandler(logging.Handler):          self.records.append(record) -class LoggingTestCase(unittest.TestCase): -    """TestCase subclass that adds more logging assertion tools.""" +class LoggingTestsMixin: +    """ +    A mixin that defines additional test methods for logging behavior. + +    This mixin relies on the availability of the `fail` attribute defined by the +    test classes included in Python's unittest method to signal test failure. +    """      @contextmanager      def assertNotLogs(self, logger=None, level=None, msg=None): @@ -73,10 +78,9 @@ class LoggingTestCase(unittest.TestCase):              self.fail(msg) -class CommandTestCase(unittest.TestCase): +class CommandTestCase(unittest.IsolatedAsyncioTestCase):      """TestCase with additional assertions that are useful for testing Discord commands.""" -    @helpers.async_test      async def assertHasPermissionsCheck(          self,          cmd: commands.Command, diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index c2e143865..fe0594efe 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -13,8 +13,8 @@ class TestSyncer(Syncer):      """Syncer subclass with mocks for abstract methods for testing purposes."""      name = "test" -    _get_diff = helpers.AsyncMock() -    _sync = helpers.AsyncMock() +    _get_diff = mock.AsyncMock() +    _sync = mock.AsyncMock()  class SyncerBaseTests(unittest.TestCase): @@ -29,7 +29,7 @@ class SyncerBaseTests(unittest.TestCase):              Syncer(self.bot) -class SyncerSendPromptTests(unittest.TestCase): +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):      """Tests for sending the sync confirmation prompt."""      def setUp(self): @@ -61,7 +61,6 @@ class SyncerSendPromptTests(unittest.TestCase):          return mock_channel, mock_message -    @helpers.async_test      async def test_send_prompt_edits_and_returns_message(self):          """The given message should be edited to display the prompt and then should be returned."""          msg = helpers.MockMessage() @@ -71,7 +70,6 @@ class SyncerSendPromptTests(unittest.TestCase):          self.assertIn("content", msg.edit.call_args[1])          self.assertEqual(ret_val, msg) -    @helpers.async_test      async def test_send_prompt_gets_dev_core_channel(self):          """The dev-core channel should be retrieved if an extant message isn't given."""          subtests = ( @@ -86,7 +84,6 @@ class SyncerSendPromptTests(unittest.TestCase):                  method.assert_called_once_with(constants.Channels.dev_core) -    @helpers.async_test      async def test_send_prompt_returns_None_if_channel_fetch_fails(self):          """None should be returned if there's an HTTPException when fetching the channel."""          self.bot.get_channel.return_value = None @@ -96,7 +93,6 @@ class SyncerSendPromptTests(unittest.TestCase):          self.assertIsNone(ret_val) -    @helpers.async_test      async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):          """A new message mentioning core devs should be sent and returned if message isn't given."""          for mock_ in (self.mock_get_channel, self.mock_fetch_channel): @@ -108,7 +104,6 @@ class SyncerSendPromptTests(unittest.TestCase):                  self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])                  self.assertEqual(ret_val, mock_message) -    @helpers.async_test      async def test_send_prompt_adds_reactions(self):          """The message should have reactions for confirmation added."""          extant_message = helpers.MockMessage() @@ -129,7 +124,7 @@ class SyncerSendPromptTests(unittest.TestCase):                  mock_message.add_reaction.assert_has_calls(calls) -class SyncerConfirmationTests(unittest.TestCase): +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):      """Tests for waiting for a sync confirmation reaction on the prompt."""      def setUp(self): @@ -211,7 +206,6 @@ class SyncerConfirmationTests(unittest.TestCase):                  ret_val = self.syncer._reaction_check(*args)                  self.assertFalse(ret_val) -    @helpers.async_test      async def test_wait_for_confirmation(self):          """The message should always be edited and only return True if the emoji is a check mark."""          subtests = ( @@ -251,14 +245,13 @@ class SyncerConfirmationTests(unittest.TestCase):                      self.assertIs(actual_return, ret_val) -class SyncerSyncTests(unittest.TestCase): +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for main function orchestrating the sync."""      def setUp(self):          self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))          self.syncer = TestSyncer(self.bot) -    @helpers.async_test      async def test_sync_respects_confirmation_result(self):          """The sync should abort if confirmation fails and continue if confirmed."""          mock_message = helpers.MockMessage() @@ -274,7 +267,7 @@ class SyncerSyncTests(unittest.TestCase):                  diff = _Diff({1, 2, 3}, {4, 5}, None)                  self.syncer._get_diff.return_value = diff -                self.syncer._get_confirmation_result = helpers.AsyncMock( +                self.syncer._get_confirmation_result = mock.AsyncMock(                      return_value=(confirmed, message)                  ) @@ -289,7 +282,6 @@ class SyncerSyncTests(unittest.TestCase):                  else:                      self.syncer._sync.assert_not_called() -    @helpers.async_test      async def test_sync_diff_size(self):          """The diff size should be correctly calculated."""          subtests = ( @@ -303,7 +295,7 @@ class SyncerSyncTests(unittest.TestCase):              with self.subTest(size=size, diff=diff):                  self.syncer._get_diff.reset_mock()                  self.syncer._get_diff.return_value = diff -                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) +                self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))                  guild = helpers.MockGuild()                  await self.syncer.sync(guild) @@ -312,7 +304,6 @@ class SyncerSyncTests(unittest.TestCase):                  self.syncer._get_confirmation_result.assert_called_once()                  self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) -    @helpers.async_test      async def test_sync_message_edited(self):          """The message should be edited if one was sent, even if the sync has an API error."""          subtests = ( @@ -324,7 +315,7 @@ class SyncerSyncTests(unittest.TestCase):          for message, side_effect, should_edit in subtests:              with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):                  self.syncer._sync.side_effect = side_effect -                self.syncer._get_confirmation_result = helpers.AsyncMock( +                self.syncer._get_confirmation_result = mock.AsyncMock(                      return_value=(True, message)                  ) @@ -335,7 +326,6 @@ class SyncerSyncTests(unittest.TestCase):                      message.edit.assert_called_once()                      self.assertIn("content", message.edit.call_args[1]) -    @helpers.async_test      async def test_sync_confirmation_context_redirect(self):          """If ctx is given, a new message should be sent and author should be ctx's author."""          mock_member = helpers.MockMember() @@ -349,7 +339,10 @@ class SyncerSyncTests(unittest.TestCase):                  if ctx is not None:                      ctx.send.return_value = message -                self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) +                # Make sure `_get_diff` returns a MagicMock, not an AsyncMock +                self.syncer._get_diff.return_value = mock.MagicMock() + +                self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))                  guild = helpers.MockGuild()                  await self.syncer.sync(guild, ctx) @@ -362,16 +355,15 @@ class SyncerSyncTests(unittest.TestCase):                  self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)      @mock.patch.object(constants.Sync, "max_diff", new=3) -    @helpers.async_test      async def test_confirmation_result_small_diff(self):          """Should always return True and the given message if the diff size is too small."""          author = helpers.MockMember()          expected_message = helpers.MockMessage() -        for size in (3, 2): +        for size in (3, 2):  # pragma: no cover              with self.subTest(size=size): -                self.syncer._send_prompt = helpers.AsyncMock() -                self.syncer._wait_for_confirmation = helpers.AsyncMock() +                self.syncer._send_prompt = mock.AsyncMock() +                self.syncer._wait_for_confirmation = mock.AsyncMock()                  coro = self.syncer._get_confirmation_result(size, author, expected_message)                  result, actual_message = await coro @@ -382,7 +374,6 @@ class SyncerSyncTests(unittest.TestCase):                  self.syncer._wait_for_confirmation.assert_not_called()      @mock.patch.object(constants.Sync, "max_diff", new=3) -    @helpers.async_test      async def test_confirmation_result_large_diff(self):          """Should return True if confirmed and False if _send_prompt fails or aborted."""          author = helpers.MockMember() @@ -394,10 +385,10 @@ class SyncerSyncTests(unittest.TestCase):              (False, mock_message, False, "aborted"),          ) -        for expected_result, expected_message, confirmed, msg in subtests: +        for expected_result, expected_message, confirmed, msg in subtests:  # pragma: no cover              with self.subTest(msg=msg): -                self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) -                self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) +                self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) +                self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)                  coro = self.syncer._get_confirmation_result(4, author)                  actual_result, actual_message = await coro diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 98c9afc0d..81398c61f 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -11,19 +11,7 @@ from tests import helpers  from tests.base import CommandTestCase -class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): -    """ -    A MagicMock subclass to mock Syncer objects. - -    Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` -    instances. For more information, see the `MockGuild` docstring. -    """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=Syncer, **kwargs) - - -class SyncExtensionTests(unittest.TestCase): +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase):      """Tests for the sync extension."""      @staticmethod @@ -34,22 +22,21 @@ class SyncExtensionTests(unittest.TestCase):          bot.add_cog.assert_called_once() -class SyncCogTestCase(unittest.TestCase): +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):      """Base class for Sync cog tests. Sets up patches for syncers."""      def setUp(self):          self.bot = helpers.MockBot() -        # These patch the type. When the type is called, a MockSyncer instanced is returned. -        # MockSyncer is needed so that our custom AsyncMock is used. -        # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.          self.role_syncer_patcher = mock.patch(              "bot.cogs.sync.syncers.RoleSyncer", -            new=mock.MagicMock(return_value=MockSyncer()) +            autospec=Syncer, +            spec_set=True          )          self.user_syncer_patcher = mock.patch(              "bot.cogs.sync.syncers.UserSyncer", -            new=mock.MagicMock(return_value=MockSyncer()) +            autospec=Syncer, +            spec_set=True          )          self.RoleSyncer = self.role_syncer_patcher.start()          self.UserSyncer = self.user_syncer_patcher.start() @@ -72,13 +59,13 @@ class SyncCogTestCase(unittest.TestCase):  class SyncCogTests(SyncCogTestCase):      """Tests for the Sync cog.""" -    @mock.patch.object(sync.Sync, "sync_guild") +    @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock)      def test_sync_cog_init(self, sync_guild):          """Should instantiate syncers and run a sync for the guild."""          # Reset because a Sync cog was already instantiated in setUp.          self.RoleSyncer.reset_mock()          self.UserSyncer.reset_mock() -        self.bot.loop.create_task.reset_mock() +        self.bot.loop.create_task = mock.MagicMock()          mock_sync_guild_coro = mock.MagicMock()          sync_guild.return_value = mock_sync_guild_coro @@ -90,7 +77,6 @@ class SyncCogTests(SyncCogTestCase):          sync_guild.assert_called_once_with()          self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) -    @helpers.async_test      async def test_sync_cog_sync_guild(self):          """Roles and users should be synced only if a guild is successfully retrieved."""          for guild in (helpers.MockGuild(), None): @@ -126,14 +112,12 @@ class SyncCogTests(SyncCogTestCase):              json=updated_information,          ) -    @helpers.async_test      async def test_sync_cog_patch_user(self):          """A PATCH request should be sent and 404 errors ignored."""          for side_effect in (None, self.response_error(404)):              with self.subTest(side_effect=side_effect):                  await self.patch_user_helper(side_effect) -    @helpers.async_test      async def test_sync_cog_patch_user_non_404(self):          """A PATCH request should be sent and the error raised if it's not a 404."""          with self.assertRaises(ResponseCodeError): @@ -145,9 +129,8 @@ class SyncCogListenerTests(SyncCogTestCase):      def setUp(self):          super().setUp() -        self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) +        self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) -    @helpers.async_test      async def test_sync_cog_on_guild_role_create(self):          """A POST request should be sent with the new role's data."""          self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -164,7 +147,6 @@ class SyncCogListenerTests(SyncCogTestCase):          self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) -    @helpers.async_test      async def test_sync_cog_on_guild_role_delete(self):          """A DELETE request should be sent."""          self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) @@ -174,7 +156,6 @@ class SyncCogListenerTests(SyncCogTestCase):          self.bot.api_client.delete.assert_called_once_with("bot/roles/99") -    @helpers.async_test      async def test_sync_cog_on_guild_role_update(self):          """A PUT request should be sent if the colour, name, permissions, or position changes."""          self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -212,7 +193,6 @@ class SyncCogListenerTests(SyncCogTestCase):                      else:                          self.bot.api_client.put.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_member_remove(self):          """Member should patched to set in_guild as False."""          self.assertTrue(self.cog.on_member_remove.__cog_listener__) @@ -225,7 +205,6 @@ class SyncCogListenerTests(SyncCogTestCase):              updated_information={"in_guild": False}          ) -    @helpers.async_test      async def test_sync_cog_on_member_update_roles(self):          """Members should be patched if their roles have changed."""          self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -240,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase):          data = {"roles": sorted(role.id for role in after_member.roles)}          self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) -    @helpers.async_test      async def test_sync_cog_on_member_update_other(self):          """Members should not be patched if other attributes have changed."""          self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -262,7 +240,6 @@ class SyncCogListenerTests(SyncCogTestCase):                  self.cog.patch_user.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_user_update(self):          """A user should be patched only if the name, discriminator, or avatar changes."""          self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -341,7 +318,6 @@ class SyncCogListenerTests(SyncCogTestCase):          return data -    @helpers.async_test      async def test_sync_cog_on_member_join(self):          """Should PUT user's data or POST it if the user doesn't exist."""          for side_effect in (None, self.response_error(404)): @@ -354,7 +330,6 @@ class SyncCogListenerTests(SyncCogTestCase):                  else:                      self.bot.api_client.post.assert_not_called() -    @helpers.async_test      async def test_sync_cog_on_member_join_non_404(self):          """ResponseCodeError should be re-raised if status code isn't a 404."""          with self.assertRaises(ResponseCodeError): @@ -366,7 +341,6 @@ class SyncCogListenerTests(SyncCogTestCase):  class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):      """Tests for the commands in the Sync cog.""" -    @helpers.async_test      async def test_sync_roles_command(self):          """sync() should be called on the RoleSyncer."""          ctx = helpers.MockContext() @@ -374,7 +348,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) -    @helpers.async_test      async def test_sync_users_command(self):          """sync() should be called on the UserSyncer."""          ctx = helpers.MockContext() @@ -382,7 +355,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) -    def test_commands_require_admin(self): +    async def test_commands_require_admin(self):          """The sync commands should only run if the author has the administrator permission."""          cmds = (              self.cog.sync_group, @@ -392,4 +365,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          for cmd in cmds:              with self.subTest(cmd=cmd): -                self.assertHasPermissionsCheck(cmd, {"administrator": True}) +                await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 14fb2577a..79eee98f4 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -18,7 +18,7 @@ def fake_role(**kwargs):      return kwargs -class RoleSyncerDiffTests(unittest.TestCase): +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between roles in the DB and roles in the Guild cache."""      def setUp(self): @@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          return guild -    @helpers.async_test      async def test_empty_diff_for_identical_roles(self):          """No differences should be found if the roles in the guild and DB are identical."""          self.bot.api_client.get.return_value = [fake_role()] @@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_updated_roles(self):          """Only updated roles should be added to the 'updated' set of the diff."""          updated_role = fake_role(id=41, name="new") @@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_roles(self):          """Only new roles should be added to the 'created' set of the diff."""          new_role = fake_role(id=41, name="new") @@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_deleted_roles(self):          """Only deleted roles should be added to the 'deleted' set of the diff."""          deleted_role = fake_role(id=61, name="deleted") @@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_updated_and_deleted_roles(self):          """When roles are added, updated, and removed, all of them are returned properly."""          new = fake_role(id=41, name="new") @@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -class RoleSyncerSyncTests(unittest.TestCase): +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync roles."""      def setUp(self):          self.bot = helpers.MockBot()          self.syncer = RoleSyncer(self.bot) -    @helpers.async_test      async def test_sync_created_roles(self):          """Only POST requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] @@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase):          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_updated_roles(self):          """Only PUT requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] @@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase):          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_deleted_roles(self):          """Only DELETE requests should be made with the correct payload."""          roles = [fake_role(id=111), fake_role(id=222)] diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 421bf6bb6..818883012 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -17,7 +17,7 @@ def fake_user(**kwargs):      return kwargs -class UserSyncerDiffTests(unittest.TestCase): +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache."""      def setUp(self): @@ -42,7 +42,6 @@ class UserSyncerDiffTests(unittest.TestCase):          return guild -    @helpers.async_test      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned."""          guild = self.get_guild() @@ -52,7 +51,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical."""          self.bot.api_client.get.return_value = [fake_user()] @@ -63,7 +61,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_updated_users(self):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") @@ -76,7 +73,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_users(self):          """Only new users should be added to the 'created' set of the diff."""          new_user = fake_user(id=99, name="new") @@ -89,7 +85,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`."""          leaving_user = fake_user(id=63, in_guild=False) @@ -102,7 +97,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") @@ -117,7 +111,6 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -    @helpers.async_test      async def test_empty_diff_for_db_users_not_in_guild(self):          """When the DB knows a user the guild doesn't, no difference is found."""          self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] @@ -129,14 +122,13 @@ class UserSyncerDiffTests(unittest.TestCase):          self.assertEqual(actual_diff, expected_diff) -class UserSyncerSyncTests(unittest.TestCase): +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self):          self.bot = helpers.MockBot()          self.syncer = UserSyncer(self.bot) -    @helpers.async_test      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] @@ -152,7 +144,6 @@ class UserSyncerSyncTests(unittest.TestCase):          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() -    @helpers.async_test      async def test_sync_updated_users(self):          """Only PUT requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 5b0a3b8c3..7e6bfc748 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio  import logging  import typing  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch  import discord @@ -14,7 +14,7 @@ from tests import helpers  MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):      """Tests for DuckPond functionality."""      @classmethod @@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):              with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):                  self.assertEqual(expected_return, actual_return) -    @helpers.async_test      async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):          """The `has_green_checkmark` method should only return `True` if one is present."""          test_cases = ( @@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):          nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]          return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) -    @helpers.async_test      async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):          """The `count_ducks` method should return the number of unique staffers who gave a duck."""          test_cases = ( @@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):              with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):                  self.assertEqual(expected_count, actual_count) -    @helpers.async_test      async def test_relay_message_correctly_relays_content_and_attachments(self):          """The `relay_message` method should correctly relay message content and attachments."""          send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" @@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):          )          for message, expect_webhook_call, expect_attachment_call in test_values: -            with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: -                with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: +            with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: +                with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:                      with self.subTest(clean_content=message.clean_content, attachments=message.attachments):                          await self.cog.relay_message(message) @@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):                          message.add_reaction.assert_called_once_with(self.checkmark_emoji) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) -    @helpers.async_test +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):          self.cog.webhook = helpers.MockAsyncWebhook()          log = logging.getLogger("bot.cogs.duck_pond") -        for side_effect in side_effects: +        for side_effect in side_effects:  # pragma: no cover              send_attachments.side_effect = side_effect -            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: +            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:                  with self.subTest(side_effect=type(side_effect).__name__):                      with self.assertNotLogs(logger=log, level=logging.ERROR):                          await self.cog.relay_message(message)                      self.assertEqual(send_webhook.call_count, 2) -    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) -    @helpers.async_test +    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):          payload.emoji.name = emoji_name          return payload -    @helpers.async_test      async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):          """The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""          test_values = ( @@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):          return channel, message, member, payload -    @helpers.async_test      async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):          """The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""          channel_id = 1234 @@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):                  channel.fetch_message.reset_mock()      @patch(f"{MODULE_PATH}.DuckPond.is_staff") -    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) +    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)      def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):          """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""          channel_id = 31415926535 @@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):          # Assert that we've made it past `self.is_staff`          is_staff.assert_called_once() -    @helpers.async_test      async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):          """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""          test_cases = ( @@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):          payload.emoji = self.duck_pond_emoji          for duck_count, should_relay in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: -                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: +                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                      count_ducks.return_value = duck_count                      with self.subTest(duck_count=duck_count, should_relay=should_relay):                          await self.cog.on_raw_reaction_add(payload) @@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):                          if should_relay:                              relay_message.assert_called_once_with(message) -    @helpers.async_test      async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):          """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""          checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) @@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):              (constants.DuckPond.threshold + 1, True),          )          for duck_count, should_re_add_checkmark in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                  count_ducks.return_value = duck_count                  with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):                      await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 8443cfe71..5693d2946 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase):          """Test if the `role_info` command correctly returns the `moderator_role`."""          self.ctx.guild.roles.append(self.moderator_role) -        self.cog.roles_info.can_run = helpers.AsyncMock() +        self.cog.roles_info.can_run = unittest.mock.AsyncMock()          self.cog.roles_info.can_run.return_value = True          coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase):          self.ctx.guild.roles.append([dummy_role, admin_role]) -        self.cog.role_info.can_run = helpers.AsyncMock() +        self.cog.role_info.can_run = unittest.mock.AsyncMock()          self.cog.role_info.can_run.return_value = True          coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot)          self.member = helpers.MockMember(id=1234) @@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Mr. Hemlock") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_nick_in_title_if_available(self):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_ignores_everyone_role(self):          """Created `!user` embeds should not contain mention of the @everyone-role."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase):          self.assertIn("&Admins", embed.description)          self.assertNotIn("&Everyone", embed.description) -    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) -    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):          """The embed should contain expanded infractions and nomination info in mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):          """The embed should contain only basic infraction data outside of mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):          """The embed should be created with the colour of the top role, if a top role is available."""          ctx = helpers.MockContext() @@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):          """The embed should be created with a blurple colour if the user has no assigned roles."""          ctx = helpers.MockContext() @@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour.blurple()) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):          """The embed thumbnail should be set to the user's avatar in `png` format."""          ctx = helpers.MockContext() @@ -529,7 +529,7 @@ class UserCommandTests(unittest.TestCase):          with self.assertRaises(InChannelCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx)) -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -542,7 +542,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -555,7 +555,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.moderator)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 985bc66a1..9cd7f0154 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,74 +1,68 @@  import asyncio  import logging  import unittest -from functools import partial -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch  from bot.cogs import snekbox  from bot.cogs.snekbox import Snekbox  from bot.constants import URLs -from tests.helpers import ( -    AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test -) +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser -class SnekboxTests(unittest.TestCase): +class SnekboxTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          """Add mocked bot and cog to the instance."""          self.bot = MockBot() - -        self.mocked_post = MagicMock() -        self.mocked_post.json = AsyncMock() -        self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post)) -          self.cog = Snekbox(bot=self.bot) -    @async_test      async def test_post_eval(self):          """Post the eval code to the URLs.snekbox_eval_api endpoint.""" -        self.mocked_post.json.return_value = {'lemon': 'AI'} +        resp = MagicMock() +        resp.json = AsyncMock(return_value="return") +        self.bot.http_session.post().__aenter__.return_value = resp -        self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'}) -        self.bot.http_session.post.assert_called_once_with( +        self.assertEqual(await self.cog.post_eval("import random"), "return") +        self.bot.http_session.post.assert_called_with(              URLs.snekbox_eval_api,              json={"input": "import random"},              raise_for_status=True          ) +        resp.json.assert_awaited_once() -    @async_test      async def test_upload_output_reject_too_long(self):          """Reject output longer than MAX_PASTE_LEN."""          result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))          self.assertEqual(result, "too long to upload") -    @async_test      async def test_upload_output(self):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" -        key = "RainbowDash" -        self.mocked_post.json.return_value = {"key": key} +        key = "MarkDiamond" +        resp = MagicMock() +        resp.json = AsyncMock(return_value={"key": key}) +        self.bot.http_session.post().__aenter__.return_value = resp          self.assertEqual(              await self.cog.upload_output("My awesome output"),              URLs.paste_service.format(key=key)          ) -        self.bot.http_session.post.assert_called_once_with( +        self.bot.http_session.post.assert_called_with(              URLs.paste_service.format(key="documents"),              data="My awesome output",              raise_for_status=True          ) -    @async_test      async def test_upload_output_gracefully_fallback_if_exception_during_request(self):          """Output upload gracefully fallback if the upload fail.""" -        self.mocked_post.json.side_effect = Exception +        resp = MagicMock() +        resp.json = AsyncMock(side_effect=Exception) +        self.bot.http_session.post().__aenter__.return_value = resp +          log = logging.getLogger("bot.cogs.snekbox")          with self.assertLogs(logger=log, level='ERROR'):              await self.cog.upload_output('My awesome output!') -    @async_test      async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):          """Output upload gracefully fallback if there is no key entry in the response body.""" -        self.mocked_post.json.return_value = {}          self.assertEqual((await self.cog.upload_output('My awesome output!')), None)      def test_prepare_input(self): @@ -121,7 +115,6 @@ class SnekboxTests(unittest.TestCase):                  actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})                  self.assertEqual(actual, expected) -    @async_test      async def test_format_output(self):          """Test output formatting."""          self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') @@ -172,7 +165,6 @@ class SnekboxTests(unittest.TestCase):              with self.subTest(msg=testname, case=case, expected=expected):                  self.assertEqual(await self.cog.format_output(case), expected) -    @async_test      async def test_eval_command_evaluate_once(self):          """Test the eval command procedure."""          ctx = MockContext() @@ -186,7 +178,6 @@ class SnekboxTests(unittest.TestCase):          self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_once_with(ctx, response) -    @async_test      async def test_eval_command_evaluate_twice(self):          """Test the eval and re-eval command procedure."""          ctx = MockContext() @@ -201,7 +192,6 @@ class SnekboxTests(unittest.TestCase):          self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_with(ctx, response) -    @async_test      async def test_eval_command_reject_two_eval_at_the_same_time(self):          """Test if the eval command rejects an eval if the author already have a running eval."""          ctx = MockContext() @@ -214,7 +204,6 @@ class SnekboxTests(unittest.TestCase):              "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"          ) -    @async_test      async def test_eval_command_call_help(self):          """Test if the eval command call the help command if no code is provided."""          ctx = MockContext() @@ -222,14 +211,13 @@ class SnekboxTests(unittest.TestCase):          await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')          ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") -    @async_test      async def test_send_eval(self):          """Test the send_eval function."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +          self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -244,14 +232,13 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})          self.cog.format_output.assert_called_once_with('') -    @async_test      async def test_send_eval_with_paste_link(self):          """Test the send_eval function with a too long output that generate a paste link."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +          self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -267,14 +254,12 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})          self.cog.format_output.assert_called_once_with('Way too long beard') -    @async_test      async def test_send_eval_with_non_zero_eval(self):          """Test the send_eval function with a code returning a non-zero code."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))          self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})          self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))          self.cog.get_status_emoji = MagicMock(return_value=':nope!:') @@ -289,8 +274,8 @@ class SnekboxTests(unittest.TestCase):          self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})          self.cog.format_output.assert_not_called() -    @async_test -    async def test_continue_eval_does_continue(self): +    @patch("bot.cogs.snekbox.partial") +    async def test_continue_eval_does_continue(self, partial_mock):          """Test that the continue_eval function does continue if required conditions are met."""          ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))          response = MockMessage(delete=AsyncMock()) @@ -299,15 +284,16 @@ class SnekboxTests(unittest.TestCase):          actual = await self.cog.continue_eval(ctx, response)          self.assertEqual(actual, 'NewCode') -        self.bot.wait_for.has_calls( -            call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10), -            call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +        self.bot.wait_for.assert_has_awaits( +            ( +                call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), +                call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +            )          )          ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)          ctx.message.clear_reactions.assert_called_once()          response.delete.assert_called_once() -    @async_test      async def test_continue_eval_does_not_continue(self):          ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))          self.bot.wait_for.side_effect = asyncio.TimeoutError diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..33d1ec170 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,7 +1,7 @@  import asyncio  import logging  import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock  from discord import Colour @@ -11,7 +11,7 @@ from bot.cogs.token_remover import (      setup as setup_cog,  )  from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from tests.helpers import MockBot, MockMessage  class TokenRemoverTests(unittest.TestCase): diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index 36c986fe1..0d570f5a3 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple):      n_violations: int -class RuleTest(unittest.TestCase, metaclass=ABCMeta): +class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):      """      Abstract class for antispam rule test cases. @@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta):      @abstractmethod      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          """Give expected relevant messages for `case`.""" -        raise NotImplementedError +        raise NotImplementedError  # pragma: no cover      @abstractmethod      def get_report(self, case: DisallowedCase) -> str:          """Give expected error report for `case`.""" -        raise NotImplementedError +        raise NotImplementedError  # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index e54b4b5b8..d7e779221 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import attachments  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_attachments: int) -> MockMessage: @@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest):          self.apply = attachments.apply          self.config = {"max": 5, "interval": 10} -    @async_test      async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( @@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py index 72f0be0c7..03682966b 100644 --- a/tests/bot/rules/test_burst.py +++ b/tests/bot/rules/test_burst.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import burst  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest):          self.apply = burst.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py index 47367a5f8..3275143d5 100644 --- a/tests/bot/rules/test_burst_shared.py +++ b/tests/bot/rules/test_burst_shared.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import burst_shared  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest):          self.apply = burst_shared.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """          Cases that do not violate the rule. @@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py index 7cc36f49e..f1e3c76a7 100644 --- a/tests/bot/rules/test_chars.py +++ b/tests/bot/rules/test_chars.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import chars  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, n_chars: int) -> MockMessage: @@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest):              "interval": 10,          } -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of chars within limit."""          cases = ( @@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the total amount of chars exceeds the limit, triggering the rule."""          cases = ( diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 0239b0b00..9a72723e2 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import discord_emojis  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> @@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest):          self.apply = discord_emojis.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of discord emojis within limit."""          cases = ( @@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of discord emojis."""          cases = ( diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py index 59e0fb6ef..9bd886a77 100644 --- a/tests/bot/rules/test_duplicates.py +++ b/tests/bot/rules/test_duplicates.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import duplicates  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, content: str) -> MockMessage: @@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest):          self.apply = duplicates.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with too many duplicate messages from the same author."""          cases = ( diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 3c3f90e5f..b091bd9d7 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import links  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_links: int) -> MockMessage: @@ -21,7 +21,6 @@ class LinksTests(RuleTest):              "interval": 10          } -    @async_test      async def test_links_within_limit(self):          """Messages with an allowed amount of links."""          cases = ( @@ -34,7 +33,6 @@ class LinksTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ebcdabac6..6444532f2 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, total_mentions: int) -> MockMessage: @@ -20,7 +20,6 @@ class TestMentions(RuleTest):              "interval": 10,          } -    @async_test      async def test_mentions_within_limit(self):          """Messages with an allowed amount of mentions."""          cases = ( @@ -32,7 +31,6 @@ class TestMentions(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py index d61c4609d..e35377773 100644 --- a/tests/bot/rules/test_newlines.py +++ b/tests/bot/rules/test_newlines.py @@ -2,7 +2,7 @@ from typing import Iterable, List  from bot.rules import newlines  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, newline_groups: List[int]) -> MockMessage: @@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest):              "interval": 10,          } -    @async_test      async def test_allows_messages_within_limit(self):          """Cases which do not violate the rule."""          cases = ( @@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_total(self):          """Cases which violate the rule by having too many newlines in total."""          cases = ( @@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest):          self.apply = newlines.apply          self.config = {"max": 5, "max_consecutive": 3, "interval": 10} -    @async_test      async def test_disallows_messages_consecutive(self):          """Cases which violate the rule due to having too many consecutive newlines."""          cases = ( diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py index b339cccf7..26c05d527 100644 --- a/tests/bot/rules/test_role_mentions.py +++ b/tests/bot/rules/test_role_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable  from bot.rules import role_mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage  def make_msg(author: str, n_mentions: int) -> MockMessage: @@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest):          self.apply = role_mentions.apply          self.config = {"max": 2, "interval": 10} -    @async_test      async def test_allows_messages_within_limit(self):          """Cases with a total amount of role mentions within limit."""          cases = ( @@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest):          await self.run_allowed(cases) -    @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of role mentions."""          cases = ( diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index bdfcc73e4..99e942813 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -2,10 +2,9 @@ import unittest  from unittest.mock import MagicMock  from bot import api -from tests.helpers import async_test -class APIClientTests(unittest.TestCase): +class APIClientTests(unittest.IsolatedAsyncioTestCase):      """Tests for the bot's API client."""      @classmethod @@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase):          """The event loop should not be running by default."""          self.assertFalse(api.loop_is_running()) -    @async_test      async def test_loop_is_running_in_async_context(self):          """The event loop should be running in an async context."""          self.assertTrue(api.loop_is_running()) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..694d3a40f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@  import asyncio  import unittest  from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch  from dateutil.relativedelta import relativedelta  from bot.utils import time -from tests.helpers import AsyncMock  class TimeTests(unittest.TestCase): @@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):          for max_units in test_cases:              with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:                  time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) -                self.assertEqual(str(error), 'max_units must be positive') +            self.assertEqual(str(error.exception), 'max_units must be positive')      def test_parse_rfc1123(self):          """Testing parse_rfc1123.""" diff --git a/tests/helpers.py b/tests/helpers.py index 6f50f6ae3..8e13f0f28 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,13 +1,10 @@  from __future__ import annotations -import asyncio  import collections -import functools -import inspect  import itertools  import logging  import unittest.mock -from typing import Any, Iterable, Optional +from typing import Iterable, Optional  import discord  from discord.ext.commands import Context @@ -26,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def async_test(wrapped): -    """ -    Run a test case via asyncio. -    Example: -        >>> @async_test -        ... async def lemon_wins(): -        ...     assert True -    """ - -    @functools.wraps(wrapped) -    def wrapper(*args, **kwargs): -        return asyncio.run(wrapped(*args, **kwargs)) -    return wrapper - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. @@ -69,24 +51,31 @@ class CustomMockMixin:      """      Provides common functionality for our custom Mock types. -    The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine -    function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care -    of making sure child mocks are instantiated with the correct class. By default, the mock of the -    children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute -    `child_mock_type` on the custom mock inheriting from this mixin. +    The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock +    object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the +    class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional +    attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The +    class method `spec_set` can be overwritten with the object that should be uses as the specification +    for the mock. + +    Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to +    implement custom behavior.      """      child_mock_type = unittest.mock.MagicMock      discord_id = itertools.count(0) +    spec_set = None +    additional_spec_asyncs = None -    def __init__(self, spec_set: Any = None, **kwargs): +    def __init__(self, **kwargs):          name = kwargs.pop('name', None)  # `name` has special meaning for Mock classes, so we need to set it manually. -        super().__init__(spec_set=spec_set, **kwargs) +        super().__init__(spec_set=self.spec_set, **kwargs) + +        if self.additional_spec_asyncs: +            self._spec_asyncs.extend(self.additional_spec_asyncs)          if name:              self.name = name -        if spec_set: -            self._extract_coroutine_methods_from_spec_instance(spec_set)      def _get_child_mock(self, **kw):          """ @@ -100,7 +89,16 @@ class CustomMockMixin:          This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.          """ -        klass = self.child_mock_type +        _new_name = kw.get("_new_name") +        if _new_name in self.__dict__['_spec_asyncs']: +            return unittest.mock.AsyncMock(**kw) + +        _type = type(self) +        if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: +            # Any asynchronous magic becomes an AsyncMock +            klass = unittest.mock.AsyncMock +        else: +            klass = self.child_mock_type          if self._mock_sealed:              attribute = "." + kw["name"] if "name" in kw else "()" @@ -109,107 +107,6 @@ class CustomMockMixin:          return klass(**kw) -    def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: -        """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" -        for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): -            setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): -    """ -    A MagicMock subclass to mock async callables. - -    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more -    features; this stand-in only overwrites the `__call__` method to an async version. -    """ - -    async def __call__(self, *args, **kwargs): -        return super().__call__(*args, **kwargs) - - -class AsyncContextManagerMock(unittest.mock.MagicMock): -    def __init__(self, return_value: Any): -        super().__init__() -        self._return_value = return_value - -    async def __aenter__(self): -        return self._return_value - -    async def __aexit__(self, *args): -        pass - - -class AsyncIteratorMock: -    """ -    A class to mock asynchronous iterators. - -    This allows async for, which is used in certain Discord.py objects. For example, -    an async iterator is returned by the Reaction.users() method. -    """ - -    def __init__(self, iterable: Iterable = None): -        if iterable is None: -            iterable = [] - -        self.iter = iter(iterable) -        self.iterable = iterable - -        self.call_count = 0 - -    def __aiter__(self): -        return self - -    async def __anext__(self): -        try: -            return next(self.iter) -        except StopIteration: -            raise StopAsyncIteration - -    def __call__(self): -        """ -        Keeps track of the number of times an instance has been called. - -        This is useful, since it typically shows that the iterator has actually been used somewhere after we have -        instantiated the mock for an attribute that normally returns an iterator when called. -        """ -        self.call_count += 1 -        return self - -    @property -    def return_value(self): -        """Makes `self.iterable` accessible as self.return_value.""" -        return self.iterable - -    @return_value.setter -    def return_value(self, iterable): -        """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" -        self.iter = iter(iterable) -        self.iterable = iterable - -    def assert_called(self): -        """Asserts if the AsyncIteratorMock instance has been called at least once.""" -        if self.call_count == 0: -            raise AssertionError("Expected AsyncIteratorMock to have been called.") - -    def assert_called_once(self): -        """Asserts if the AsyncIteratorMock instance has been called exactly once.""" -        if self.call_count != 1: -            raise AssertionError( -                f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." -            ) - -    def assert_not_called(self): -        """Asserts if the AsyncIteratorMock instance has not been called.""" -        if self.call_count != 0: -            raise AssertionError( -                f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." -            ) - -    def reset_mock(self): -        """Resets the call count, but not the return value or iterator.""" -        self.call_count = 0 -  # Create a guild instance to get a realistic Mock of `discord.Guild`  guild_data = { @@ -260,9 +157,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):      For more info, see the `Mocking` section in `tests/README.md`.      """ +    spec_set = guild_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'members': []} -        super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -281,6 +180,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.Role` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = role_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {              'id': next(self.discord_id), @@ -289,7 +190,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):              'colour': discord.Colour(0xdeadbf),              'permissions': discord.Permissions(),          } -        super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if isinstance(self.colour, int):              self.colour = discord.Colour(self.colour) @@ -318,9 +219,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin      Instances of this class will follow the specifications of `discord.Member` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = member_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -341,9 +244,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.User` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = user_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"@{self.name}" @@ -356,9 +261,7 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `bot.api.APIClient` instances.      For more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=APIClient, **kwargs) +    spec_set = APIClient  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` @@ -374,16 +277,13 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ +    spec_set = bot_instance +    additional_spec_asyncs = ("wait_for",)      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=bot_instance, **kwargs) +        super().__init__(**kwargs)          self.api_client = MockAPIClient() -        # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and -        # and should therefore be awaited. (The documentation calls it a coroutine as well, which -        # is technically incorrect, since it's a regular def.) -        self.wait_for = AsyncMock() -          # Since calling `create_task` on our MockBot does not actually schedule the coroutine object          # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object          # to prevent "has not been awaited"-warnings. @@ -413,10 +313,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = channel_instance      def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} -        super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"#{self.name}" @@ -455,9 +356,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Context`      instances. For more information, see the `MockGuild` docstring.      """ +    spec_set = context_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=context_instance, **kwargs) +        super().__init__(**kwargs)          self.bot = kwargs.get('bot', MockBot())          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember()) @@ -474,8 +376,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Attachment` instances. For      more information, see the `MockGuild` docstring.      """ -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=attachment_instance, **kwargs) +    spec_set = attachment_instance  class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -485,10 +386,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Message` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = message_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'attachments': []} -        super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) @@ -504,9 +406,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Emoji` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = emoji_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=emoji_instance, **kwargs) +        super().__init__(**kwargs)          self.guild = kwargs.get('guild', MockGuild()) @@ -520,9 +423,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=partial_emoji_instance, **kwargs) +    spec_set = partial_emoji_instance  reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -535,12 +436,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Reaction` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = reaction_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=reaction_instance, **kwargs) +        _users = kwargs.pop("users", []) +        super().__init__(**kwargs)          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage()) -        self.users = AsyncIteratorMock(kwargs.get('users', [])) + +        user_iterator = unittest.mock.AsyncMock() +        user_iterator.__aiter__.return_value = _users +        self.users.return_value = user_iterator +          self.__str__.return_value = str(self.emoji) @@ -554,13 +461,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Webhook` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=webhook_instance, **kwargs) - -        # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined -        # as coroutines. That's why we need to set the methods manually. -        self.send = AsyncMock() -        self.edit = AsyncMock() -        self.delete = AsyncMock() -        self.execute = AsyncMock() +    spec_set = webhook_instance +    additional_spec_asyncs = ("send", "edit", "delete", "execute") diff --git a/tests/test_base.py b/tests/test_base.py index a16e2af8f..a7db4bf3e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,7 +3,11 @@ import unittest  import unittest.mock -from tests.base import LoggingTestCase, _CaptureLogHandler +from tests.base import LoggingTestsMixin, _CaptureLogHandler + + +class LoggingTestCase(LoggingTestsMixin, unittest.TestCase): +    pass  class LoggingTestCaseTests(unittest.TestCase): @@ -18,24 +22,14 @@ class LoggingTestCaseTests(unittest.TestCase):          try:              with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):                  pass -        except AssertionError: +        except AssertionError:  # pragma: no cover              self.fail("`self.assertNotLogs` raised an AssertionError when it should not!") -    @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs") -    def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs): -        """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly.""" -        assertNotLogs.return_value = iter([None]) -        assertNotLogs.side_effect = AssertionError - -        message = "`self.assertNotLogs` raised an AssertionError when it should not!" -        with self.assertRaises(AssertionError, msg=message): -            self.test_assert_not_logs_does_not_raise_with_no_logs() -      def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):          """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""          msg_regex = (              r"1 logs of DEBUG or higher were triggered on root:\n" -            r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">' +            r'<LogRecord: tests\.test_base, [\d]+, .+[/\\]tests[/\\]test_base\.py, [\d]+, "Log!">'          )          with self.assertRaisesRegex(AssertionError, msg_regex):              with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7894e104a..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,4 @@  import asyncio -import inspect  import unittest  import unittest.mock @@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):          with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):              asyncio.run(coroutine_object) +    def test_user_mock_uses_explicitly_passed_mention_attribute(self): +        """MockUser should use an explicitly passed value for user.mention.""" +        user = helpers.MockUser(mention="hello") +        self.assertEqual(user.mention, "hello") +  class MockObjectTests(unittest.TestCase):      """Tests the mock objects and mixins we've defined.""" @@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase):                  attribute = getattr(mock, valid_attribute)                  self.assertTrue(isinstance(attribute, mock_type.child_mock_type)) -    def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self): -        """Test if all coroutine functions are extracted, but not regular methods or attributes.""" -        class CoroutineDonor: -            def __init__(self): -                self.some_attribute = 'alpha' - -            async def first_coroutine(): -                """This coroutine function should be extracted.""" - -            async def second_coroutine(): -                """This coroutine function should be extracted.""" - -            def regular_method(): -                """This regular function should not be extracted.""" - -        class Receiver: +    def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self): +        """The CustomMockMixin should mock async magic methods with an AsyncMock.""" +        class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):              pass -        donor = CoroutineDonor() -        receiver = Receiver() - -        helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor) - -        self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock) -        self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock) -        self.assertFalse(hasattr(receiver, 'regular_method')) -        self.assertFalse(hasattr(receiver, 'some_attribute')) - -    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) -    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") -    def test_custom_mock_mixin_init_with_spec(self, extract_method_mock): -        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" -        spec_set = "pydis" - -        helpers.CustomMockMixin(spec_set=spec_set) - -        extract_method_mock.assert_called_once_with(spec_set) - -    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) -    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") -    def test_custom_mock_mixin_init_without_spec(self, extract_method_mock): -        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" -        helpers.CustomMockMixin() - -        extract_method_mock.assert_not_called() - -    def test_async_mock_provides_coroutine_for_dunder_call(self): -        """Test if AsyncMock objects have a coroutine for their __call__ method.""" -        async_mock = helpers.AsyncMock() -        self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__)) - -        coroutine = async_mock() -        self.assertTrue(inspect.iscoroutine(coroutine)) -        self.assertIsNotNone(asyncio.run(coroutine)) - -    def test_async_test_decorator_allows_synchronous_call_to_async_def(self): -        """Test if the `async_test` decorator allows an `async def` to be called synchronously.""" -        @helpers.async_test -        async def kosayoda(): -            return "return value" - -        self.assertEqual(kosayoda(), "return value") +        mock = MyMock() +        self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py deleted file mode 100644 index 4baa6395c..000000000 --- a/tests/utils/test_time.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from datetime import datetime, timezone -from unittest.mock import patch - -import pytest -from dateutil.relativedelta import relativedelta - -from bot.utils import time -from tests.helpers import AsyncMock - - -    ('delta', 'precision', 'max_units', 'expected'), -    ( -        (relativedelta(days=2), 'seconds', 1, '2 days'), -        (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), -        (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), -        (relativedelta(days=2, hours=2), 'days', 2, '2 days'), - -        # Does not abort for unknown units, as the unit name is checked -        # against the attribute of the relativedelta instance. -        (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), - -        # Very high maximum units, but it only ever iterates over -        # each value the relativedelta might have. -        (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), -    ) -) -def test_humanize_delta( -        delta: relativedelta, -        precision: str, -        max_units: int, -        expected: str -): -    assert time.humanize_delta(delta, precision, max_units) == expected - - [email protected]('max_units', (-1, 0)) -def test_humanize_delta_raises_for_invalid_max_units(max_units: int): -    with pytest.raises(ValueError, match='max_units must be positive'): -        time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - - -    ('stamp', 'expected'), -    ( -        ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), -    ) -) -def test_parse_rfc1123(stamp: str, expected: str): -    assert time.parse_rfc1123(stamp) == expected - - -@patch('asyncio.sleep', new_callable=AsyncMock) -def test_wait_until(sleep_patch): -    start = datetime(2019, 1, 1, 0, 0) -    then = datetime(2019, 1, 1, 0, 10) - -    # No return value -    assert asyncio.run(time.wait_until(then, start)) is None - -    sleep_patch.assert_called_once_with(10 * 60) | 
