aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Jannes Jonkers <[email protected]>2020-05-11 20:24:39 +0200
committerGravatar Jannes Jonkers <[email protected]>2020-05-11 20:24:39 +0200
commit148b12603f4ad8799d135ec9956d1841cf1c7bf7 (patch)
tree4ae79aab4098d891154902068d026e893cb95a04
parentAntiMalware Tests - extracted the method for determining disallowed extension... (diff)
AntiMalware Tests - extracted the method for determining disallowed extensions and added a test for it.
-rw-r--r--bot/cogs/antimalware.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 66b5073e8..f5fd5e2d9 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -1,4 +1,5 @@
import logging
+import typing as t
from os.path import splitext
from discord import Embed, Message, NotFound
@@ -29,8 +30,7 @@ class AntiMalware(Cog):
return
embed = Embed()
- file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
- extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ extensions_blocked = self.get_disallowed_extensions(message)
blocked_extensions_str = ', '.join(extensions_blocked)
if ".py" in extensions_blocked:
# Short-circuit on *.py files to provide a pastebin link
@@ -73,6 +73,13 @@ class AntiMalware(Cog):
except NotFound:
log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
+ @classmethod
+ def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]:
+ """Get an iterable containing all the disallowed extensions of attachments."""
+ file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
+ extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ return extensions_blocked
+
def setup(bot: Bot) -> None:
"""Load the AntiMalware cog."""