diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/exts/filtering/test_extension_filter.py | 30 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/test_snekbox.py | 8 |
2 files changed, 19 insertions, 19 deletions
diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 0ad41116d..351daa0b4 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -45,9 +45,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) @@ -62,9 +62,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_illegal_extension(self): """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) @@ -72,11 +72,11 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_python_file_redirect_embed_description(self): """A message containing a .py file should result in an embed redirecting the user to our paste site.""" attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) - self.assertEqual(self.ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) @patch("bot.instance", BOT) async def test_txt_file_redirect_embed_description(self): @@ -91,12 +91,12 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.TXT_EMBED_DESCRIPTION.format( blocked_extension=disallowed_extension, ) @@ -106,13 +106,13 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) meta_channel = BOT.get_channel(Channels.meta) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.DISALLOWED_EMBED_DESCRIPTION.format( joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", @@ -134,6 +134,6 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{ext}") for ext in extensions] - result = await self.filter_list.actions_for(self.ctx) + ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) + result = await self.filter_list.actions_for(ctx) self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") |