diff options
| author | 2021-11-22 07:48:28 +0000 | |
|---|---|---|
| committer | 2021-11-22 07:48:28 +0000 | |
| commit | 52fdc53e6282ed819e3aed67ba9dc4f6fe90abc2 (patch) | |
| tree | a3a16202d452bdb56d0794cd5c3e6352ba76890c /tests | |
| parent | Add ability to reply to message for `!remind` (diff) | |
| parent | Merge pull request #1966 from python-discord/disable-file-logs (diff) | |
Merge branch 'main' into allow-reply-in-remind
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 94 | 
1 files changed, 92 insertions, 2 deletions
diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index ccc842050..cfe0c4b03 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -3,13 +3,16 @@ import enum  import logging  import typing as t  import unittest +from unittest import mock  from unittest.mock import AsyncMock, MagicMock, Mock, call, patch  import aiohttp  import discord +from async_rediscache import RedisSession  from bot.constants import Colours  from bot.exts.moderation import incidents +from bot.utils.messages import format_user  from tests.helpers import (      MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,      MockUser @@ -276,6 +279,22 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      the instance as they wish.      """ +    session = None + +    async def flush(self): +        """Flush everything from the database to prevent carry-overs between tests.""" +        with await self.session.pool as connection: +            await connection.flushall() + +    async def asyncSetUp(self):  # noqa: N802 +        self.session = RedisSession(use_fakeredis=True) +        await self.session.connect() +        await self.flush() + +    async def asyncTearDown(self):  # noqa: N802 +        if self.session: +            await self.session.close() +      def setUp(self):          """          Prepare a fresh `Incidents` instance for each test. @@ -506,7 +525,7 @@ class TestProcessEvent(TestIncidents):          with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):              await self.cog_instance.process_event(                  reaction=incidents.Signal.ACTIONED.value, -                incident=MockMessage(), +                incident=MockMessage(id=123),                  member=MockMember(roles=[MockRole(id=1)])              ) @@ -526,7 +545,7 @@ class TestProcessEvent(TestIncidents):              with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):                  await self.cog_instance.process_event(                      reaction=incidents.Signal.ACTIONED.value, -                    incident=MockMessage(), +                    incident=MockMessage(id=123),                      member=MockMember(roles=[MockRole(id=1)])                  )          except asyncio.TimeoutError: @@ -761,3 +780,74 @@ class TestOnMessage(TestIncidents):              await self.cog_instance.on_message(MockMessage())          mock_add_signals.assert_not_called() + + +class TestMessageLinkEmbeds(TestIncidents): +    """Tests for `extract_message_links` coroutine.""" + +    async def test_shorten_text(self): +        """Test all cases of text shortening by mocking messages.""" +        tests = { +            "thisisasingleword"*10: "thisisasinglewordthisisasinglewordthisisasinglewor...", + +            "\n".join("Lets make a new line test".split()): "Lets\nmake\na...", + +            'Hello, World!' * 300: ( +                "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!" +                "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!" +                "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!" +                "Hello, World!Hello, World!H..." +            ) +        } + +        for content, expected_conversion in tests.items(): +            with self.subTest(content=content, expected_conversion=expected_conversion): +                conversion = incidents.shorten_text(content) +                self.assertEqual(conversion, expected_conversion) + +    async def extract_and_form_message_link_embeds(self): +        """ +        Extract message links from a mocked message and form the message link embed. + +        Considers all types of message links, discord supports. +        """ +        self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5) +        self.guild_id = self.guild_id_patcher.start() + +        msg = MockMessage(id=555, content="Hello, World!" * 3000) +        msg.channel.mention = "#lemonade-stand" + +        msg_links = [ +            # Valid Message links +            f"https://discord.com/channels/{self.guild_id}/{msg.channel.discord_id}/{msg.discord_id}", +            f"http://canary.discord.com/channels/{self.guild_id}/{msg.channel.discord_id}/{msg.discord_id}", + +            # Invalid Message links +            f"https://discord.com/channels/{msg.channel.discord_id}/{msg.discord_id}", +            f"https://discord.com/channels/{self.guild_id}/{msg.channel.discord_id}000/{msg.discord_id}", +        ] + +        incident_msg = MockMessage( +            id=777, +            content=( +                f"I would like to report the following messages, " +                f"as they break our rules: \n{', '.join(msg_links)}" +            ) +        ) + +        with patch( +                "bot.exts.moderation.incidents.Incidents.extract_message_links", AsyncMock() +        ) as mock_extract_message_links: +            embeds = mock_extract_message_links(incident_msg) +            description = ( +                f"**Author:** {format_user(msg.author)}\n" +                f"**Channel:** {msg.channel.mention} ({msg.channel.category}/#{msg.channel.name})\n" +                f"**Content:** {('Hello, World!' * 3000)[:300] + '...'}\n" +            ) + +            # Check number of embeds returned with number of valid links +            self.assertEqual(len(embeds), 2) + +            # Check for the embed descriptions +            for embed in embeds: +                self.assertEqual(embed.description, description)  |