diff options
author | 2020-05-30 23:56:21 +0200 | |
---|---|---|
committer | 2020-05-30 23:56:21 +0200 | |
commit | 31861888a5cdc310028a05ea0ed03dc693bbe7b5 (patch) | |
tree | 636f60ecc967faf8aa59cdb43e10528e7121b5a0 | |
parent | Oops, add the return back. (diff) | |
parent | Merge pull request #864 from ks129/ban-kick-reason-length (diff) |
Merge branch 'master' into remove_periodic_ping
-rw-r--r-- | bot/cogs/antimalware.py | 55 | ||||
-rw-r--r-- | bot/cogs/moderation/infractions.py | 7 | ||||
-rw-r--r-- | bot/cogs/moderation/modlog.py | 5 | ||||
-rw-r--r-- | bot/cogs/moderation/scheduler.py | 69 | ||||
-rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
-rw-r--r-- | bot/cogs/moderation/utils.py | 12 | ||||
-rw-r--r-- | bot/cogs/tags.py | 2 | ||||
-rw-r--r-- | bot/cogs/watchchannels/bigbrother.py | 5 | ||||
-rw-r--r-- | bot/cogs/watchchannels/talentpool.py | 10 | ||||
-rw-r--r-- | bot/cogs/watchchannels/watchchannel.py | 3 | ||||
-rw-r--r-- | tests/bot/cogs/moderation/test_infractions.py | 55 | ||||
-rw-r--r-- | tests/bot/cogs/moderation/test_modlog.py | 29 | ||||
-rw-r--r-- | tests/bot/cogs/test_antimalware.py | 159 | ||||
-rw-r--r-- | tests/helpers.py | 4 |
14 files changed, 354 insertions, 63 deletions
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 66b5073e8..ea257442e 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,4 +1,5 @@ import logging +import typing as t from os.path import splitext from discord import Embed, Message, NotFound @@ -9,6 +10,27 @@ from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLE log = logging.getLogger(__name__) +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + class AntiMalware(Cog): """Delete messages which contain attachments with non-whitelisted file extensions.""" @@ -29,34 +51,20 @@ class AntiMalware(Cog): return embed = Embed() - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} - extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + extensions_blocked = self.get_disallowed_extensions(message) blocked_extensions_str = ', '.join(extensions_blocked) if ".py" in extensions_blocked: # Short-circuit on *.py files to provide a pastebin link - embed.description = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" - ) + embed.description = PY_EMBED_DESCRIPTION elif ".txt" in extensions_blocked: # Work around Discord AutoConversion of messages longer than 2000 chars to .txt cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - f"{cmd_channel.mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" - ) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) elif extensions_blocked: - whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) meta_channel = self.bot.get_channel(Channels.meta) - embed.description = ( - f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - f"We currently allow the following file types: **{whitelisted_types}**.\n\n" - f"Feel free to ask in {meta_channel.mention} if you think this is a mistake." + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, ) if embed.description: @@ -73,6 +81,13 @@ class AntiMalware(Cog): except NotFound: log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + @classmethod + def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]: + """Get an iterable containing all the disallowed extensions of attachments.""" + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + return extensions_blocked + def setup(bot: Bot) -> None: """Load the AntiMalware cog.""" diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index e62a36c43..5bfaad796 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -1,4 +1,5 @@ import logging +import textwrap import typing as t import discord @@ -225,7 +226,7 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_remove, user.id) - action = user.kick(reason=reason) + action = user.kick(reason=textwrap.shorten(reason, width=512, placeholder="...")) await self.apply_infraction(ctx, infraction, user, action) @respect_role_hierarchy() @@ -258,7 +259,9 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_remove, user.id) - action = ctx.guild.ban(user, reason=reason, delete_message_days=0) + truncated_reason = textwrap.shorten(reason, width=512, placeholder="...") + + action = ctx.guild.ban(user, reason=truncated_reason, delete_message_days=0) await self.apply_infraction(ctx, infraction, user, action) if infraction.get('expires_at') is not None: diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index beef7a8ef..9d28030d9 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -98,7 +98,10 @@ class ModLog(Cog, name="ModLog"): footer: t.Optional[str] = None, ) -> Context: """Generate log embed and send to logging channel.""" - embed = discord.Embed(description=text) + # Truncate string directly here to avoid removing newlines + embed = discord.Embed( + description=text[:2045] + "..." if len(text) > 2048 else text + ) if title and icon_url: embed.set_author(name=title, icon_url=icon_url) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 012432e60..f0a3ad1b1 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -101,33 +101,17 @@ class InfractionScheduler(Scheduler): dm_result = "" dm_log_text = "" - expiry_log_text = f"Expires: {expiry}" if expiry else "" + expiry_log_text = f"\nExpires: {expiry}" if expiry else "" log_title = "applied" log_content = None - - # DM the user about the infraction if it's not a shadow/hidden infraction. - if not infraction["hidden"]: - dm_result = f"{constants.Emojis.failmail} " - dm_log_text = "\nDM: **Failed**" - - # Sometimes user is a discord.Object; make it a proper user. - try: - if not isinstance(user, (discord.Member, discord.User)): - user = await self.bot.fetch_user(user.id) - except discord.HTTPException as e: - log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") - else: - # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): - dm_result = ":incoming_envelope: " - dm_log_text = "\nDM: Sent" + failed = False if infraction["actor"] == self.bot.user.id: log.trace( f"Infraction #{id_} actor is bot; including the reason in the confirmation message." ) - end_msg = f" (reason: {infraction['reason']})" + end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" elif ctx.channel.id not in STAFF_CHANNELS: log.trace( f"Infraction #{id_} context is not in a staff channel; omitting infraction count." @@ -164,12 +148,43 @@ class InfractionScheduler(Scheduler): log.warning(f"{log_msg}: bot lacks permissions.") else: log.exception(log_msg) + failed = True + + # DM the user about the infraction if it's not a shadow/hidden infraction. + # Don't send DM when applying failed. + if not infraction["hidden"] and not failed: + dm_result = f"{constants.Emojis.failmail} " + dm_log_text = "\nDM: **Failed**" + + # Sometimes user is a discord.Object; make it a proper user. + try: + if not isinstance(user, (discord.Member, discord.User)): + user = await self.bot.fetch_user(user.id) + except discord.HTTPException as e: + log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") + else: + # Accordingly display whether the user was successfully notified via DM. + if await utils.notify_infraction(user, infr_type, expiry, reason, icon): + dm_result = ":incoming_envelope: " + dm_log_text = "\nDM: Sent" + + if failed: + dm_log_text = "\nDM: **Canceled**" + dm_result = f"{constants.Emojis.failmail} " + log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") + try: + await self.bot.api_client.delete(f"bot/infractions/{id_}") + except ResponseCodeError as e: + confirm_msg += " and failed to delete" + log_title += " and failed to delete" + log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") + infr_message = "" + else: + infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" # Send a confirmation message to the invoking context. log.trace(f"Sending infraction #{id_} confirmation message.") - await ctx.send( - f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." - ) + await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") # Send a log message to the mod log. log.trace(f"Sending apply mod log for infraction #{id_}.") @@ -180,9 +195,8 @@ class InfractionScheduler(Scheduler): thumbnail=user.avatar_url_as(static_format="png"), text=textwrap.dedent(f""" Member: {user.mention} (`{user.id}`) - Actor: {ctx.message.author}{dm_log_text} + Actor: {ctx.message.author}{dm_log_text}{expiry_log_text} Reason: {reason} - {expiry_log_text} """), content=log_content, footer=f"ID {infraction['id']}" @@ -294,6 +308,9 @@ class InfractionScheduler(Scheduler): f"{log_text.get('Failure', '')}" ) + # Move reason to end of entry to avoid cutting out some keys + log_text["Reason"] = log_text.pop("Reason") + # Send a log message to the mod log. await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[infr_type][1], @@ -407,6 +424,9 @@ class InfractionScheduler(Scheduler): user = self.bot.get_user(user_id) avatar = user.avatar_url_as(static_format="png") if user else None + # Move reason to end so when reason is too long, this is not gonna cut out required items. + log_text["Reason"] = log_text.pop("Reason") + log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[type_][1], @@ -416,7 +436,6 @@ class InfractionScheduler(Scheduler): text="\n".join(f"{k}: {v}" for k, v in log_text.items()), footer=f"ID: {id_}", content=log_content, - ) return log_text diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 29855c325..45a010f00 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -183,10 +183,10 @@ class Superstarify(InfractionScheduler, Cog): text=textwrap.dedent(f""" Member: {member.mention} (`{member.id}`) Actor: {ctx.message.author} - Reason: {reason} Expires: {expiry_str} Old nickname: `{old_nick}` New nickname: `{forced_nick}` + Reason: {reason} """), footer=f"ID {id_}" ) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index e4e0f1ec2..1b716b2ea 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -143,12 +143,14 @@ async def notify_infraction( """DM a user about their new infraction and return True if the DM is successful.""" log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + text = textwrap.dedent(f""" + **Type:** {infr_type.capitalize()} + **Expires:** {expires_at or "N/A"} + **Reason:** {reason or "No reason provided."} + """) + embed = discord.Embed( - description=textwrap.dedent(f""" - **Type:** {infr_type.capitalize()} - **Expires:** {expires_at or "N/A"} - **Reason:** {reason or "No reason provided."} - """), + description=textwrap.shorten(text, width=2048, placeholder="..."), colour=Colours.soft_red ) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index bc7f53f68..6f03a3475 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -44,7 +44,7 @@ class Tags(Cog): tag = { "title": tag_title, "embed": { - "description": file.read_text(), + "description": file.read_text(encoding="utf8"), }, "restricted_to": "developers", } diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index e4fb173e0..702d371f4 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -1,4 +1,5 @@ import logging +import textwrap from collections import ChainMap from discord.ext.commands import Cog, Context, group @@ -97,8 +98,8 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): if len(history) > 1: total = f"({len(history) // 2} previous infractions in total)" - end_reason = history[0]["reason"] - start_reason = f"Watched: {history[1]['reason']}" + end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") + start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}" msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" else: msg = ":x: Failed to post the infraction: response was empty." diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index cd9c7e555..14547105f 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -106,8 +106,8 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): if history: total = f"({len(history)} previous nominations in total)" - start_reason = f"Watched: {history[0]['reason']}" - end_reason = f"Unwatched: {history[0]['end_reason']}" + start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" + end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}" msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```" await ctx.send(msg) @@ -224,7 +224,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): Status: **Active** Date: {start_date} Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} + Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")} Nomination ID: `{nomination_object["id"]}` =============== """ @@ -237,10 +237,10 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): Status: Inactive Date: {start_date} Actor: {actor.mention if actor else actor_id} - Reason: {nomination_object["reason"]} + Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")} End date: {end_date} - Unwatch reason: {nomination_object["end_reason"]} + Unwatch reason: {textwrap.shorten(nomination_object["end_reason"], width=200, placeholder="...")} Nomination ID: `{nomination_object["id"]}` =============== """ diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 643cd46e4..436778c46 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -280,8 +280,9 @@ class WatchChannel(metaclass=CogABCMeta): else: message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + footer = f"Added {time_delta} by {actor} | Reason: {reason}" embed = Embed(description=f"{msg.author.mention} {message_jump}") - embed.set_footer(text=f"Added {time_delta} by {actor} | Reason: {reason}") + embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="...")) await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..da4e92ccc --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.cogs.moderation.utils.get_active_infraction") + @patch("bot.cogs.moderation.utils.post_infraction") + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + get_active_mock.return_value = None + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 + ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) + + @patch("bot.cogs.moderation.utils.post_infraction") + async def test_apply_kick_reason_truncation(self, post_infraction_mock): + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..f2809f40a --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): + """Tests for moderation logs.""" + + def setUp(self): + self.bot = MockBot() + self.cog = ModLog(self.bot) + self.channel = MockTextChannel() + + async def test_log_entry_description_truncation(self): + """Test that embed description for ModLog entry is truncated.""" + self.bot.get_channel.return_value = self.channel + await self.cog.send_log_message( + icon_url="foo", + colour=discord.Colour.blue(), + title="bar", + text="foo bar" * 3000 + ) + embed = self.channel.send.call_args[1]["embed"] + self.assertEqual( + embed.description, ("foo bar" * 3000)[:2045] + "..." + ) diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + await self.cog.on_message(self.message) + self.message.delete.assert_not_called() + + async def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.guild = None + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_called_once() + + async def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extention_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + + async def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + self.message.delete.assert_called_once() + + async def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + with self.assertLogs(logger=antimalware.log, level="INFO"): + await self.cog.on_message(self.message) + + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (AntiMalwareConfig.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/helpers.py b/tests/helpers.py index 13283339b..faa839370 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -208,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position < other.position + def __ge__(self, other): + """Simplified position-based comparisons similar to those of `discord.Role`.""" + return self.position >= other.position + # Create a Member instance to get a realistic Mock of `discord.Member` member_data = {'user': 'lemon', 'roles': [1]} |