diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/rules/__init__.py | 76 | ||||
| -rw-r--r-- | tests/bot/rules/test_attachments.py | 97 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst.py | 56 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst_shared.py | 59 | ||||
| -rw-r--r-- | tests/bot/rules/test_chars.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 54 | ||||
| -rw-r--r-- | tests/bot/rules/test_duplicates.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_links.py | 84 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 90 | ||||
| -rw-r--r-- | tests/bot/rules/test_newlines.py | 105 | ||||
| -rw-r--r-- | tests/bot/rules/test_role_mentions.py | 57 | ||||
| -rw-r--r-- | tests/bot/test_api.py | 64 | 
12 files changed, 635 insertions, 239 deletions
| diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..36c986fe1 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): +    """Encapsulation for test cases expected to fail.""" +    recent_messages: List[MockMessage] +    culprits: Iterable[str] +    n_violations: int + + +class RuleTest(unittest.TestCase, metaclass=ABCMeta): +    """ +    Abstract class for antispam rule test cases. + +    Tests for specific rules should inherit from `RuleTest` and implement +    `relevant_messages` and `get_report`. Each instance should also set the +    `apply` and `config` attributes as necessary. + +    The execution of test cases can then be delegated to the `run_allowed` +    and `run_disallowed` methods. +    """ + +    apply: Callable  # The tested rule's apply function +    config: Dict[str, int] + +    async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to pass.""" +        for recent_messages in cases: +            last_message = recent_messages[0] + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                config=self.config, +            ): +                self.assertIsNone( +                    await self.apply(last_message, recent_messages, self.config) +                ) + +    async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to fail.""" +        for case in cases: +            recent_messages, culprits, n_violations = case +            last_message = recent_messages[0] +            relevant_messages = self.relevant_messages(case) +            desired_output = ( +                self.get_report(case), +                culprits, +                relevant_messages, +            ) + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                relevant_messages=relevant_messages, +                n_violations=n_violations, +                config=self.config, +            ): +                self.assertTupleEqual( +                    await self.apply(last_message, recent_messages, self.config), +                    desired_output, +                ) + +    @abstractmethod +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        """Give expected relevant messages for `case`.""" +        raise NotImplementedError + +    @abstractmethod +    def get_report(self, case: DisallowedCase) -> str: +        """Give expected error report for `case`.""" +        raise NotImplementedError diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index d7187f315..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,98 +1,71 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_attachments: int - - -def msg(author: str, total_attachments: int) -> MockMessage: +def make_msg(author: str, total_attachments: int) -> MockMessage:      """Builds a message with `total_attachments` attachments."""      return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest):      """Tests applying the `attachments` antispam rule."""      def setUp(self): -        self.config = {"max": 5} +        self.apply = attachments.apply +        self.config = {"max": 5, "interval": 10}      @async_test      async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( -            [msg("bob", 0), msg("bob", 0), msg("bob", 0)], -            [msg("bob", 2), msg("bob", 2)], -            [msg("bob", 2), msg("alice", 2), msg("bob", 2)], +            [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], +            [make_msg("bob", 2), make_msg("bob", 2)], +            [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await attachments.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( -            Case( -                [msg("bob", 4), msg("bob", 0), msg("bob", 6)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],                  ("bob",), -                10 +                10,              ), -            Case( -                [msg("bob", 4), msg("alice", 6), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],                  ("bob",), -                6 +                6,              ), -            Case( -                [msg("alice", 6)], +            DisallowedCase( +                [make_msg("alice", 6)],                  ("alice",), -                6 +                6,              ), -            ( -                [msg("alice", 1) for _ in range(6)], +            DisallowedCase( +                [make_msg("alice", 1) for _ in range(6)],                  ("alice",), -                6 +                6,              ),          ) -        for recent_messages, culprit, total_attachments in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if ( -                    msg.author == last_message.author -                    and len(msg.attachments) > 0 -                ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and len(msg.attachments) > 0              ) +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                total_attachments=total_attachments, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_attachments} attachments in {self.config['max']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await attachments.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..72f0be0c7 --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,56 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with author set to `author`. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): +    """Tests the `burst` antispam rule.""" + +    def setUp(self): +        self.apply = burst.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("bob"), make_msg("bob")], +            [make_msg("bob"), make_msg("alice"), make_msg("bob")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                ("bob",), +                3, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..47367a5f8 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,59 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with the passed arg. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): +    """Tests the `burst_shared` antispam rule.""" + +    def setUp(self): +        self.apply = burst_shared.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """ +        Cases that do not violate the rule. + +        There really isn't more to test here than a single case. +        """ +        cases = ( +            [make_msg("spongebob"), make_msg("patrick")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                {"bob"}, +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                {"bob", "alice"}, +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return case.recent_messages + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..7cc36f49e --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_chars: int) -> MockMessage: +    """Build a message with arbitrary content of `n_chars` length.""" +    return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): +    """Tests the `chars` antispam rule.""" + +    def setUp(self): +        self.apply = chars.apply +        self.config = { +            "max": 20,  # Max allowed sum of chars per user +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of chars within limit.""" +        cases = ( +            [make_msg("bob", 0)], +            [make_msg("bob", 20)], +            [make_msg("bob", 15), make_msg("alice", 15)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the total amount of chars exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 21)], +                ("bob",), +                21, +            ), +            DisallowedCase( +                [make_msg("bob", 15), make_msg("bob", 15)], +                ("bob",), +                30, +            ), +            DisallowedCase( +                [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], +                ("alice",), +                30, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..0239b0b00 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + +discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: +    """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" +    return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): +    """Tests for the `discord_emojis` antispam rule.""" + +    def setUp(self): +        self.apply = discord_emojis.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of discord emojis within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of discord emojis.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..59e0fb6ef --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, content: str) -> MockMessage: +    """Give a MockMessage instance with `author` and `content` attrs.""" +    return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): +    """Tests the `duplicates` antispam rule.""" + +    def setUp(self): +        self.apply = duplicates.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", "A"), make_msg("alice", "A")], +            [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")],  # Non-duplicate +            [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")],  # Different author +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with too many duplicate messages from the same author.""" +        cases = ( +            DisallowedCase( +                [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], +                ("alice",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 duplicate messages, but only 3 from bob +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 message from bob, but only 3 duplicates +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and msg.content == last_message.content +            ) +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 02a5d5501..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,26 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import links +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_links: int - - -def msg(author: str, total_links: int) -> MockMessage: +def make_msg(author: str, total_links: int) -> MockMessage:      """Makes a message with `total_links` links."""      content = " ".join(["https://pydis.com"] * total_links)      return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest):      """Tests applying the `links` rule."""      def setUp(self): +        self.apply = links.apply          self.config = {              "max": 2,              "interval": 10 @@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase):      async def test_links_within_limit(self):          """Messages with an allowed amount of links."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 2), msg("alice", 2)]  # Only messages from latest author count +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await links.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( -            Case( -                [msg("bob", 1), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 1), make_msg("bob", 2)],                  ("bob",),                  3              ), -            Case( -                [msg("alice", 1), msg("alice", 1), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],                  ("alice",),                  3              ), -            Case( -                [msg("alice", 2), msg("bob", 3), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],                  ("alice",),                  3              )          ) -        for recent_messages, culprit, total_links in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_links=total_links, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_links} links in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await links.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage:      """Makes a message with `total_mentions` mentions."""      return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest):      """Tests applying the `mentions` antispam rule."""      def setUp(self): +        self.apply = mentions.apply          self.config = {              "max": 2, -            "interval": 10 +            "interval": 10,          }      @async_test      async def test_mentions_within_limit(self):          """Messages with an allowed amount of mentions."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 1), msg("alice", 2)] +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 1), make_msg("alice", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await mentions.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( -            Case( -                [msg("bob", 3)], +            DisallowedCase( +                [make_msg("bob", 3)],                  ("bob",), -                3 +                3,              ), -            Case( -                [msg("alice", 2), msg("alice", 0), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],                  ("alice",), -                3 +                3,              ), -            Case( -                [msg("bob", 2), msg("alice", 3), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",), -                4 +                4,              )          ) -        for recent_messages, culprit, total_mentions in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_mentions=total_mentions, -                cofig=self.config -            ): -                desired_output = ( -                    f"sent {total_mentions} mentions in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await mentions.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..d61c4609d --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,105 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: +    """Init a MockMessage instance with `author` and content configured by `newline_groups". + +    Configure content by passing a list of ints, where each int `n` will generate +    a separate group of `n` newlines. + +    Example: +        newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" +    """ +    content = " ".join("\n" * n for n in newline_groups) +    return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): +    """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + +    def setUp(self): +        self.apply = newlines.apply +        self.config = { +            "max": 5,  # Max sum of newlines in relevant messages +            "max_consecutive": 3,  # Max newlines in one group, in one message +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", [])],  # Single message with no newlines +            [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])],  # 5 newlines in 2 messages +            [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])],  # 5 newlines from each author +            [make_msg("bob", [1]), make_msg("alice", [5])],  # Alice breaks the rule, but only bob is relevant +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_total(self): +        """Cases which violate the rule by having too many newlines in total.""" +        cases = ( +            DisallowedCase(  # Alice sends a total of 6 newlines (disallowed) +                [make_msg("alice", [2, 2]), make_msg("alice", [2])], +                ("alice",), +                6, +            ), +            DisallowedCase(  # Here we test that only alice's newlines count in the sum +                [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], +                ("alice",), +                7, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): +    """ +    Tests the `newlines` antispam rule against max consecutive newline violations. + +    As these violations yield a different error report, they require a different +    `get_report` implementation. +    """ + +    def setUp(self): +        self.apply = newlines.apply +        self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + +    @async_test +    async def test_disallows_messages_consecutive(self): +        """Cases which violate the rule due to having too many consecutive newlines.""" +        cases = ( +            DisallowedCase(  # Bob sends a group of newlines too large +                [make_msg("bob", [4])], +                ("bob",), +                4, +            ), +            DisallowedCase(  # Alice sends 5 in total (allowed), but 4 in one group (disallowed) +                [make_msg("alice", [1]), make_msg("alice", [4])], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..b339cccf7 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_mentions: int) -> MockMessage: +    """Build a MockMessage instance with `n_mentions` role mentions.""" +    return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): +    """Tests for the `role_mentions` antispam rule.""" + +    def setUp(self): +        self.apply = role_mentions.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of role mentions within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of role mentions.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..bdfcc73e4 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,9 +1,7 @@ -import logging  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock  from bot import api -from tests.base import LoggingTestCase  from tests.helpers import async_test @@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):          self.assertEqual(error.response_text, "")          self.assertIs(error.response, self.error_api_response) -    def test_responde_code_error_string_representation_default_initialization(self): +    def test_response_code_error_string_representation_default_initialization(self):          """Test the string representation of `ResponseCodeError` initialized without text or json."""          error = api.ResponseCodeError(response=self.error_api_response)          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):              response_text=text_data          )          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): -    """Tests the bot's API Log Handler.""" - -    @classmethod -    def setUpClass(cls): -        cls.debug_log_record = logging.LogRecord( -            name='my.logger', level=logging.DEBUG, -            pathname='my/logger.py', lineno=666, -            msg="Lemon wins", args=(), -            exc_info=None -        ) - -        cls.trace_log_record = logging.LogRecord( -            name='my.logger', level=logging.TRACE, -            pathname='my/logger.py', lineno=666, -            msg="This will not be logged", args=(), -            exc_info=None -        ) - -    def setUp(self): -        self.log_handler = api.APILoggingHandler(None) - -    def test_emit_appends_to_queue_with_stopped_event_loop(self): -        """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" -        with patch("bot.api.APILoggingHandler.ship_off") as ship_off: -            # Patch `ship_off` to ease testing against the return value of this coroutine. -            ship_off.return_value = 42 -            self.log_handler.emit(self.debug_log_record) - -        self.assertListEqual(self.log_handler.queue, [42]) - -    def test_emit_ignores_less_than_debug(self): -        """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" -        self.log_handler.emit(self.trace_log_record) -        self.assertListEqual(self.log_handler.queue, []) - -    def test_schedule_queued_tasks_for_empty_queue(self): -        """`APILoggingHandler` should not schedule anything when the queue is empty.""" -        with self.assertNotLogs(level=logging.DEBUG): -            self.log_handler.schedule_queued_tasks() - -    def test_schedule_queued_tasks_for_nonempty_queue(self): -        """`APILoggingHandler` should schedule logs when the queue is not empty.""" -        log = logging.getLogger("bot.api") - -        with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: -            self.log_handler.queue = [555] -            self.log_handler.schedule_queued_tasks() -            self.assertListEqual(self.log_handler.queue, []) -            create_task.assert_called_once_with(555) - -            [record] = logs.records -            self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") -            self.assertEqual(record.levelno, logging.DEBUG) -            self.assertEqual(record.name, 'bot.api') -            self.assertIn('via_handler', record.__dict__) | 
