aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-07-18 16:54:01 +0200
committerGravatar Leon Sandøy <[email protected]>2020-07-18 16:54:01 +0200
commit1c569f2f38fe18d6210deec001046cf9ee68ea53 (patch)
treec8ac90eea18379b2eaea75387bf1e890d741dc72
parentRemove Filtering constants, use cache data. (diff)
Remove AntiMalWare constants, use cache data.
Also updates the tests for this cog.
-rw-r--r--bot/bot.py2
-rw-r--r--bot/cogs/antimalware.py24
-rw-r--r--bot/constants.py6
-rw-r--r--config-default.yml29
-rw-r--r--tests/bot/cogs/test_antimalware.py24
5 files changed, 30 insertions, 55 deletions
diff --git a/bot/bot.py b/bot/bot.py
index 6c02e72a7..962c8dd93 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -34,6 +34,7 @@ class Bot(commands.Bot):
self.redis_ready = asyncio.Event()
self.redis_closed = False
self.api_client = api.APIClient(loop=self.loop)
+ self.allow_deny_list_cache = {}
self._connector = None
self._resolver = None
@@ -52,7 +53,6 @@ class Bot(commands.Bot):
async def _cache_allow_deny_list_data(self) -> None:
"""Cache all the data in the AllowDenyList on the site."""
full_cache = await self.api_client.get('bot/allow_deny_lists')
- self.allow_deny_list_cache = {}
for item in full_cache:
type_ = item.get("type")
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index ea257442e..38ff1133d 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound
from discord.ext.commands import Cog
from bot.bot import Bot
-from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs
+from bot.constants import Channels, STAFF_ROLES, URLs
log = logging.getLogger(__name__)
@@ -27,7 +27,7 @@ TXT_EMBED_DESCRIPTION = (
DISALLOWED_EMBED_DESCRIPTION = (
"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
- f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n"
+ "We currently allow the following file types: **{joined_whitelist}**.\n\n"
"Feel free to ask in {meta_channel_mention} if you think this is a mistake."
)
@@ -38,6 +38,16 @@ class AntiMalware(Cog):
def __init__(self, bot: Bot):
self.bot = bot
+ def _get_whitelisted_file_formats(self) -> list:
+ """Get the file formats currently on the whitelist."""
+ return [item.get('content') for item in self.bot.allow_deny_list_cache['file_format.True']]
+
+ def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]:
+ """Get an iterable containing all the disallowed extensions of attachments."""
+ file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
+ extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats())
+ return extensions_blocked
+
@Cog.listener()
async def on_message(self, message: Message) -> None:
"""Identify messages with prohibited attachments."""
@@ -51,7 +61,7 @@ class AntiMalware(Cog):
return
embed = Embed()
- extensions_blocked = self.get_disallowed_extensions(message)
+ extensions_blocked = self._get_disallowed_extensions(message)
blocked_extensions_str = ', '.join(extensions_blocked)
if ".py" in extensions_blocked:
# Short-circuit on *.py files to provide a pastebin link
@@ -63,6 +73,7 @@ class AntiMalware(Cog):
elif extensions_blocked:
meta_channel = self.bot.get_channel(Channels.meta)
embed.description = DISALLOWED_EMBED_DESCRIPTION.format(
+ joined_whitelist=', '.join(self._get_whitelisted_file_formats()),
blocked_extensions_str=blocked_extensions_str,
meta_channel_mention=meta_channel.mention,
)
@@ -81,13 +92,6 @@ class AntiMalware(Cog):
except NotFound:
log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
- @classmethod
- def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]:
- """Get an iterable containing all the disallowed extensions of attachments."""
- file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
- extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
- return extensions_blocked
-
def setup(bot: Bot) -> None:
"""Load the AntiMalware cog."""
diff --git a/bot/constants.py b/bot/constants.py
index f5245ca50..857e6c4f0 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -527,12 +527,6 @@ class AntiSpam(metaclass=YAMLGetter):
rules: Dict[str, Dict[str, int]]
-class AntiMalware(metaclass=YAMLGetter):
- section = "anti_malware"
-
- whitelist: list
-
-
class BigBrother(metaclass=YAMLGetter):
section = 'big_brother'
diff --git a/config-default.yml b/config-default.yml
index 81c8c40d5..503cc2b52 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -386,35 +386,6 @@ anti_spam:
max: 3
-anti_malware:
- whitelist:
- - '.3gp'
- - '.3g2'
- - '.avi'
- - '.bmp'
- - '.gif'
- - '.h264'
- - '.jpg'
- - '.jpeg'
- - '.m4v'
- - '.mkv'
- - '.mov'
- - '.mp4'
- - '.mpeg'
- - '.mpg'
- - '.png'
- - '.tiff'
- - '.wmv'
- - '.svg'
- - '.psd' # Photoshop
- - '.ai' # Illustrator
- - '.aep' # After Effects
- - '.xcf' # GIMP
- - '.mp3'
- - '.wav'
- - '.ogg'
-
-
reddit:
subreddits:
- 'r/Python'
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
index f219fc1ba..1e010d2ce 100644
--- a/tests/bot/cogs/test_antimalware.py
+++ b/tests/bot/cogs/test_antimalware.py
@@ -1,28 +1,33 @@
import unittest
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import AsyncMock, Mock
from discord import NotFound
from bot.cogs import antimalware
-from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from bot.constants import Channels, STAFF_ROLES
from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
-MODULE = "bot.cogs.antimalware"
-
-@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
"""Test the AntiMalware cog."""
def setUp(self):
"""Sets up fresh objects for each test."""
self.bot = MockBot()
+ self.bot.allow_deny_list_cache = {
+ "file_format.True": [
+ {"content": ".first"},
+ {"content": ".second"},
+ {"content": ".third"}
+ ]
+ }
self.cog = antimalware.AntiMalware(self.bot)
self.message = MockMessage()
+ self.whitelist = [".first", ".second", ".third"]
async def test_message_with_allowed_attachment(self):
"""Messages with allowed extensions should not be deleted"""
- attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ attachment = MockAttachment(filename="python.first")
self.message.attachments = [attachment]
await self.cog.on_message(self.message)
@@ -93,7 +98,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
- async def test_other_disallowed_extention_embed_description(self):
+ async def test_other_disallowed_extension_embed_description(self):
"""Test the description for a non .py/.txt disallowed extension."""
attachment = MockAttachment(filename="python.disallowed")
self.message.attachments = [attachment]
@@ -109,6 +114,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ joined_whitelist=", ".join(self.whitelist),
blocked_extensions_str=".disallowed",
meta_channel_mention=meta_channel.mention
)
@@ -135,7 +141,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
"""The return value should include all non-whitelisted extensions."""
test_values = (
([], []),
- (AntiMalwareConfig.whitelist, []),
+ (self.whitelist, []),
([".first"], []),
([".first", ".disallowed"], [".disallowed"]),
([".disallowed"], [".disallowed"]),
@@ -145,7 +151,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
for extensions, expected_disallowed_extensions in test_values:
with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
- disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ disallowed_extensions = self.cog._get_disallowed_extensions(self.message)
self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)