diff options
Diffstat (limited to '')
| -rw-r--r-- | tests/bot/rules/test_attachments.py | 110 | ||||
| -rw-r--r-- | tests/bot/rules/test_links.py | 26 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 95 | 
3 files changed, 184 insertions, 47 deletions
| diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index 4bb0acf7c..d7187f315 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,52 +1,98 @@ -import asyncio  import unittest -from dataclasses import dataclass -from typing import Any, List +from typing import List, NamedTuple, Tuple  from bot.rules import attachments +from tests.helpers import MockMessage, async_test -# Using `MagicMock` sadly doesn't work for this usecase -# since it's __eq__ compares the MagicMock's ID. We just -# want to compare the actual attributes we set. -@dataclass -class FakeMessage: -    author: str -    attachments: List[Any] +class Case(NamedTuple): +    recent_messages: List[MockMessage] +    culprit: Tuple[str] +    total_attachments: int -def msg(total_attachments: int) -> FakeMessage: -    return FakeMessage(author='lemon', attachments=list(range(total_attachments))) +def msg(author: str, total_attachments: int) -> MockMessage: +    """Builds a message with `total_attachments` attachments.""" +    return MockMessage(author=author, attachments=list(range(total_attachments)))  class AttachmentRuleTests(unittest.TestCase): -    """Tests applying the `attachment` antispam rule.""" +    """Tests applying the `attachments` antispam rule.""" -    def test_allows_messages_without_too_many_attachments(self): +    def setUp(self): +        self.config = {"max": 5} + +    @async_test +    async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( -            (msg(0), msg(0), msg(0)), -            (msg(2), msg(2)), -            (msg(0),), +            [msg("bob", 0), msg("bob", 0), msg("bob", 0)], +            [msg("bob", 2), msg("bob", 2)], +            [msg("bob", 2), msg("alice", 2), msg("bob", 2)],          ) -        for last_message, *recent_messages in cases: -            with self.subTest(last_message=last_message, recent_messages=recent_messages): -                coro = attachments.apply(last_message, recent_messages, {'max': 5}) -                self.assertIsNone(asyncio.run(coro)) +        for recent_messages in cases: +            last_message = recent_messages[0] + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                config=self.config +            ): +                self.assertIsNone( +                    await attachments.apply(last_message, recent_messages, self.config) +                ) -    def test_disallows_messages_with_too_many_attachments(self): +    @async_test +    async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( -            ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), -            ((msg(6),), [msg(6)], 6), -            ((msg(1),) * 6, [msg(1)] * 6, 6), +            Case( +                [msg("bob", 4), msg("bob", 0), msg("bob", 6)], +                ("bob",), +                10 +            ), +            Case( +                [msg("bob", 4), msg("alice", 6), msg("bob", 2)], +                ("bob",), +                6 +            ), +            Case( +                [msg("alice", 6)], +                ("alice",), +                6 +            ), +            ( +                [msg("alice", 1) for _ in range(6)], +                ("alice",), +                6 +            ),          ) -        for messages, relevant_messages, total in cases: -            with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total): -                last_message, *recent_messages = messages -                coro = attachments.apply(last_message, recent_messages, {'max': 5}) -                self.assertEqual( -                    asyncio.run(coro), -                    (f"sent {total} attachments in 5s", ('lemon',), relevant_messages) + +        for recent_messages, culprit, total_attachments in cases: +            last_message = recent_messages[0] +            relevant_messages = tuple( +                msg +                for msg in recent_messages +                if ( +                    msg.author == last_message.author +                    and len(msg.attachments) > 0 +                ) +            ) + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                relevant_messages=relevant_messages, +                total_attachments=total_attachments, +                config=self.config +            ): +                desired_output = ( +                    f"sent {total_attachments} attachments in {self.config['max']}s", +                    culprit, +                    relevant_messages +                ) +                self.assertTupleEqual( +                    await attachments.apply(last_message, recent_messages, self.config), +                    desired_output                  ) diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index be832843b..02a5d5501 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,25 +2,19 @@ import unittest  from typing import List, NamedTuple, Tuple  from bot.rules import links -from tests.helpers import async_test - - -class FakeMessage(NamedTuple): -    author: str -    content: str +from tests.helpers import MockMessage, async_test  class Case(NamedTuple): -    recent_messages: List[FakeMessage] -    relevant_messages: Tuple[FakeMessage] +    recent_messages: List[MockMessage]      culprit: Tuple[str]      total_links: int -def msg(author: str, total_links: int) -> FakeMessage: -    """Makes a message with *total_links* links.""" +def msg(author: str, total_links: int) -> MockMessage: +    """Makes a message with `total_links` links."""      content = " ".join(["https://pydis.com"] * total_links) -    return FakeMessage(author=author, content=content) +    return MockMessage(author=author, content=content)  class LinksTests(unittest.TestCase): @@ -61,26 +55,28 @@ class LinksTests(unittest.TestCase):          cases = (              Case(                  [msg("bob", 1), msg("bob", 2)], -                (msg("bob", 1), msg("bob", 2)),                  ("bob",),                  3              ),              Case(                  [msg("alice", 1), msg("alice", 1), msg("alice", 1)], -                (msg("alice", 1), msg("alice", 1), msg("alice", 1)),                  ("alice",),                  3              ),              Case(                  [msg("alice", 2), msg("bob", 3), msg("alice", 1)], -                (msg("alice", 2), msg("alice", 1)),                  ("alice",),                  3              )          ) -        for recent_messages, relevant_messages, culprit, total_links in cases: +        for recent_messages, culprit, total_links in cases:              last_message = recent_messages[0] +            relevant_messages = tuple( +                msg +                for msg in recent_messages +                if msg.author == last_message.author +            )              with self.subTest(                  last_message=last_message, diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py new file mode 100644 index 000000000..ad49ead32 --- /dev/null +++ b/tests/bot/rules/test_mentions.py @@ -0,0 +1,95 @@ +import unittest +from typing import List, NamedTuple, Tuple + +from bot.rules import mentions +from tests.helpers import MockMessage, async_test + + +class Case(NamedTuple): +    recent_messages: List[MockMessage] +    culprit: Tuple[str] +    total_mentions: int + + +def msg(author: str, total_mentions: int) -> MockMessage: +    """Makes a message with `total_mentions` mentions.""" +    return MockMessage(author=author, mentions=list(range(total_mentions))) + + +class TestMentions(unittest.TestCase): +    """Tests applying the `mentions` antispam rule.""" + +    def setUp(self): +        self.config = { +            "max": 2, +            "interval": 10 +        } + +    @async_test +    async def test_mentions_within_limit(self): +        """Messages with an allowed amount of mentions.""" +        cases = ( +            [msg("bob", 0)], +            [msg("bob", 2)], +            [msg("bob", 1), msg("bob", 1)], +            [msg("bob", 1), msg("alice", 2)] +        ) + +        for recent_messages in cases: +            last_message = recent_messages[0] + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                config=self.config +            ): +                self.assertIsNone( +                    await mentions.apply(last_message, recent_messages, self.config) +                ) + +    @async_test +    async def test_mentions_exceeding_limit(self): +        """Messages with a higher than allowed amount of mentions.""" +        cases = ( +            Case( +                [msg("bob", 3)], +                ("bob",), +                3 +            ), +            Case( +                [msg("alice", 2), msg("alice", 0), msg("alice", 1)], +                ("alice",), +                3 +            ), +            Case( +                [msg("bob", 2), msg("alice", 3), msg("bob", 2)], +                ("bob",), +                4 +            ) +        ) + +        for recent_messages, culprit, total_mentions in cases: +            last_message = recent_messages[0] +            relevant_messages = tuple( +                msg +                for msg in recent_messages +                if msg.author == last_message.author +            ) + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                relevant_messages=relevant_messages, +                culprit=culprit, +                total_mentions=total_mentions, +                cofig=self.config +            ): +                desired_output = ( +                    f"sent {total_mentions} mentions in {self.config['interval']}s", +                    culprit, +                    relevant_messages +                ) +                self.assertTupleEqual( +                    await mentions.apply(last_message, recent_messages, self.config), +                    desired_output +                ) | 
