aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar kwzrd <[email protected]>2020-05-30 23:56:21 +0200
committerGravatar GitHub <[email protected]>2020-05-30 23:56:21 +0200
commit31861888a5cdc310028a05ea0ed03dc693bbe7b5 (patch)
tree636f60ecc967faf8aa59cdb43e10528e7121b5a0
parentOops, add the return back. (diff)
parentMerge pull request #864 from ks129/ban-kick-reason-length (diff)
Merge branch 'master' into remove_periodic_ping
-rw-r--r--bot/cogs/antimalware.py55
-rw-r--r--bot/cogs/moderation/infractions.py7
-rw-r--r--bot/cogs/moderation/modlog.py5
-rw-r--r--bot/cogs/moderation/scheduler.py69
-rw-r--r--bot/cogs/moderation/superstarify.py2
-rw-r--r--bot/cogs/moderation/utils.py12
-rw-r--r--bot/cogs/tags.py2
-rw-r--r--bot/cogs/watchchannels/bigbrother.py5
-rw-r--r--bot/cogs/watchchannels/talentpool.py10
-rw-r--r--bot/cogs/watchchannels/watchchannel.py3
-rw-r--r--tests/bot/cogs/moderation/test_infractions.py55
-rw-r--r--tests/bot/cogs/moderation/test_modlog.py29
-rw-r--r--tests/bot/cogs/test_antimalware.py159
-rw-r--r--tests/helpers.py4
14 files changed, 354 insertions, 63 deletions
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 66b5073e8..ea257442e 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -1,4 +1,5 @@
import logging
+import typing as t
from os.path import splitext
from discord import Embed, Message, NotFound
@@ -9,6 +10,27 @@ from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLE
log = logging.getLogger(__name__)
+PY_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach a Python file - "
+ f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
+)
+
+TXT_EMBED_DESCRIPTION = (
+ "**Uh-oh!** It looks like your message got zapped by our spam filter. "
+ "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n"
+ "• If you attempted to send a message longer than 2000 characters, try shortening your message "
+ "to fit within the character limit or use a pasting service (see below) \n\n"
+ "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in "
+ "{cmd_channel_mention} for more information) or use a pasting service like: "
+ f"\n\n{URLs.site_schema}{URLs.site_paste}"
+)
+
+DISALLOWED_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
+ f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n"
+ "Feel free to ask in {meta_channel_mention} if you think this is a mistake."
+)
+
class AntiMalware(Cog):
"""Delete messages which contain attachments with non-whitelisted file extensions."""
@@ -29,34 +51,20 @@ class AntiMalware(Cog):
return
embed = Embed()
- file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
- extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ extensions_blocked = self.get_disallowed_extensions(message)
blocked_extensions_str = ', '.join(extensions_blocked)
if ".py" in extensions_blocked:
# Short-circuit on *.py files to provide a pastebin link
- embed.description = (
- "It looks like you tried to attach a Python file - "
- f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
- )
+ embed.description = PY_EMBED_DESCRIPTION
elif ".txt" in extensions_blocked:
# Work around Discord AutoConversion of messages longer than 2000 chars to .txt
cmd_channel = self.bot.get_channel(Channels.bot_commands)
- embed.description = (
- "**Uh-oh!** It looks like your message got zapped by our spam filter. "
- "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n"
- "• If you attempted to send a message longer than 2000 characters, try shortening your message "
- "to fit within the character limit or use a pasting service (see below) \n\n"
- "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in "
- f"{cmd_channel.mention} for more information) or use a pasting service like: "
- f"\n\n{URLs.site_schema}{URLs.site_paste}"
- )
+ embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention)
elif extensions_blocked:
- whitelisted_types = ', '.join(AntiMalwareConfig.whitelist)
meta_channel = self.bot.get_channel(Channels.meta)
- embed.description = (
- f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
- f"We currently allow the following file types: **{whitelisted_types}**.\n\n"
- f"Feel free to ask in {meta_channel.mention} if you think this is a mistake."
+ embed.description = DISALLOWED_EMBED_DESCRIPTION.format(
+ blocked_extensions_str=blocked_extensions_str,
+ meta_channel_mention=meta_channel.mention,
)
if embed.description:
@@ -73,6 +81,13 @@ class AntiMalware(Cog):
except NotFound:
log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
+ @classmethod
+ def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]:
+ """Get an iterable containing all the disallowed extensions of attachments."""
+ file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
+ extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)
+ return extensions_blocked
+
def setup(bot: Bot) -> None:
"""Load the AntiMalware cog."""
diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py
index e62a36c43..5bfaad796 100644
--- a/bot/cogs/moderation/infractions.py
+++ b/bot/cogs/moderation/infractions.py
@@ -1,4 +1,5 @@
import logging
+import textwrap
import typing as t
import discord
@@ -225,7 +226,7 @@ class Infractions(InfractionScheduler, commands.Cog):
self.mod_log.ignore(Event.member_remove, user.id)
- action = user.kick(reason=reason)
+ action = user.kick(reason=textwrap.shorten(reason, width=512, placeholder="..."))
await self.apply_infraction(ctx, infraction, user, action)
@respect_role_hierarchy()
@@ -258,7 +259,9 @@ class Infractions(InfractionScheduler, commands.Cog):
self.mod_log.ignore(Event.member_remove, user.id)
- action = ctx.guild.ban(user, reason=reason, delete_message_days=0)
+ truncated_reason = textwrap.shorten(reason, width=512, placeholder="...")
+
+ action = ctx.guild.ban(user, reason=truncated_reason, delete_message_days=0)
await self.apply_infraction(ctx, infraction, user, action)
if infraction.get('expires_at') is not None:
diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py
index beef7a8ef..9d28030d9 100644
--- a/bot/cogs/moderation/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -98,7 +98,10 @@ class ModLog(Cog, name="ModLog"):
footer: t.Optional[str] = None,
) -> Context:
"""Generate log embed and send to logging channel."""
- embed = discord.Embed(description=text)
+ # Truncate string directly here to avoid removing newlines
+ embed = discord.Embed(
+ description=text[:2045] + "..." if len(text) > 2048 else text
+ )
if title and icon_url:
embed.set_author(name=title, icon_url=icon_url)
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index 012432e60..f0a3ad1b1 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -101,33 +101,17 @@ class InfractionScheduler(Scheduler):
dm_result = ""
dm_log_text = ""
- expiry_log_text = f"Expires: {expiry}" if expiry else ""
+ expiry_log_text = f"\nExpires: {expiry}" if expiry else ""
log_title = "applied"
log_content = None
-
- # DM the user about the infraction if it's not a shadow/hidden infraction.
- if not infraction["hidden"]:
- dm_result = f"{constants.Emojis.failmail} "
- dm_log_text = "\nDM: **Failed**"
-
- # Sometimes user is a discord.Object; make it a proper user.
- try:
- if not isinstance(user, (discord.Member, discord.User)):
- user = await self.bot.fetch_user(user.id)
- except discord.HTTPException as e:
- log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})")
- else:
- # Accordingly display whether the user was successfully notified via DM.
- if await utils.notify_infraction(user, infr_type, expiry, reason, icon):
- dm_result = ":incoming_envelope: "
- dm_log_text = "\nDM: Sent"
+ failed = False
if infraction["actor"] == self.bot.user.id:
log.trace(
f"Infraction #{id_} actor is bot; including the reason in the confirmation message."
)
- end_msg = f" (reason: {infraction['reason']})"
+ end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})"
elif ctx.channel.id not in STAFF_CHANNELS:
log.trace(
f"Infraction #{id_} context is not in a staff channel; omitting infraction count."
@@ -164,12 +148,43 @@ class InfractionScheduler(Scheduler):
log.warning(f"{log_msg}: bot lacks permissions.")
else:
log.exception(log_msg)
+ failed = True
+
+ # DM the user about the infraction if it's not a shadow/hidden infraction.
+ # Don't send DM when applying failed.
+ if not infraction["hidden"] and not failed:
+ dm_result = f"{constants.Emojis.failmail} "
+ dm_log_text = "\nDM: **Failed**"
+
+ # Sometimes user is a discord.Object; make it a proper user.
+ try:
+ if not isinstance(user, (discord.Member, discord.User)):
+ user = await self.bot.fetch_user(user.id)
+ except discord.HTTPException as e:
+ log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})")
+ else:
+ # Accordingly display whether the user was successfully notified via DM.
+ if await utils.notify_infraction(user, infr_type, expiry, reason, icon):
+ dm_result = ":incoming_envelope: "
+ dm_log_text = "\nDM: Sent"
+
+ if failed:
+ dm_log_text = "\nDM: **Canceled**"
+ dm_result = f"{constants.Emojis.failmail} "
+ log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.")
+ try:
+ await self.bot.api_client.delete(f"bot/infractions/{id_}")
+ except ResponseCodeError as e:
+ confirm_msg += " and failed to delete"
+ log_title += " and failed to delete"
+ log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.")
+ infr_message = ""
+ else:
+ infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}"
# Send a confirmation message to the invoking context.
log.trace(f"Sending infraction #{id_} confirmation message.")
- await ctx.send(
- f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}."
- )
+ await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.")
# Send a log message to the mod log.
log.trace(f"Sending apply mod log for infraction #{id_}.")
@@ -180,9 +195,8 @@ class InfractionScheduler(Scheduler):
thumbnail=user.avatar_url_as(static_format="png"),
text=textwrap.dedent(f"""
Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}{dm_log_text}
+ Actor: {ctx.message.author}{dm_log_text}{expiry_log_text}
Reason: {reason}
- {expiry_log_text}
"""),
content=log_content,
footer=f"ID {infraction['id']}"
@@ -294,6 +308,9 @@ class InfractionScheduler(Scheduler):
f"{log_text.get('Failure', '')}"
)
+ # Move reason to end of entry to avoid cutting out some keys
+ log_text["Reason"] = log_text.pop("Reason")
+
# Send a log message to the mod log.
await self.mod_log.send_log_message(
icon_url=utils.INFRACTION_ICONS[infr_type][1],
@@ -407,6 +424,9 @@ class InfractionScheduler(Scheduler):
user = self.bot.get_user(user_id)
avatar = user.avatar_url_as(static_format="png") if user else None
+ # Move reason to end so when reason is too long, this is not gonna cut out required items.
+ log_text["Reason"] = log_text.pop("Reason")
+
log.trace(f"Sending deactivation mod log for infraction #{id_}.")
await self.mod_log.send_log_message(
icon_url=utils.INFRACTION_ICONS[type_][1],
@@ -416,7 +436,6 @@ class InfractionScheduler(Scheduler):
text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
footer=f"ID: {id_}",
content=log_content,
-
)
return log_text
diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py
index 29855c325..45a010f00 100644
--- a/bot/cogs/moderation/superstarify.py
+++ b/bot/cogs/moderation/superstarify.py
@@ -183,10 +183,10 @@ class Superstarify(InfractionScheduler, Cog):
text=textwrap.dedent(f"""
Member: {member.mention} (`{member.id}`)
Actor: {ctx.message.author}
- Reason: {reason}
Expires: {expiry_str}
Old nickname: `{old_nick}`
New nickname: `{forced_nick}`
+ Reason: {reason}
"""),
footer=f"ID {id_}"
)
diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py
index e4e0f1ec2..1b716b2ea 100644
--- a/bot/cogs/moderation/utils.py
+++ b/bot/cogs/moderation/utils.py
@@ -143,12 +143,14 @@ async def notify_infraction(
"""DM a user about their new infraction and return True if the DM is successful."""
log.trace(f"Sending {user} a DM about their {infr_type} infraction.")
+ text = textwrap.dedent(f"""
+ **Type:** {infr_type.capitalize()}
+ **Expires:** {expires_at or "N/A"}
+ **Reason:** {reason or "No reason provided."}
+ """)
+
embed = discord.Embed(
- description=textwrap.dedent(f"""
- **Type:** {infr_type.capitalize()}
- **Expires:** {expires_at or "N/A"}
- **Reason:** {reason or "No reason provided."}
- """),
+ description=textwrap.shorten(text, width=2048, placeholder="..."),
colour=Colours.soft_red
)
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py
index bc7f53f68..6f03a3475 100644
--- a/bot/cogs/tags.py
+++ b/bot/cogs/tags.py
@@ -44,7 +44,7 @@ class Tags(Cog):
tag = {
"title": tag_title,
"embed": {
- "description": file.read_text(),
+ "description": file.read_text(encoding="utf8"),
},
"restricted_to": "developers",
}
diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py
index e4fb173e0..702d371f4 100644
--- a/bot/cogs/watchchannels/bigbrother.py
+++ b/bot/cogs/watchchannels/bigbrother.py
@@ -1,4 +1,5 @@
import logging
+import textwrap
from collections import ChainMap
from discord.ext.commands import Cog, Context, group
@@ -97,8 +98,8 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
if len(history) > 1:
total = f"({len(history) // 2} previous infractions in total)"
- end_reason = history[0]["reason"]
- start_reason = f"Watched: {history[1]['reason']}"
+ end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...")
+ start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}"
msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"
else:
msg = ":x: Failed to post the infraction: response was empty."
diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py
index cd9c7e555..14547105f 100644
--- a/bot/cogs/watchchannels/talentpool.py
+++ b/bot/cogs/watchchannels/talentpool.py
@@ -106,8 +106,8 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
if history:
total = f"({len(history)} previous nominations in total)"
- start_reason = f"Watched: {history[0]['reason']}"
- end_reason = f"Unwatched: {history[0]['end_reason']}"
+ start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}"
+ end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}"
msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"
await ctx.send(msg)
@@ -224,7 +224,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
Status: **Active**
Date: {start_date}
Actor: {actor.mention if actor else actor_id}
- Reason: {nomination_object["reason"]}
+ Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")}
Nomination ID: `{nomination_object["id"]}`
===============
"""
@@ -237,10 +237,10 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
Status: Inactive
Date: {start_date}
Actor: {actor.mention if actor else actor_id}
- Reason: {nomination_object["reason"]}
+ Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")}
End date: {end_date}
- Unwatch reason: {nomination_object["end_reason"]}
+ Unwatch reason: {textwrap.shorten(nomination_object["end_reason"], width=200, placeholder="...")}
Nomination ID: `{nomination_object["id"]}`
===============
"""
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
index 643cd46e4..436778c46 100644
--- a/bot/cogs/watchchannels/watchchannel.py
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -280,8 +280,9 @@ class WatchChannel(metaclass=CogABCMeta):
else:
message_jump = f"in [#{msg.channel.name}]({msg.jump_url})"
+ footer = f"Added {time_delta} by {actor} | Reason: {reason}"
embed = Embed(description=f"{msg.author.mention} {message_jump}")
- embed.set_footer(text=f"Added {time_delta} by {actor} | Reason: {reason}")
+ embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="..."))
await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url)
diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py
new file mode 100644
index 000000000..da4e92ccc
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_infractions.py
@@ -0,0 +1,55 @@
+import textwrap
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from bot.cogs.moderation.infractions import Infractions
+from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole
+
+
+class TruncationTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for ban and kick command reason truncation."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = Infractions(self.bot)
+ self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10))
+ self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0))
+ self.guild = MockGuild(id=4567)
+ self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild)
+
+ @patch("bot.cogs.moderation.utils.get_active_infraction")
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock):
+ """Should truncate reason for `ctx.guild.ban`."""
+ get_active_mock.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.bot.get_cog.return_value = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.ctx.guild.ban = Mock()
+
+ await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000)
+ self.ctx.guild.ban.assert_called_once_with(
+ self.target,
+ reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),
+ delete_message_days=0
+ )
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value
+ )
+
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_kick_reason_truncation(self, post_infraction_mock):
+ """Should truncate reason for `Member.kick`."""
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.target.kick = Mock()
+
+ await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000)
+ self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
+ )
diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py
new file mode 100644
index 000000000..f2809f40a
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_modlog.py
@@ -0,0 +1,29 @@
+import unittest
+
+import discord
+
+from bot.cogs.moderation.modlog import ModLog
+from tests.helpers import MockBot, MockTextChannel
+
+
+class ModLogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for moderation logs."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = ModLog(self.bot)
+ self.channel = MockTextChannel()
+
+ async def test_log_entry_description_truncation(self):
+ """Test that embed description for ModLog entry is truncated."""
+ self.bot.get_channel.return_value = self.channel
+ await self.cog.send_log_message(
+ icon_url="foo",
+ colour=discord.Colour.blue(),
+ title="bar",
+ text="foo bar" * 3000
+ )
+ embed = self.channel.send.call_args[1]["embed"]
+ self.assertEqual(
+ embed.description, ("foo bar" * 3000)[:2045] + "..."
+ )
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
new file mode 100644
index 000000000..f219fc1ba
--- /dev/null
+++ b/tests/bot/cogs/test_antimalware.py
@@ -0,0 +1,159 @@
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from discord import NotFound
+
+from bot.cogs import antimalware
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
+
+MODULE = "bot.cogs.antimalware"
+
+
+@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
+class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
+ """Test the AntiMalware cog."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = MockBot()
+ self.cog = antimalware.AntiMalware(self.bot)
+ self.message = MockMessage()
+
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should not be deleted"""
+ attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_message_without_attachment(self):
+ """Messages without attachments should result in no action."""
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_direct_message_with_attachment(self):
+ """Direct messages should have no action taken."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.guild = None
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_message_with_illegal_extension_gets_deleted(self):
+ """A message containing an illegal extension should send an embed."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_called_once()
+
+ async def test_message_send_by_staff(self):
+ """A message send by a member of staff should be ignored."""
+ staff_role = MockRole(id=STAFF_ROLES[0])
+ self.message.author.roles.append(staff_role)
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site"""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+
+ self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
+
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt file should result in the correct embed."""
+ attachment = MockAttachment(filename="python.txt")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+
+ async def test_other_disallowed_extention_embed_description(self):
+ """Test the description for a non .py/.txt disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ meta_channel = self.bot.get_channel(Channels.meta)
+
+ self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+
+ async def test_removing_deleted_message_logs(self):
+ """Removing an already deleted message logs the correct message"""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_called_once()
+
+ async def test_message_with_illegal_attachment_logs(self):
+ """Deleting a message with an illegal attachment should result in a log."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (AntiMalwareConfig.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], [".disallowed"]),
+ ([".disallowed"], [".disallowed"]),
+ ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
+ disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
+
+
+class AntiMalwareSetupTests(unittest.TestCase):
+ """Tests setup of the `AntiMalware` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ antimalware.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/helpers.py b/tests/helpers.py
index 13283339b..faa839370 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -208,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""Simplified position-based comparisons similar to those of `discord.Role`."""
return self.position < other.position
+ def __ge__(self, other):
+ """Simplified position-based comparisons similar to those of `discord.Role`."""
+ return self.position >= other.position
+
# Create a Member instance to get a realistic Mock of `discord.Member`
member_data = {'user': 'lemon', 'roles': [1]}