diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 29 | 
1 files changed, 23 insertions, 6 deletions
| diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 6e06df0a8..78ad996f2 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -19,10 +19,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          self.bot = MockBot()          self.cog = antimalware.AntiMalware(self.bot)          self.message = MockMessage() +        AntiMalwareConfig.whitelist = [".first", ".second", ".third"]      async def test_message_with_allowed_attachment(self):          """Messages with allowed extensions should not be deleted""" -        attachment = MockAttachment(filename=f"python.{AntiMalwareConfig.whitelist[0]}") +        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")          self.message.attachments = [attachment]          await self.cog.on_message(self.message) @@ -35,7 +36,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      async def test_direct_message_with_attachment(self):          """Direct messages should have no action taken.""" -        attachment = MockAttachment(filename="python.asdfsff") +        attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment]          self.message.guild = None @@ -45,7 +46,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      async def test_message_with_illegal_extension_gets_deleted(self):          """A message containing an illegal extension should send an embed.""" -        attachment = MockAttachment(filename="python.asdfsff") +        attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment]          await self.cog.on_message(self.message) @@ -56,7 +57,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):          """A message send by a member of staff should be ignored."""          staff_role = MockRole(id=STAFF_ROLES[0])          self.message.author.roles.append(staff_role) -        attachment = MockAttachment(filename="python.asdfsff") +        attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment]          await self.cog.on_message(self.message) @@ -103,7 +104,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      async def test_removing_deleted_message_logs(self):          """Removing an already deleted message logs the correct message""" -        attachment = MockAttachment(filename="python.asdfsff") +        attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment]          self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) @@ -117,7 +118,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      async def test_message_with_illegal_attachment_logs(self):          """Deleting a message with an illegal attachment should result in a log.""" -        attachment = MockAttachment(filename="python.asdfsff") +        attachment = MockAttachment(filename="python.disallowed")          self.message.attachments = [attachment]          file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} @@ -135,6 +136,22 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):              logs.output          ) +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            (AntiMalwareConfig.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], [".disallowed"]), +            ([".disallowed"], [".disallowed"]), +            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] +                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) +  class AntiMalwareSetupTests(unittest.TestCase):      """Tests setup of the `AntiMalware` cog.""" | 
