diff options
| author | 2020-07-18 16:54:01 +0200 | |
|---|---|---|
| committer | 2020-07-18 16:54:01 +0200 | |
| commit | 1c569f2f38fe18d6210deec001046cf9ee68ea53 (patch) | |
| tree | c8ac90eea18379b2eaea75387bf1e890d741dc72 | |
| parent | Remove Filtering constants, use cache data. (diff) | |
Remove AntiMalWare constants, use cache data.
Also updates the tests for this cog.
| -rw-r--r-- | bot/bot.py | 2 | ||||
| -rw-r--r-- | bot/cogs/antimalware.py | 24 | ||||
| -rw-r--r-- | bot/constants.py | 6 | ||||
| -rw-r--r-- | config-default.yml | 29 | ||||
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 24 |
5 files changed, 30 insertions, 55 deletions
diff --git a/bot/bot.py b/bot/bot.py index 6c02e72a7..962c8dd93 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -34,6 +34,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) + self.allow_deny_list_cache = {} self._connector = None self._resolver = None @@ -52,7 +53,6 @@ class Bot(commands.Bot): async def _cache_allow_deny_list_data(self) -> None: """Cache all the data in the AllowDenyList on the site.""" full_cache = await self.api_client.get('bot/allow_deny_lists') - self.allow_deny_list_cache = {} for item in full_cache: type_ = item.get("type") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index ea257442e..38ff1133d 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound from discord.ext.commands import Cog from bot.bot import Bot -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs +from bot.constants import Channels, STAFF_ROLES, URLs log = logging.getLogger(__name__) @@ -27,7 +27,7 @@ TXT_EMBED_DESCRIPTION = ( DISALLOWED_EMBED_DESCRIPTION = ( "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n" + "We currently allow the following file types: **{joined_whitelist}**.\n\n" "Feel free to ask in {meta_channel_mention} if you think this is a mistake." ) @@ -38,6 +38,16 @@ class AntiMalware(Cog): def __init__(self, bot: Bot): self.bot = bot + def _get_whitelisted_file_formats(self) -> list: + """Get the file formats currently on the whitelist.""" + return [item.get('content') for item in self.bot.allow_deny_list_cache['file_format.True']] + + def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) + return extensions_blocked + @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" @@ -51,7 +61,7 @@ class AntiMalware(Cog): return embed = Embed() - extensions_blocked = self.get_disallowed_extensions(message) + extensions_blocked = self._get_disallowed_extensions(message) blocked_extensions_str = ', '.join(extensions_blocked) if ".py" in extensions_blocked: # Short-circuit on *.py files to provide a pastebin link @@ -63,6 +73,7 @@ class AntiMalware(Cog): elif extensions_blocked: meta_channel = self.bot.get_channel(Channels.meta) embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=', '.join(self._get_whitelisted_file_formats()), blocked_extensions_str=blocked_extensions_str, meta_channel_mention=meta_channel.mention, ) @@ -81,13 +92,6 @@ class AntiMalware(Cog): except NotFound: log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - @classmethod - def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]: - """Get an iterable containing all the disallowed extensions of attachments.""" - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) - return extensions_blocked - def setup(bot: Bot) -> None: """Load the AntiMalware cog.""" diff --git a/bot/constants.py b/bot/constants.py index f5245ca50..857e6c4f0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -527,12 +527,6 @@ class AntiSpam(metaclass=YAMLGetter): rules: Dict[str, Dict[str, int]] -class AntiMalware(metaclass=YAMLGetter): - section = "anti_malware" - - whitelist: list - - class BigBrother(metaclass=YAMLGetter): section = 'big_brother' diff --git a/config-default.yml b/config-default.yml index 81c8c40d5..503cc2b52 100644 --- a/config-default.yml +++ b/config-default.yml @@ -386,35 +386,6 @@ anti_spam: max: 3 -anti_malware: - whitelist: - - '.3gp' - - '.3g2' - - '.avi' - - '.bmp' - - '.gif' - - '.h264' - - '.jpg' - - '.jpeg' - - '.m4v' - - '.mkv' - - '.mov' - - '.mp4' - - '.mpeg' - - '.mpg' - - '.png' - - '.tiff' - - '.wmv' - - '.svg' - - '.psd' # Photoshop - - '.ai' # Illustrator - - '.aep' # After Effects - - '.xcf' # GIMP - - '.mp3' - - '.wav' - - '.ogg' - - reddit: subreddits: - 'r/Python' diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index f219fc1ba..1e010d2ce 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,28 +1,33 @@ import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from bot.constants import Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole -MODULE = "bot.cogs.antimalware" - -@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Test the AntiMalware cog.""" def setUp(self): """Sets up fresh objects for each test.""" self.bot = MockBot() + self.bot.allow_deny_list_cache = { + "file_format.True": [ + {"content": ".first"}, + {"content": ".second"}, + {"content": ".third"} + ] + } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") + attachment = MockAttachment(filename="python.first") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -93,7 +98,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - async def test_other_disallowed_extention_embed_description(self): + async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] @@ -109,6 +114,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", meta_channel_mention=meta_channel.mention ) @@ -135,7 +141,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """The return value should include all non-whitelisted extensions.""" test_values = ( ([], []), - (AntiMalwareConfig.whitelist, []), + (self.whitelist, []), ([".first"], []), ([".first", ".disallowed"], [".disallowed"]), ([".disallowed"], [".disallowed"]), @@ -145,7 +151,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) |