aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/filters/antimalware.py11
-rw-r--r--tests/bot/exts/filters/test_antimalware.py45
2 files changed, 38 insertions, 18 deletions
diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py
index 26f00e91f..89e539e7b 100644
--- a/bot/exts/filters/antimalware.py
+++ b/bot/exts/filters/antimalware.py
@@ -15,9 +15,11 @@ PY_EMBED_DESCRIPTION = (
f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
)
+TXT_LIKE_FILES = {".txt", ".csv", ".json"}
TXT_EMBED_DESCRIPTION = (
"**Uh-oh!** It looks like your message got zapped by our spam filter. "
- "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n"
+ "We currently don't allow `{blocked_extension}` attachments, "
+ "so here are some tips to help you travel safely: \n\n"
"• If you attempted to send a message longer than 2000 characters, try shortening your message "
"to fit within the character limit or use a pasting service (see below) \n\n"
"• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in "
@@ -70,10 +72,13 @@ class AntiMalware(Cog):
if ".py" in extensions_blocked:
# Short-circuit on *.py files to provide a pastebin link
embed.description = PY_EMBED_DESCRIPTION
- elif ".txt" in extensions_blocked:
+ elif extensions := TXT_LIKE_FILES.intersection(extensions_blocked):
# Work around Discord AutoConversion of messages longer than 2000 chars to .txt
cmd_channel = self.bot.get_channel(Channels.bot_commands)
- embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention)
+ embed.description = TXT_EMBED_DESCRIPTION.format(
+ blocked_extension=extensions.pop(),
+ cmd_channel_mention=cmd_channel.mention
+ )
elif extensions_blocked:
meta_channel = self.bot.get_channel(Channels.meta)
embed.description = DISALLOWED_EMBED_DESCRIPTION.format(
diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py
index 3393c6cdc..06d78de9d 100644
--- a/tests/bot/exts/filters/test_antimalware.py
+++ b/tests/bot/exts/filters/test_antimalware.py
@@ -104,24 +104,39 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
async def test_txt_file_redirect_embed_description(self):
- """A message containing a .txt file should result in the correct embed."""
- attachment = MockAttachment(filename="python.txt")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
- antimalware.TXT_EMBED_DESCRIPTION = Mock()
- antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
- cmd_channel = self.bot.get_channel(Channels.bot_commands)
+ """A message containing a .txt/.json/.csv file should result in the correct embed."""
+ test_values = (
+ ("text", ".txt"),
+ ("json", ".json"),
+ ("csv", ".csv"),
+ )
- self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
- antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+ for file_name, disallowed_extension in test_values:
+ with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
+
+ attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(
+ embed.description,
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value
+ )
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extension=disallowed_extension,
+ cmd_channel_mention=cmd_channel.mention
+ )
async def test_other_disallowed_extension_embed_description(self):
- """Test the description for a non .py/.txt disallowed extension."""
+ """Test the description for a non .py/.txt/.json/.csv disallowed extension."""
attachment = MockAttachment(filename="python.disallowed")
self.message.attachments = [attachment]
self.message.channel.send = AsyncMock()