diff options
author | 2020-05-30 15:32:33 +0200 | |
---|---|---|
committer | 2020-05-30 15:32:33 +0200 | |
commit | f66e5ab4303c7752a614c6aa923edda04796fdbb (patch) | |
tree | 03e4756bd9b67dfb92b33791065190106dfc5492 /tests | |
parent | Merge pull request #972 from Numerlor/tag-encoding (diff) | |
parent | Merge branch 'master' into test_antimalware (diff) |
Merge pull request #930 from MrGrote/test_antimalware
Add tests for the antimalware cog
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/cogs/test_antimalware.py | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extention_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (AntiMalwareConfig.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() |