diff options
author | 2020-05-11 20:24:39 +0200 | |
---|---|---|
committer | 2020-05-11 20:24:39 +0200 | |
commit | 148b12603f4ad8799d135ec9956d1841cf1c7bf7 (patch) | |
tree | 4ae79aab4098d891154902068d026e893cb95a04 | |
parent | AntiMalware Tests - extracted the method for determining disallowed extension... (diff) |
AntiMalware Tests - extracted the method for determining disallowed extensions and added a test for it.
-rw-r--r-- | bot/cogs/antimalware.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 66b5073e8..f5fd5e2d9 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,4 +1,5 @@ import logging +import typing as t from os.path import splitext from discord import Embed, Message, NotFound @@ -29,8 +30,7 @@ class AntiMalware(Cog): return embed = Embed() - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + extensions_blocked = self.get_disallowed_extensions(message) blocked_extensions_str = ', '.join(extensions_blocked) if ".py" in extensions_blocked: # Short-circuit on *.py files to provide a pastebin link @@ -73,6 +73,13 @@ class AntiMalware(Cog): except NotFound: log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + @classmethod + def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + return extensions_blocked + def setup(bot: Bot) -> None: """Load the AntiMalware cog.""" |