diff options
| author | 2020-05-30 23:56:21 +0200 | |
|---|---|---|
| committer | 2020-05-30 23:56:21 +0200 | |
| commit | 31861888a5cdc310028a05ea0ed03dc693bbe7b5 (patch) | |
| tree | 636f60ecc967faf8aa59cdb43e10528e7121b5a0 | |
| parent | Oops, add the return back. (diff) | |
| parent | Merge pull request #864 from ks129/ban-kick-reason-length (diff) | |
Merge branch 'master' into remove_periodic_ping
| -rw-r--r-- | bot/cogs/antimalware.py | 55 | ||||
| -rw-r--r-- | bot/cogs/moderation/infractions.py | 7 | ||||
| -rw-r--r-- | bot/cogs/moderation/modlog.py | 5 | ||||
| -rw-r--r-- | bot/cogs/moderation/scheduler.py | 69 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 2 | ||||
| -rw-r--r-- | bot/cogs/moderation/utils.py | 12 | ||||
| -rw-r--r-- | bot/cogs/tags.py | 2 | ||||
| -rw-r--r-- | bot/cogs/watchchannels/bigbrother.py | 5 | ||||
| -rw-r--r-- | bot/cogs/watchchannels/talentpool.py | 10 | ||||
| -rw-r--r-- | bot/cogs/watchchannels/watchchannel.py | 3 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_infractions.py | 55 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_modlog.py | 29 | ||||
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 159 | ||||
| -rw-r--r-- | tests/helpers.py | 4 | 
14 files changed, 354 insertions, 63 deletions
| diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 66b5073e8..ea257442e 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,4 +1,5 @@  import logging +import typing as t  from os.path import splitext  from discord import Embed, Message, NotFound @@ -9,6 +10,27 @@ from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLE  log = logging.getLogger(__name__) +PY_EMBED_DESCRIPTION = ( +    "It looks like you tried to attach a Python file - " +    f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( +    "**Uh-oh!** It looks like your message got zapped by our spam filter. " +    "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" +    "• If you attempted to send a message longer than 2000 characters, try shortening your message " +    "to fit within the character limit or use a pasting service (see below) \n\n" +    "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " +    "{cmd_channel_mention} for more information) or use a pasting service like: " +    f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( +    "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " +    f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n" +    "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) +  class AntiMalware(Cog):      """Delete messages which contain attachments with non-whitelisted file extensions.""" @@ -29,34 +51,20 @@ class AntiMalware(Cog):              return          embed = Embed() -        file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} -        extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) +        extensions_blocked = self.get_disallowed_extensions(message)          blocked_extensions_str = ', '.join(extensions_blocked)          if ".py" in extensions_blocked:              # Short-circuit on *.py files to provide a pastebin link -            embed.description = ( -                "It looks like you tried to attach a Python file - " -                f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" -            ) +            embed.description = PY_EMBED_DESCRIPTION          elif ".txt" in extensions_blocked:              # Work around Discord AutoConversion of messages longer than 2000 chars to .txt              cmd_channel = self.bot.get_channel(Channels.bot_commands) -            embed.description = ( -                "**Uh-oh!** It looks like your message got zapped by our spam filter. " -                "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" -                "• If you attempted to send a message longer than 2000 characters, try shortening your message " -                "to fit within the character limit or use a pasting service (see below) \n\n" -                "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " -                f"{cmd_channel.mention} for more information) or use a pasting service like: " -                f"\n\n{URLs.site_schema}{URLs.site_paste}" -            ) +            embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention)          elif extensions_blocked: -            whitelisted_types = ', '.join(AntiMalwareConfig.whitelist)              meta_channel = self.bot.get_channel(Channels.meta) -            embed.description = ( -                f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " -                f"We currently allow the following file types: **{whitelisted_types}**.\n\n" -                f"Feel free to ask in {meta_channel.mention} if you think this is a mistake." +            embed.description = DISALLOWED_EMBED_DESCRIPTION.format( +                blocked_extensions_str=blocked_extensions_str, +                meta_channel_mention=meta_channel.mention,              )          if embed.description: @@ -73,6 +81,13 @@ class AntiMalware(Cog):              except NotFound:                  log.info(f"Tried to delete message `{message.id}`, but message could not be found.") +    @classmethod +    def get_disallowed_extensions(cls, message: Message) -> t.Iterable[str]: +        """Get an iterable containing all the disallowed extensions of attachments.""" +        file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} +        extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) +        return extensions_blocked +  def setup(bot: Bot) -> None:      """Load the AntiMalware cog.""" diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index e62a36c43..5bfaad796 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -1,4 +1,5 @@  import logging +import textwrap  import typing as t  import discord @@ -225,7 +226,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self.mod_log.ignore(Event.member_remove, user.id) -        action = user.kick(reason=reason) +        action = user.kick(reason=textwrap.shorten(reason, width=512, placeholder="..."))          await self.apply_infraction(ctx, infraction, user, action)      @respect_role_hierarchy() @@ -258,7 +259,9 @@ class Infractions(InfractionScheduler, commands.Cog):          self.mod_log.ignore(Event.member_remove, user.id) -        action = ctx.guild.ban(user, reason=reason, delete_message_days=0) +        truncated_reason = textwrap.shorten(reason, width=512, placeholder="...") + +        action = ctx.guild.ban(user, reason=truncated_reason, delete_message_days=0)          await self.apply_infraction(ctx, infraction, user, action)          if infraction.get('expires_at') is not None: diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index beef7a8ef..9d28030d9 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -98,7 +98,10 @@ class ModLog(Cog, name="ModLog"):          footer: t.Optional[str] = None,      ) -> Context:          """Generate log embed and send to logging channel.""" -        embed = discord.Embed(description=text) +        # Truncate string directly here to avoid removing newlines +        embed = discord.Embed( +            description=text[:2045] + "..." if len(text) > 2048 else text +        )          if title and icon_url:              embed.set_author(name=title, icon_url=icon_url) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 012432e60..f0a3ad1b1 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -101,33 +101,17 @@ class InfractionScheduler(Scheduler):          dm_result = ""          dm_log_text = "" -        expiry_log_text = f"Expires: {expiry}" if expiry else "" +        expiry_log_text = f"\nExpires: {expiry}" if expiry else ""          log_title = "applied"          log_content = None - -        # DM the user about the infraction if it's not a shadow/hidden infraction. -        if not infraction["hidden"]: -            dm_result = f"{constants.Emojis.failmail} " -            dm_log_text = "\nDM: **Failed**" - -            # Sometimes user is a discord.Object; make it a proper user. -            try: -                if not isinstance(user, (discord.Member, discord.User)): -                    user = await self.bot.fetch_user(user.id) -            except discord.HTTPException as e: -                log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") -            else: -                # Accordingly display whether the user was successfully notified via DM. -                if await utils.notify_infraction(user, infr_type, expiry, reason, icon): -                    dm_result = ":incoming_envelope: " -                    dm_log_text = "\nDM: Sent" +        failed = False          if infraction["actor"] == self.bot.user.id:              log.trace(                  f"Infraction #{id_} actor is bot; including the reason in the confirmation message."              ) -            end_msg = f" (reason: {infraction['reason']})" +            end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})"          elif ctx.channel.id not in STAFF_CHANNELS:              log.trace(                  f"Infraction #{id_} context is not in a staff channel; omitting infraction count." @@ -164,12 +148,43 @@ class InfractionScheduler(Scheduler):                      log.warning(f"{log_msg}: bot lacks permissions.")                  else:                      log.exception(log_msg) +                failed = True + +        # DM the user about the infraction if it's not a shadow/hidden infraction. +        # Don't send DM when applying failed. +        if not infraction["hidden"] and not failed: +            dm_result = f"{constants.Emojis.failmail} " +            dm_log_text = "\nDM: **Failed**" + +            # Sometimes user is a discord.Object; make it a proper user. +            try: +                if not isinstance(user, (discord.Member, discord.User)): +                    user = await self.bot.fetch_user(user.id) +            except discord.HTTPException as e: +                log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") +            else: +                # Accordingly display whether the user was successfully notified via DM. +                if await utils.notify_infraction(user, infr_type, expiry, reason, icon): +                    dm_result = ":incoming_envelope: " +                    dm_log_text = "\nDM: Sent" + +        if failed: +            dm_log_text = "\nDM: **Canceled**" +            dm_result = f"{constants.Emojis.failmail} " +            log.trace(f"Deleted infraction {infraction['id']} from database because applying infraction failed.") +            try: +                await self.bot.api_client.delete(f"bot/infractions/{id_}") +            except ResponseCodeError as e: +                confirm_msg += " and failed to delete" +                log_title += " and failed to delete" +                log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.") +            infr_message = "" +        else: +            infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}"          # Send a confirmation message to the invoking context.          log.trace(f"Sending infraction #{id_} confirmation message.") -        await ctx.send( -            f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." -        ) +        await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.")          # Send a log message to the mod log.          log.trace(f"Sending apply mod log for infraction #{id_}.") @@ -180,9 +195,8 @@ class InfractionScheduler(Scheduler):              thumbnail=user.avatar_url_as(static_format="png"),              text=textwrap.dedent(f"""                  Member: {user.mention} (`{user.id}`) -                Actor: {ctx.message.author}{dm_log_text} +                Actor: {ctx.message.author}{dm_log_text}{expiry_log_text}                  Reason: {reason} -                {expiry_log_text}              """),              content=log_content,              footer=f"ID {infraction['id']}" @@ -294,6 +308,9 @@ class InfractionScheduler(Scheduler):                  f"{log_text.get('Failure', '')}"              ) +        # Move reason to end of entry to avoid cutting out some keys +        log_text["Reason"] = log_text.pop("Reason") +          # Send a log message to the mod log.          await self.mod_log.send_log_message(              icon_url=utils.INFRACTION_ICONS[infr_type][1], @@ -407,6 +424,9 @@ class InfractionScheduler(Scheduler):              user = self.bot.get_user(user_id)              avatar = user.avatar_url_as(static_format="png") if user else None +            # Move reason to end so when reason is too long, this is not gonna cut out required items. +            log_text["Reason"] = log_text.pop("Reason") +              log.trace(f"Sending deactivation mod log for infraction #{id_}.")              await self.mod_log.send_log_message(                  icon_url=utils.INFRACTION_ICONS[type_][1], @@ -416,7 +436,6 @@ class InfractionScheduler(Scheduler):                  text="\n".join(f"{k}: {v}" for k, v in log_text.items()),                  footer=f"ID: {id_}",                  content=log_content, -              )          return log_text diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 29855c325..45a010f00 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -183,10 +183,10 @@ class Superstarify(InfractionScheduler, Cog):              text=textwrap.dedent(f"""                  Member: {member.mention} (`{member.id}`)                  Actor: {ctx.message.author} -                Reason: {reason}                  Expires: {expiry_str}                  Old nickname: `{old_nick}`                  New nickname: `{forced_nick}` +                Reason: {reason}              """),              footer=f"ID {id_}"          ) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index e4e0f1ec2..1b716b2ea 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -143,12 +143,14 @@ async def notify_infraction(      """DM a user about their new infraction and return True if the DM is successful."""      log.trace(f"Sending {user} a DM about their {infr_type} infraction.") +    text = textwrap.dedent(f""" +        **Type:** {infr_type.capitalize()} +        **Expires:** {expires_at or "N/A"} +        **Reason:** {reason or "No reason provided."} +    """) +      embed = discord.Embed( -        description=textwrap.dedent(f""" -            **Type:** {infr_type.capitalize()} -            **Expires:** {expires_at or "N/A"} -            **Reason:** {reason or "No reason provided."} -            """), +        description=textwrap.shorten(text, width=2048, placeholder="..."),          colour=Colours.soft_red      ) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index bc7f53f68..6f03a3475 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -44,7 +44,7 @@ class Tags(Cog):                  tag = {                      "title": tag_title,                      "embed": { -                        "description": file.read_text(), +                        "description": file.read_text(encoding="utf8"),                      },                      "restricted_to": "developers",                  } diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index e4fb173e0..702d371f4 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -1,4 +1,5 @@  import logging +import textwrap  from collections import ChainMap  from discord.ext.commands import Cog, Context, group @@ -97,8 +98,8 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):              if len(history) > 1:                  total = f"({len(history) // 2} previous infractions in total)" -                end_reason = history[0]["reason"] -                start_reason = f"Watched: {history[1]['reason']}" +                end_reason = textwrap.shorten(history[0]["reason"], width=500, placeholder="...") +                start_reason = f"Watched: {textwrap.shorten(history[1]['reason'], width=500, placeholder='...')}"                  msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"          else:              msg = ":x: Failed to post the infraction: response was empty." diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index cd9c7e555..14547105f 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -106,8 +106,8 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          if history:              total = f"({len(history)} previous nominations in total)" -            start_reason = f"Watched: {history[0]['reason']}" -            end_reason = f"Unwatched: {history[0]['end_reason']}" +            start_reason = f"Watched: {textwrap.shorten(history[0]['reason'], width=500, placeholder='...')}" +            end_reason = f"Unwatched: {textwrap.shorten(history[0]['end_reason'], width=500, placeholder='...')}"              msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"          await ctx.send(msg) @@ -224,7 +224,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):                  Status: **Active**                  Date: {start_date}                  Actor: {actor.mention if actor else actor_id} -                Reason: {nomination_object["reason"]} +                Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")}                  Nomination ID: `{nomination_object["id"]}`                  ===============                  """ @@ -237,10 +237,10 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):                  Status: Inactive                  Date: {start_date}                  Actor: {actor.mention if actor else actor_id} -                Reason: {nomination_object["reason"]} +                Reason: {textwrap.shorten(nomination_object["reason"], width=200, placeholder="...")}                  End date: {end_date} -                Unwatch reason: {nomination_object["end_reason"]} +                Unwatch reason: {textwrap.shorten(nomination_object["end_reason"], width=200, placeholder="...")}                  Nomination ID: `{nomination_object["id"]}`                  ===============                  """ diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 643cd46e4..436778c46 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -280,8 +280,9 @@ class WatchChannel(metaclass=CogABCMeta):          else:              message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" +        footer = f"Added {time_delta} by {actor} | Reason: {reason}"          embed = Embed(description=f"{msg.author.mention} {message_jump}") -        embed.set_footer(text=f"Added {time_delta} by {actor} | Reason: {reason}") +        embed.set_footer(text=textwrap.shorten(footer, width=128, placeholder="..."))          await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..da4e92ccc --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): +    """Tests for ban and kick command reason truncation.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = Infractions(self.bot) +        self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) +        self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) +        self.guild = MockGuild(id=4567) +        self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + +    @patch("bot.cogs.moderation.utils.get_active_infraction") +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): +        """Should truncate reason for `ctx.guild.ban`.""" +        get_active_mock.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.bot.get_cog.return_value = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.ctx.guild.ban = Mock() + +        await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) +        self.ctx.guild.ban.assert_called_once_with( +            self.target, +            reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), +            delete_message_days=0 +        ) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value +        ) + +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_kick_reason_truncation(self, post_infraction_mock): +        """Should truncate reason for `Member.kick`.""" +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.target.kick = Mock() + +        await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) +        self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value +        ) diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..f2809f40a --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for moderation logs.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = ModLog(self.bot) +        self.channel = MockTextChannel() + +    async def test_log_entry_description_truncation(self): +        """Test that embed description for ModLog entry is truncated.""" +        self.bot.get_channel.return_value = self.channel +        await self.cog.send_log_message( +            icon_url="foo", +            colour=discord.Colour.blue(), +            title="bar", +            text="foo bar" * 3000 +        ) +        embed = self.channel.send.call_args[1]["embed"] +        self.assertEqual( +            embed.description, ("foo bar" * 3000)[:2045] + "..." +        ) diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): +    """Test the AntiMalware cog.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() +        self.cog = antimalware.AntiMalware(self.bot) +        self.message = MockMessage() + +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should not be deleted""" +        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_message_without_attachment(self): +        """Messages without attachments should result in no action.""" +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_direct_message_with_attachment(self): +        """Direct messages should have no action taken.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.guild = None + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_message_with_illegal_extension_gets_deleted(self): +        """A message containing an illegal extension should send an embed.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_called_once() + +    async def test_message_send_by_staff(self): +        """A message send by a member of staff should be ignored.""" +        staff_role = MockRole(id=STAFF_ROLES[0]) +        self.message.author.roles.append(staff_role) +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site""" +        attachment = MockAttachment(filename="python.py") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") + +        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt file should result in the correct embed.""" +        attachment = MockAttachment(filename="python.txt") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.TXT_EMBED_DESCRIPTION = Mock() +        antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        cmd_channel = self.bot.get_channel(Channels.bot_commands) + +        self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) +        antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + +    async def test_other_disallowed_extention_embed_description(self): +        """Test the description for a non .py/.txt disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        meta_channel = self.bot.get_channel(Channels.meta) + +        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( +            blocked_extensions_str=".disallowed", +            meta_channel_mention=meta_channel.mention +        ) + +    async def test_removing_deleted_message_logs(self): +        """Removing an already deleted message logs the correct message""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) +        self.message.delete.assert_called_once() + +    async def test_message_with_illegal_attachment_logs(self): +        """Deleting a message with an illegal attachment should result in a log.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) + +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (AntiMalwareConfig.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], [".disallowed"]), +            ([".disallowed"], [".disallowed"]), +            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] +                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): +    """Tests setup of the `AntiMalware` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        antimalware.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/helpers.py b/tests/helpers.py index 13283339b..faa839370 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -208,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):          """Simplified position-based comparisons similar to those of `discord.Role`."""          return self.position < other.position +    def __ge__(self, other): +        """Simplified position-based comparisons similar to those of `discord.Role`.""" +        return self.position >= other.position +  # Create a Member instance to get a realistic Mock of `discord.Member`  member_data = {'user': 'lemon', 'roles': [1]} | 
