diff options
| author | 2020-03-03 21:05:18 -0500 | |
|---|---|---|
| committer | 2020-03-03 21:05:18 -0500 | |
| commit | 524692f49f4c98402b8b94ff8dd55d95b89f8fc8 (patch) | |
| tree | 65e98a5ceb276e099226c55d79cd691f44382e5b /tests/bot | |
| parent | Add logging to antimalware cog & expand user feedback (diff) | |
| parent | Merge pull request #750 from python-discord/bug/backend/b748/resolver-in-coro (diff) | |
Merge branch 'master' into antimalware-logging
Diffstat (limited to 'tests/bot')
23 files changed, 443 insertions, 190 deletions
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index e6a6f9688..fe0594efe 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -13,8 +13,8 @@ class TestSyncer(Syncer): """Syncer subclass with mocks for abstract methods for testing purposes.""" name = "test" - _get_diff = helpers.AsyncMock() - _sync = helpers.AsyncMock() + _get_diff = mock.AsyncMock() + _sync = mock.AsyncMock() class SyncerBaseTests(unittest.TestCase): @@ -29,7 +29,7 @@ class SyncerBaseTests(unittest.TestCase): Syncer(self.bot) -class SyncerSendPromptTests(unittest.TestCase): +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): """Tests for sending the sync confirmation prompt.""" def setUp(self): @@ -61,7 +61,6 @@ class SyncerSendPromptTests(unittest.TestCase): return mock_channel, mock_message - @helpers.async_test async def test_send_prompt_edits_and_returns_message(self): """The given message should be edited to display the prompt and then should be returned.""" msg = helpers.MockMessage() @@ -71,7 +70,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIn("content", msg.edit.call_args[1]) self.assertEqual(ret_val, msg) - @helpers.async_test async def test_send_prompt_gets_dev_core_channel(self): """The dev-core channel should be retrieved if an extant message isn't given.""" subtests = ( @@ -84,9 +82,8 @@ class SyncerSendPromptTests(unittest.TestCase): mock_() await self.syncer._send_prompt() - method.assert_called_once_with(constants.Channels.devcore) + method.assert_called_once_with(constants.Channels.dev_core) - @helpers.async_test async def test_send_prompt_returns_None_if_channel_fetch_fails(self): """None should be returned if there's an HTTPException when fetching the channel.""" self.bot.get_channel.return_value = None @@ -96,7 +93,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIsNone(ret_val) - @helpers.async_test async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): """A new message mentioning core devs should be sent and returned if message isn't given.""" for mock_ in (self.mock_get_channel, self.mock_fetch_channel): @@ -108,7 +104,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) self.assertEqual(ret_val, mock_message) - @helpers.async_test async def test_send_prompt_adds_reactions(self): """The message should have reactions for confirmation added.""" extant_message = helpers.MockMessage() @@ -129,13 +124,13 @@ class SyncerSendPromptTests(unittest.TestCase): mock_message.add_reaction.assert_has_calls(calls) -class SyncerConfirmationTests(unittest.TestCase): +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): """Tests for waiting for a sync confirmation reaction on the prompt.""" def setUp(self): self.bot = helpers.MockBot() self.syncer = TestSyncer(self.bot) - self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) @staticmethod def get_message_reaction(emoji): @@ -211,7 +206,6 @@ class SyncerConfirmationTests(unittest.TestCase): ret_val = self.syncer._reaction_check(*args) self.assertFalse(ret_val) - @helpers.async_test async def test_wait_for_confirmation(self): """The message should always be edited and only return True if the emoji is a check mark.""" subtests = ( @@ -251,14 +245,13 @@ class SyncerConfirmationTests(unittest.TestCase): self.assertIs(actual_return, ret_val) -class SyncerSyncTests(unittest.TestCase): +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for main function orchestrating the sync.""" def setUp(self): self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) self.syncer = TestSyncer(self.bot) - @helpers.async_test async def test_sync_respects_confirmation_result(self): """The sync should abort if confirmation fails and continue if confirmed.""" mock_message = helpers.MockMessage() @@ -274,7 +267,7 @@ class SyncerSyncTests(unittest.TestCase): diff = _Diff({1, 2, 3}, {4, 5}, None) self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = helpers.AsyncMock( + self.syncer._get_confirmation_result = mock.AsyncMock( return_value=(confirmed, message) ) @@ -289,7 +282,6 @@ class SyncerSyncTests(unittest.TestCase): else: self.syncer._sync.assert_not_called() - @helpers.async_test async def test_sync_diff_size(self): """The diff size should be correctly calculated.""" subtests = ( @@ -303,7 +295,7 @@ class SyncerSyncTests(unittest.TestCase): with self.subTest(size=size, diff=diff): self.syncer._get_diff.reset_mock() self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) guild = helpers.MockGuild() await self.syncer.sync(guild) @@ -312,7 +304,6 @@ class SyncerSyncTests(unittest.TestCase): self.syncer._get_confirmation_result.assert_called_once() self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - @helpers.async_test async def test_sync_message_edited(self): """The message should be edited if one was sent, even if the sync has an API error.""" subtests = ( @@ -324,7 +315,7 @@ class SyncerSyncTests(unittest.TestCase): for message, side_effect, should_edit in subtests: with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = helpers.AsyncMock( + self.syncer._get_confirmation_result = mock.AsyncMock( return_value=(True, message) ) @@ -335,7 +326,6 @@ class SyncerSyncTests(unittest.TestCase): message.edit.assert_called_once() self.assertIn("content", message.edit.call_args[1]) - @helpers.async_test async def test_sync_confirmation_context_redirect(self): """If ctx is given, a new message should be sent and author should be ctx's author.""" mock_member = helpers.MockMember() @@ -349,7 +339,10 @@ class SyncerSyncTests(unittest.TestCase): if ctx is not None: ctx.send.return_value = message - self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) guild = helpers.MockGuild() await self.syncer.sync(guild, ctx) @@ -362,16 +355,15 @@ class SyncerSyncTests(unittest.TestCase): self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) @mock.patch.object(constants.Sync, "max_diff", new=3) - @helpers.async_test async def test_confirmation_result_small_diff(self): """Should always return True and the given message if the diff size is too small.""" author = helpers.MockMember() expected_message = helpers.MockMessage() - for size in (3, 2): + for size in (3, 2): # pragma: no cover with self.subTest(size=size): - self.syncer._send_prompt = helpers.AsyncMock() - self.syncer._wait_for_confirmation = helpers.AsyncMock() + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() coro = self.syncer._get_confirmation_result(size, author, expected_message) result, actual_message = await coro @@ -382,7 +374,6 @@ class SyncerSyncTests(unittest.TestCase): self.syncer._wait_for_confirmation.assert_not_called() @mock.patch.object(constants.Sync, "max_diff", new=3) - @helpers.async_test async def test_confirmation_result_large_diff(self): """Should return True if confirmed and False if _send_prompt fails or aborted.""" author = helpers.MockMember() @@ -394,10 +385,10 @@ class SyncerSyncTests(unittest.TestCase): (False, mock_message, False, "aborted"), ) - for expected_result, expected_message, confirmed, msg in subtests: + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover with self.subTest(msg=msg): - self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) coro = self.syncer._get_confirmation_result(4, author) actual_result, actual_message = await coro diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 98c9afc0d..81398c61f 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -11,19 +11,7 @@ from tests import helpers from tests.base import CommandTestCase -class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): - """ - A MagicMock subclass to mock Syncer objects. - - Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` - instances. For more information, see the `MockGuild` docstring. - """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=Syncer, **kwargs) - - -class SyncExtensionTests(unittest.TestCase): +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod @@ -34,22 +22,21 @@ class SyncExtensionTests(unittest.TestCase): bot.add_cog.assert_called_once() -class SyncCogTestCase(unittest.TestCase): +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): """Base class for Sync cog tests. Sets up patches for syncers.""" def setUp(self): self.bot = helpers.MockBot() - # These patch the type. When the type is called, a MockSyncer instanced is returned. - # MockSyncer is needed so that our custom AsyncMock is used. - # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. self.role_syncer_patcher = mock.patch( "bot.cogs.sync.syncers.RoleSyncer", - new=mock.MagicMock(return_value=MockSyncer()) + autospec=Syncer, + spec_set=True ) self.user_syncer_patcher = mock.patch( "bot.cogs.sync.syncers.UserSyncer", - new=mock.MagicMock(return_value=MockSyncer()) + autospec=Syncer, + spec_set=True ) self.RoleSyncer = self.role_syncer_patcher.start() self.UserSyncer = self.user_syncer_patcher.start() @@ -72,13 +59,13 @@ class SyncCogTestCase(unittest.TestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch.object(sync.Sync, "sync_guild") + @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild): """Should instantiate syncers and run a sync for the guild.""" # Reset because a Sync cog was already instantiated in setUp. self.RoleSyncer.reset_mock() self.UserSyncer.reset_mock() - self.bot.loop.create_task.reset_mock() + self.bot.loop.create_task = mock.MagicMock() mock_sync_guild_coro = mock.MagicMock() sync_guild.return_value = mock_sync_guild_coro @@ -90,7 +77,6 @@ class SyncCogTests(SyncCogTestCase): sync_guild.assert_called_once_with() self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - @helpers.async_test async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): @@ -126,14 +112,12 @@ class SyncCogTests(SyncCogTestCase): json=updated_information, ) - @helpers.async_test async def test_sync_cog_patch_user(self): """A PATCH request should be sent and 404 errors ignored.""" for side_effect in (None, self.response_error(404)): with self.subTest(side_effect=side_effect): await self.patch_user_helper(side_effect) - @helpers.async_test async def test_sync_cog_patch_user_non_404(self): """A PATCH request should be sent and the error raised if it's not a 404.""" with self.assertRaises(ResponseCodeError): @@ -145,9 +129,8 @@ class SyncCogListenerTests(SyncCogTestCase): def setUp(self): super().setUp() - self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - @helpers.async_test async def test_sync_cog_on_guild_role_create(self): """A POST request should be sent with the new role's data.""" self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -164,7 +147,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - @helpers.async_test async def test_sync_cog_on_guild_role_delete(self): """A DELETE request should be sent.""" self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) @@ -174,7 +156,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - @helpers.async_test async def test_sync_cog_on_guild_role_update(self): """A PUT request should be sent if the colour, name, permissions, or position changes.""" self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -212,7 +193,6 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.put.assert_not_called() - @helpers.async_test async def test_sync_cog_on_member_remove(self): """Member should patched to set in_guild as False.""" self.assertTrue(self.cog.on_member_remove.__cog_listener__) @@ -225,7 +205,6 @@ class SyncCogListenerTests(SyncCogTestCase): updated_information={"in_guild": False} ) - @helpers.async_test async def test_sync_cog_on_member_update_roles(self): """Members should be patched if their roles have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -240,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase): data = {"roles": sorted(role.id for role in after_member.roles)} self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) - @helpers.async_test async def test_sync_cog_on_member_update_other(self): """Members should not be patched if other attributes have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -262,7 +240,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.cog.patch_user.assert_not_called() - @helpers.async_test async def test_sync_cog_on_user_update(self): """A user should be patched only if the name, discriminator, or avatar changes.""" self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -341,7 +318,6 @@ class SyncCogListenerTests(SyncCogTestCase): return data - @helpers.async_test async def test_sync_cog_on_member_join(self): """Should PUT user's data or POST it if the user doesn't exist.""" for side_effect in (None, self.response_error(404)): @@ -354,7 +330,6 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.post.assert_not_called() - @helpers.async_test async def test_sync_cog_on_member_join_non_404(self): """ResponseCodeError should be re-raised if status code isn't a 404.""" with self.assertRaises(ResponseCodeError): @@ -366,7 +341,6 @@ class SyncCogListenerTests(SyncCogTestCase): class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): """Tests for the commands in the Sync cog.""" - @helpers.async_test async def test_sync_roles_command(self): """sync() should be called on the RoleSyncer.""" ctx = helpers.MockContext() @@ -374,7 +348,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - @helpers.async_test async def test_sync_users_command(self): """sync() should be called on the UserSyncer.""" ctx = helpers.MockContext() @@ -382,7 +355,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - def test_commands_require_admin(self): + async def test_commands_require_admin(self): """The sync commands should only run if the author has the administrator permission.""" cmds = ( self.cog.sync_group, @@ -392,4 +365,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): for cmd in cmds: with self.subTest(cmd=cmd): - self.assertHasPermissionsCheck(cmd, {"administrator": True}) + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 14fb2577a..79eee98f4 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -18,7 +18,7 @@ def fake_role(**kwargs): return kwargs -class RoleSyncerDiffTests(unittest.TestCase): +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between roles in the DB and roles in the Guild cache.""" def setUp(self): @@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase): return guild - @helpers.async_test async def test_empty_diff_for_identical_roles(self): """No differences should be found if the roles in the guild and DB are identical.""" self.bot.api_client.get.return_value = [fake_role()] @@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_updated_roles(self): """Only updated roles should be added to the 'updated' set of the diff.""" updated_role = fake_role(id=41, name="new") @@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_roles(self): """Only new roles should be added to the 'created' set of the diff.""" new_role = fake_role(id=41, name="new") @@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_deleted_roles(self): """Only deleted roles should be added to the 'deleted' set of the diff.""" deleted_role = fake_role(id=61, name="deleted") @@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_updated_and_deleted_roles(self): """When roles are added, updated, and removed, all of them are returned properly.""" new = fake_role(id=41, name="new") @@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) -class RoleSyncerSyncTests(unittest.TestCase): +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync roles.""" def setUp(self): self.bot = helpers.MockBot() self.syncer = RoleSyncer(self.bot) - @helpers.async_test async def test_sync_created_roles(self): """Only POST requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] @@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase): self.bot.api_client.put.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_updated_roles(self): """Only PUT requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] @@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase): self.bot.api_client.post.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_deleted_roles(self): """Only DELETE requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 421bf6bb6..818883012 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -17,7 +17,7 @@ def fake_user(**kwargs): return kwargs -class UserSyncerDiffTests(unittest.TestCase): +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between users in the DB and users in the Guild cache.""" def setUp(self): @@ -42,7 +42,6 @@ class UserSyncerDiffTests(unittest.TestCase): return guild - @helpers.async_test async def test_empty_diff_for_no_users(self): """When no users are given, an empty diff should be returned.""" guild = self.get_guild() @@ -52,7 +51,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_empty_diff_for_identical_users(self): """No differences should be found if the users in the guild and DB are identical.""" self.bot.api_client.get.return_value = [fake_user()] @@ -63,7 +61,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_updated_users(self): """Only updated users should be added to the 'updated' set of the diff.""" updated_user = fake_user(id=99, name="new") @@ -76,7 +73,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_users(self): """Only new users should be added to the 'created' set of the diff.""" new_user = fake_user(id=99, name="new") @@ -89,7 +85,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" leaving_user = fake_user(id=63, in_guild=False) @@ -102,7 +97,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_updated_and_leaving_users(self): """When users are added, updated, and removed, all of them are returned properly.""" new_user = fake_user(id=99, name="new") @@ -117,7 +111,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_empty_diff_for_db_users_not_in_guild(self): """When the DB knows a user the guild doesn't, no difference is found.""" self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] @@ -129,14 +122,13 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) -class UserSyncerSyncTests(unittest.TestCase): +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync users.""" def setUp(self): self.bot = helpers.MockBot() self.syncer = UserSyncer(self.bot) - @helpers.async_test async def test_sync_created_users(self): """Only POST requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] @@ -152,7 +144,6 @@ class UserSyncerSyncTests(unittest.TestCase): self.bot.api_client.put.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_updated_users(self): """Only PUT requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 5b0a3b8c3..7e6bfc748 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio import logging import typing import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import discord @@ -14,7 +14,7 @@ from tests import helpers MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): """Tests for DuckPond functionality.""" @classmethod @@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase): with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): self.assertEqual(expected_return, actual_return) - @helpers.async_test async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): """The `has_green_checkmark` method should only return `True` if one is present.""" test_cases = ( @@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase): nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - @helpers.async_test async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): """The `count_ducks` method should return the number of unique staffers who gave a duck.""" test_cases = ( @@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase): with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): self.assertEqual(expected_count, actual_count) - @helpers.async_test async def test_relay_message_correctly_relays_content_and_attachments(self): """The `relay_message` method should correctly relay message content and attachments.""" send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" @@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase): ) for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: with self.subTest(clean_content=message.clean_content, attachments=message.attachments): await self.cog.relay_message(message) @@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase): message.add_reaction.assert_called_once_with(self.checkmark_emoji) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase): self.cog.webhook = helpers.MockAsyncWebhook() log = logging.getLogger("bot.cogs.duck_pond") - for side_effect in side_effects: + for side_effect in side_effects: # pragma: no cover send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: with self.subTest(side_effect=type(side_effect).__name__): with self.assertNotLogs(logger=log, level=logging.ERROR): await self.cog.relay_message(message) self.assertEqual(send_webhook.call_count, 2) - @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase): payload.emoji.name = emoji_name return payload - @helpers.async_test async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" test_values = ( @@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase): return channel, message, member, payload - @helpers.async_test async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" channel_id = 1234 @@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase): channel.fetch_message.reset_mock() @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" channel_id = 31415926535 @@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase): # Assert that we've made it past `self.is_staff` is_staff.assert_called_once() - @helpers.async_test async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" test_cases = ( @@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase): payload.emoji = self.duck_pond_emoji for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_relay=should_relay): await self.cog.on_raw_reaction_add(payload) @@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase): if should_relay: relay_message.assert_called_once_with(message) - @helpers.async_test async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) @@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase): (constants.DuckPond.threshold + 1, True), ) for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index deae7ebad..5693d2946 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase): @classmethod def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator) + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) def setUp(self): """Sets up fresh objects for each test.""" @@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase): """Test if the `role_info` command correctly returns the `moderator_role`.""" self.ctx.guild.roles.append(self.moderator_role) - self.cog.roles_info.can_run = helpers.AsyncMock() + self.cog.roles_info.can_run = unittest.mock.AsyncMock() self.cog.roles_info.can_run.return_value = True coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase): self.ctx.guild.roles.append([dummy_role, admin_role]) - self.cog.role_info.can_run = helpers.AsyncMock() + self.cog.role_info.can_run = unittest.mock.AsyncMock() self.cog.role_info.can_run.return_value = True coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) self.member = helpers.MockMember(id=1234) @@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Mr. Hemlock") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase): self.assertIn("&Admins", embed.description) self.assertNotIn("&Everyone", embed.description) - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): """The embed should contain expanded infractions and nomination info in mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): """The embed should contain only basic infraction data outside of mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() @@ -521,7 +521,7 @@ class UserCommandTests(unittest.TestCase): """A regular user should not be able to use this command outside of bot-commands.""" constants.MODERATION_ROLES = [self.moderator_role.id] constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) @@ -529,11 +529,11 @@ class UserCommandTests(unittest.TestCase): with self.assertRaises(InChannelCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -542,11 +542,11 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -555,11 +555,11 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) @@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.moderator) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_moderators_can_target_another_member(self, create_embed, constants): """A moderator should be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py new file mode 100644 index 000000000..9cd7f0154 --- /dev/null +++ b/tests/bot/cogs/test_snekbox.py @@ -0,0 +1,354 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch + +from bot.cogs import snekbox +from bot.cogs.snekbox import Snekbox +from bot.constants import URLs +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + self.bot.http_session.post().__aenter__.return_value = resp + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + self.bot.http_session.post().__aenter__.return_value = resp + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + self.bot.http_session.post().__aenter__.return_value = resp + + log = logging.getLogger("bot.cogs.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_Signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.cogs.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_Signals: Mock): + mock_Signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'), + ( + '\u202E\u202E\u202E', + ('Code block escape attempt detected; will not output result', None), + 'Detect RIGHT-TO-LEFT OVERRIDE' + ), + ( + '\u200B\u200B\u200B', + ('Code block escape attempt detected; will not output result', None), + 'Detect ZERO WIDTH SPACE' + ), + ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'), + ( + 'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd', + (too_many_lines, 'https://testificate.com/'), + '12 lines output' + ), + ( + 'verylongbeard' * 100, + ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'), + '1300 characters output' + ), + ( + ('verylongbeard' * 10 + '\n') * 15, + (too_long_too_many_lines, 'https://testificate.com/'), + '15 lines, 1965 characters output' + ), + ) + for case, expected, testname in cases: + with self.subTest(msg=testname, case=case, expected=expected): + self.assertEqual(await self.cog.format_output(case), expected) + + async def test_eval_command_evaluate_once(self): + """Test the eval command procedure.""" + ctx = MockContext() + response = MockMessage() + self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.send_eval = AsyncMock(return_value=response) + self.cog.continue_eval = AsyncMock(return_value=None) + + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') + self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') + self.cog.continue_eval.assert_called_once_with(ctx, response) + + async def test_eval_command_evaluate_twice(self): + """Test the eval and re-eval command procedure.""" + ctx = MockContext() + response = MockMessage() + self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.send_eval = AsyncMock(return_value=response) + self.cog.continue_eval = AsyncMock() + self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) + self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') + self.cog.continue_eval.assert_called_with(ctx, response) + + async def test_eval_command_reject_two_eval_at_the_same_time(self): + """Test if the eval command rejects an eval if the author already have a running eval.""" + ctx = MockContext() + ctx.author.id = 42 + ctx.author.mention = '@LemonLemonishBeard#0042' + ctx.send = AsyncMock() + self.cog.jobs = (42,) + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + ctx.send.assert_called_once_with( + "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" + ) + + async def test_eval_command_call_help(self): + """Test if the eval command call the help command if no code is provided.""" + ctx = MockContext() + ctx.invoke = AsyncMock() + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') + ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") + + async def test_send_eval(self): + """Test the send_eval function.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.format_output.assert_called_once_with('') + + async def test_send_eval_with_paste_link(self): + """Test the send_eval function with a too long output that generate a paste link.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :yay!: Return code 0.' + '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.format_output.assert_called_once_with('Way too long beard') + + async def test_send_eval_with_non_zero_eval(self): + """Test the send_eval function with a code returning a non-zero code.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) + self.cog.get_status_emoji = MagicMock(return_value=':nope!:') + self.cog.format_output = AsyncMock() # This function isn't called + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.format_output.assert_not_called() + + @patch("bot.cogs.snekbox.partial") + async def test_continue_eval_does_continue(self, partial_mock): + """Test that the continue_eval function does continue if required conditions are met.""" + ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) + response = MockMessage(delete=AsyncMock()) + new_msg = MockMessage(content='!e NewCode') + self.bot.wait_for.side_effect = ((None, new_msg), None) + + actual = await self.cog.continue_eval(ctx, response) + self.assertEqual(actual, 'NewCode') + self.bot.wait_for.assert_has_awaits( + ( + call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), + call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + ) + ) + ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.clear_reactions.assert_called_once() + response.delete.assert_called_once() + + async def test_continue_eval_does_not_continue(self): + ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) + self.bot.wait_for.side_effect = asyncio.TimeoutError + + actual = await self.cog.continue_eval(ctx, MockMessage()) + self.assertEqual(actual, None) + ctx.message.clear_reactions.assert_called_once() + + def test_predicate_eval_message_edit(self): + """Test the predicate_eval_message_edit function.""" + msg0 = MockMessage(id=1, content='abc') + msg1 = MockMessage(id=2, content='abcdef') + msg2 = MockMessage(id=1, content='abcdef') + + cases = ( + (msg0, msg0, False, 'same ID, same content'), + (msg0, msg1, False, 'different ID, different content'), + (msg0, msg2, True, 'same ID, different content') + ) + for ctx_msg, new_msg, expected, testname in cases: + with self.subTest(msg=f'Messages with {testname} return {expected}'): + ctx = MockContext(message=ctx_msg) + actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + self.assertEqual(actual, expected) + + def test_predicate_eval_emoji_reaction(self): + """Test the predicate_eval_emoji_reaction function.""" + valid_reaction = MockReaction(message=MockMessage(id=1)) + valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) + valid_user = MockUser(id=2) + + invalid_reaction_id = MockReaction(message=MockMessage(id=42)) + invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_user_id = MockUser(id=42) + invalid_reaction_str = MockReaction(message=MockMessage(id=1)) + invalid_reaction_str.__str__.return_value = ':longbeard:' + + cases = ( + (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), + (valid_reaction, invalid_user_id, False, 'invalid user ID'), + (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), + (valid_reaction, valid_user, True, 'matching attributes') + ) + for reaction, user, expected, testname in cases: + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + self.assertEqual(actual, expected) + + +class SnekboxSetupTests(unittest.TestCase): + """Tests setup of the `Snekbox` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + snekbox.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..33d1ec170 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,7 +1,7 @@ import asyncio import logging import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock from discord import Colour @@ -11,7 +11,7 @@ from bot.cogs.token_remover import ( setup as setup_cog, ) from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from tests.helpers import MockBot, MockMessage class TokenRemoverTests(unittest.TestCase): diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index 36c986fe1..0d570f5a3 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple): n_violations: int -class RuleTest(unittest.TestCase, metaclass=ABCMeta): +class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): """ Abstract class for antispam rule test cases. @@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta): @abstractmethod def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: """Give expected relevant messages for `case`.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover @abstractmethod def get_report(self, case: DisallowedCase) -> str: """Give expected error report for `case`.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index e54b4b5b8..d7e779221 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import attachments from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_attachments: int) -> MockMessage: @@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest): self.apply = attachments.apply self.config = {"max": 5, "interval": 10} - @async_test async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( @@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py index 72f0be0c7..03682966b 100644 --- a/tests/bot/rules/test_burst.py +++ b/tests/bot/rules/test_burst.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import burst from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest): self.apply = burst.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the amount of messages exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py index 47367a5f8..3275143d5 100644 --- a/tests/bot/rules/test_burst_shared.py +++ b/tests/bot/rules/test_burst_shared.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import burst_shared from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest): self.apply = burst_shared.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """ Cases that do not violate the rule. @@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the amount of messages exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py index 7cc36f49e..f1e3c76a7 100644 --- a/tests/bot/rules/test_chars.py +++ b/tests/bot/rules/test_chars.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import chars from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, n_chars: int) -> MockMessage: @@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest): "interval": 10, } - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of chars within limit.""" cases = ( @@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the total amount of chars exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 0239b0b00..9a72723e2 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import discord_emojis from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> @@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest): self.apply = discord_emojis.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of discord emojis within limit.""" cases = ( @@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with more than the allowed amount of discord emojis.""" cases = ( diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py index 59e0fb6ef..9bd886a77 100644 --- a/tests/bot/rules/test_duplicates.py +++ b/tests/bot/rules/test_duplicates.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import duplicates from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, content: str) -> MockMessage: @@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest): self.apply = duplicates.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with too many duplicate messages from the same author.""" cases = ( diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 3c3f90e5f..b091bd9d7 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import links from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_links: int) -> MockMessage: @@ -21,7 +21,6 @@ class LinksTests(RuleTest): "interval": 10 } - @async_test async def test_links_within_limit(self): """Messages with an allowed amount of links.""" cases = ( @@ -34,7 +33,6 @@ class LinksTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_links_exceeding_limit(self): """Messages with a a higher than allowed amount of links.""" cases = ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ebcdabac6..6444532f2 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_mentions: int) -> MockMessage: @@ -20,7 +20,6 @@ class TestMentions(RuleTest): "interval": 10, } - @async_test async def test_mentions_within_limit(self): """Messages with an allowed amount of mentions.""" cases = ( @@ -32,7 +31,6 @@ class TestMentions(RuleTest): await self.run_allowed(cases) - @async_test async def test_mentions_exceeding_limit(self): """Messages with a higher than allowed amount of mentions.""" cases = ( diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py index d61c4609d..e35377773 100644 --- a/tests/bot/rules/test_newlines.py +++ b/tests/bot/rules/test_newlines.py @@ -2,7 +2,7 @@ from typing import Iterable, List from bot.rules import newlines from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, newline_groups: List[int]) -> MockMessage: @@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest): "interval": 10, } - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_total(self): """Cases which violate the rule by having too many newlines in total.""" cases = ( @@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest): self.apply = newlines.apply self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - @async_test async def test_disallows_messages_consecutive(self): """Cases which violate the rule due to having too many consecutive newlines.""" cases = ( diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py index b339cccf7..26c05d527 100644 --- a/tests/bot/rules/test_role_mentions.py +++ b/tests/bot/rules/test_role_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import role_mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, n_mentions: int) -> MockMessage: @@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest): self.apply = role_mentions.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of role mentions within limit.""" cases = ( @@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with more than the allowed amount of role mentions.""" cases = ( diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index bdfcc73e4..99e942813 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -2,10 +2,9 @@ import unittest from unittest.mock import MagicMock from bot import api -from tests.helpers import async_test -class APIClientTests(unittest.TestCase): +class APIClientTests(unittest.IsolatedAsyncioTestCase): """Tests for the bot's API client.""" @classmethod @@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase): """The event loop should not be running by default.""" self.assertFalse(api.loop_is_running()) - @async_test async def test_loop_is_running_in_async_context(self): """The event loop should be running in an async context.""" self.assertTrue(api.loop_is_running()) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index b2b78d9dd..1e5ca62ae 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -68,7 +68,7 @@ class ConverterTests(unittest.TestCase): ('👋', "Don't be ridiculous, you can't use that character!"), ('', "Tag names should not be empty, or filled with whitespace."), (' ', "Tag names should not be empty, or filled with whitespace."), - ('42', "Tag names can't be numbers."), + ('42', "Tag names must contain at least one letter."), ('x' * 128, "Are you insane? That's way too long!"), ) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase): instance = utils.CaseInsensitiveDict() instance.update({'FOO': 'bar'}) self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): - """Tests the `chunk` method.""" - - def test_empty_chunking(self): - """Tests chunking on an empty iterable.""" - generator = utils.chunks(iterable=[], size=5) - self.assertEqual(list(generator), []) - - def test_list_chunking(self): - """Tests chunking a non-empty list.""" - iterable = [1, 2, 3, 4, 5] - generator = utils.chunks(iterable=iterable, size=2) - self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..694d3a40f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@ import asyncio import unittest from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from dateutil.relativedelta import relativedelta from bot.utils import time -from tests.helpers import AsyncMock class TimeTests(unittest.TestCase): @@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase): for max_units in test_cases: with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - self.assertEqual(str(error), 'max_units must be positive') + self.assertEqual(str(error.exception), 'max_units must be positive') def test_parse_rfc1123(self): """Testing parse_rfc1123.""" |