aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Sebastiaan Zeeff <[email protected]>2020-05-30 15:32:33 +0200
committerGravatar GitHub <[email protected]>2020-05-30 15:32:33 +0200
commitf66e5ab4303c7752a614c6aa923edda04796fdbb (patch)
tree03e4756bd9b67dfb92b33791065190106dfc5492
parentMerge pull request #972 from Numerlor/tag-encoding (diff)
parentMerge branch 'master' into test_antimalware (diff)
Merge pull request #930 from MrGrote/test_antimalware
Add tests for the antimalware cog
-rw-r--r--bot/cogs/antimalware.py55
-rw-r--r--tests/bot/cogs/test_antimalware.py159
2 files changed, 194 insertions, 20 deletions
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 66b5073e8..ea257442e 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -1,4 +1,5 @@
import logging
+import typing as t
from os.path import splitext
from discord import Embed, Message, NotFound
@@ -9,6 +10,27 @@ from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLE
log = logging.getLogger(__name__)
+PY_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach a Python file - "
+ f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
+)
+
+TXT_EMBED_DESCRIPTION = (
+ "**Uh-oh!** It looks like your message got zapped by our spam filter. "
+ "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n"
+ "• If you attempted to send a message longer than 2000 characters, try shortening your message "
+ "to fit within the character limit or use a pasting service (see below) \n\n"
+ "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in "
+ "{cmd_channel_mention} for more information) or use a pasting service like: "
+ f"\n\n{URLs.site_schema}{URLs.site_paste}"
+)
+
+DISALLOWED_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
+ f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n"
+ "Feel free to ask in {meta_channel_mention} if you think this is a mistake."
+)
+
class AntiMalware(Cog):
"""Delete messages which contain attachments with non-whitelisted file extensions."""
@@ -29,34 +51,20 @@ class AntiMalware(Cog):
return
embed = Embed()
- file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
- extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ extensions_blocked = self.get_disallowed_extensions(message)
blocked_extensions_str = ', '.join(extensions_blocked)
if ".py" in extensions_blocked:
# Short-circuit on *.py files to provide a pastebin link
- embed.description = (
- "It looks like you tried to attach a Python file - "
- f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
- )
+ embed.description = PY_EMBED_DESCRIPTION
elif ".txt" in extensions_blocked:
# Work around Discord AutoConversion of messages longer than 2000 chars to .txt
cmd_channel = self.bot.get_channel(Channels.bot_commands)
- embed.description = (
- "**Uh-oh!** It looks like your message got zapped by our spam filter. "
- "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n"
- "• If you attempted to send a message longer than 2000 characters, try shortening your message "
- "to fit within the character limit or use a pasting service (see below) \n\n"
- "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in "
- f"{cmd_channel.mention} for more information) or use a pasting service like: "
- f"\n\n{URLs.site_schema}{URLs.site_paste}"
- )
+ embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention)
elif extensions_blocked:
- whitelisted_types = ', '.join(AntiMalwareConfig.whitelist)
meta_channel = self.bot.get_channel(Channels.meta)
- embed.description = (
- f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
- f"We currently allow the following file types: **{whitelisted_types}**.\n\n"
- f"Feel free to ask in {meta_channel.mention} if you think this is a mistake."
+ embed.description = DISALLOWED_EMBED_DESCRIPTION.format(
+ blocked_extensions_str=blocked_extensions_str,
+ meta_channel_mention=meta_channel.mention,
)
if embed.description:
@@ -73,6 +81,13 @@ class AntiMalware(Cog):
except NotFound:
log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
+ @classmethod
+ def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]:
+ """Get an iterable containing all the disallowed extensions of attachments."""
+ file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
+ extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ return extensions_blocked
+
def setup(bot: Bot) -> None:
"""Load the AntiMalware cog."""
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
new file mode 100644
index 000000000..f219fc1ba
--- /dev/null
+++ b/tests/bot/cogs/test_antimalware.py
@@ -0,0 +1,159 @@
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from discord import NotFound
+
+from bot.cogs import antimalware
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
+
+MODULE = "bot.cogs.antimalware"
+
+
+@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
+class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
+ """Test the AntiMalware cog."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = MockBot()
+ self.cog = antimalware.AntiMalware(self.bot)
+ self.message = MockMessage()
+
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should not be deleted"""
+ attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_message_without_attachment(self):
+ """Messages without attachments should result in no action."""
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_direct_message_with_attachment(self):
+ """Direct messages should have no action taken."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.guild = None
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_message_with_illegal_extension_gets_deleted(self):
+ """A message containing an illegal extension should send an embed."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_called_once()
+
+ async def test_message_send_by_staff(self):
+ """A message send by a member of staff should be ignored."""
+ staff_role = MockRole(id=STAFF_ROLES[0])
+ self.message.author.roles.append(staff_role)
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site"""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+
+ self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
+
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt file should result in the correct embed."""
+ attachment = MockAttachment(filename="python.txt")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+
+ async def test_other_disallowed_extention_embed_description(self):
+ """Test the description for a non .py/.txt disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ meta_channel = self.bot.get_channel(Channels.meta)
+
+ self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+
+ async def test_removing_deleted_message_logs(self):
+ """Removing an already deleted message logs the correct message"""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_called_once()
+
+ async def test_message_with_illegal_attachment_logs(self):
+ """Deleting a message with an illegal attachment should result in a log."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (AntiMalwareConfig.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], [".disallowed"]),
+ ([".disallowed"], [".disallowed"]),
+ ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
+ disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
+
+
+class AntiMalwareSetupTests(unittest.TestCase):
+ """Tests setup of the `AntiMalware` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ antimalware.setup(bot)
+ bot.add_cog.assert_called_once()