diff options
| author | 2020-02-02 18:31:50 +0100 | |
|---|---|---|
| committer | 2020-02-02 18:31:50 +0100 | |
| commit | b89f9c55329aa44448d55963e22ce4f7a6ec0ff6 (patch) | |
| tree | 906d68dc862565e650832e15d6e9599aba253ad8 /tests | |
| parent | Implement RuleTest ABC (diff) | |
Adjust existing tests to inherit from RuleTest ABC
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/rules/test_attachments.py | 79 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst.py | 44 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst_shared.py | 40 | ||||
| -rw-r--r-- | tests/bot/rules/test_chars.py | 53 | ||||
| -rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 44 | ||||
| -rw-r--r-- | tests/bot/rules/test_links.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 76 | ||||
| -rw-r--r-- | tests/bot/rules/test_role_mentions.py | 49 | 
8 files changed, 157 insertions, 294 deletions
| diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index 419336417..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,25 +1,20 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_attachments: int - -  def make_msg(author: str, total_attachments: int) -> MockMessage:      """Builds a message with `total_attachments` attachments."""      return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest):      """Tests applying the `attachments` antispam rule."""      def setUp(self): +        self.apply = attachments.apply          self.config = {"max": 5, "interval": 10}      @async_test @@ -31,68 +26,46 @@ class AttachmentRuleTests(unittest.TestCase):              [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await attachments.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( -            Case( +            DisallowedCase(                  [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],                  ("bob",), -                10 +                10,              ), -            Case( +            DisallowedCase(                  [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],                  ("bob",), -                6 +                6,              ), -            Case( +            DisallowedCase(                  [make_msg("alice", 6)],                  ("alice",), -                6 +                6,              ), -            ( +            DisallowedCase(                  [make_msg("alice", 1) for _ in range(6)],                  ("alice",), -                6 +                6,              ),          ) -        for recent_messages, culprit, total_attachments in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if ( -                    msg.author == last_message.author -                    and len(msg.attachments) > 0 -                ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and len(msg.attachments) > 0              ) +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                total_attachments=total_attachments, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_attachments} attachments in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await attachments.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py index afcc5554d..72f0be0c7 100644 --- a/tests/bot/rules/test_burst.py +++ b/tests/bot/rules/test_burst.py @@ -1,6 +1,7 @@ -import unittest +from typing import Iterable  from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test @@ -13,10 +14,11 @@ def make_msg(author: str) -> MockMessage:      return MockMessage(author=author) -class BurstRuleTests(unittest.TestCase): +class BurstRuleTests(RuleTest):      """Tests the `burst` antispam rule."""      def setUp(self): +        self.apply = burst.apply          self.config = {"max": 2, "interval": 10}      @async_test @@ -27,44 +29,28 @@ class BurstRuleTests(unittest.TestCase):              [make_msg("bob"), make_msg("alice"), make_msg("bob")],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config): -                self.assertIsNone(await burst.apply(last_message, recent_messages, self.config)) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( -            ( +            DisallowedCase(                  [make_msg("bob"), make_msg("bob"), make_msg("bob")], -                "bob", +                ("bob",),                  3,              ), -            ( +            DisallowedCase(                  [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], -                "bob", +                ("bob",),                  3,              ),          ) -        for recent_messages, culprit, total_msgs in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit) -            expected_output = ( -                f"sent {total_msgs} messages in {self.config['interval']}s", -                (culprit,), -                relevant_messages, -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -                expected_output=expected_output, -            ): -                self.assertTupleEqual( -                    await burst.apply(last_message, recent_messages, self.config), -                    expected_output, -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py index 401e0b666..47367a5f8 100644 --- a/tests/bot/rules/test_burst_shared.py +++ b/tests/bot/rules/test_burst_shared.py @@ -1,6 +1,7 @@ -import unittest +from typing import Iterable  from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test @@ -13,10 +14,11 @@ def make_msg(author: str) -> MockMessage:      return MockMessage(author=author) -class BurstSharedRuleTests(unittest.TestCase): +class BurstSharedRuleTests(RuleTest):      """Tests the `burst_shared` antispam rule."""      def setUp(self): +        self.apply = burst_shared.apply          self.config = {"max": 2, "interval": 10}      @async_test @@ -26,42 +28,32 @@ class BurstSharedRuleTests(unittest.TestCase):          There really isn't more to test here than a single case.          """ -        recent_messages = [make_msg("spongebob"), make_msg("patrick")] -        last_message = recent_messages[0] +        cases = ( +            [make_msg("spongebob"), make_msg("patrick")], +        ) -        self.assertIsNone(await burst_shared.apply(last_message, recent_messages, self.config)) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the amount of messages exceeds the limit, triggering the rule."""          cases = ( -            ( +            DisallowedCase(                  [make_msg("bob"), make_msg("bob"), make_msg("bob")],                  {"bob"},                  3,              ), -            ( +            DisallowedCase(                  [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],                  {"bob", "alice"},                  4,              ),          ) -        for recent_messages, culprits, total_msgs in cases: -            last_message = recent_messages[0] -            expected_output = ( -                f"sent {total_msgs} messages in {self.config['interval']}s", -                culprits, -                recent_messages, -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return case.recent_messages -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -                expected_output=expected_output, -            ): -                self.assertTupleEqual( -                    await burst_shared.apply(last_message, recent_messages, self.config), -                    expected_output, -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py index f466a898e..7cc36f49e 100644 --- a/tests/bot/rules/test_chars.py +++ b/tests/bot/rules/test_chars.py @@ -1,6 +1,7 @@ -import unittest +from typing import Iterable  from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test @@ -9,10 +10,11 @@ def make_msg(author: str, n_chars: int) -> MockMessage:      return MockMessage(author=author, content="A" * n_chars) -class CharsRuleTests(unittest.TestCase): +class CharsRuleTests(RuleTest):      """Tests the `chars` antispam rule."""      def setUp(self): +        self.apply = chars.apply          self.config = {              "max": 20,  # Max allowed sum of chars per user              "interval": 10, @@ -27,49 +29,38 @@ class CharsRuleTests(unittest.TestCase):              [make_msg("bob", 15), make_msg("alice", 15)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config): -                self.assertIsNone(await chars.apply(last_message, recent_messages, self.config)) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases where the total amount of chars exceeds the limit, triggering the rule."""          cases = ( -            ( +            DisallowedCase(                  [make_msg("bob", 21)], -                "bob", +                ("bob",),                  21,              ), -            ( +            DisallowedCase(                  [make_msg("bob", 15), make_msg("bob", 15)], -                "bob", +                ("bob",),                  30,              ), -            ( +            DisallowedCase(                  [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], -                "alice", +                ("alice",),                  30,              ),          ) -        for recent_messages, culprit, total_chars in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit) -            expected_output = ( -                f"sent {total_chars} characters in {self.config['interval']}s", -                (culprit,), -                relevant_messages, -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -                expected_output=expected_output, -            ): -                self.assertTupleEqual( -                    await chars.apply(last_message, recent_messages, self.config), -                    expected_output, -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 1c56c9563..0239b0b00 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -1,6 +1,7 @@ -import unittest +from typing import Iterable  from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test  discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> @@ -11,10 +12,11 @@ def make_msg(author: str, n_emojis: int) -> MockMessage:      return MockMessage(author=author, content=discord_emoji * n_emojis) -class DiscordEmojisRuleTests(unittest.TestCase): +class DiscordEmojisRuleTests(RuleTest):      """Tests for the `discord_emojis` antispam rule."""      def setUp(self): +        self.apply = discord_emojis.apply          self.config = {"max": 2, "interval": 10}      @async_test @@ -25,44 +27,28 @@ class DiscordEmojisRuleTests(unittest.TestCase):              [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config): -                self.assertIsNone(await discord_emojis.apply(last_message, recent_messages, self.config)) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of discord emojis."""          cases = ( -            ( +            DisallowedCase(                  [make_msg("bob", 3)], -                "bob", +                ("bob",),                  3,              ), -            ( +            DisallowedCase(                  [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                "alice", +                ("alice",),                  4,              ),          ) -        for recent_messages, culprit, total_emojis in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit) -            expected_output = ( -                f"sent {total_emojis} emojis in {self.config['interval']}s", -                (culprit,), -                relevant_messages, -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -                expected_output=expected_output, -            ): -                self.assertTupleEqual( -                    await discord_emojis.apply(last_message, recent_messages, self.config), -                    expected_output, -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index b77e01c84..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,26 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import links +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_links: int - -  def make_msg(author: str, total_links: int) -> MockMessage:      """Makes a message with `total_links` links."""      content = " ".join(["https://pydis.com"] * total_links)      return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest):      """Tests applying the `links` rule."""      def setUp(self): +        self.apply = links.apply          self.config = {              "max": 2,              "interval": 10 @@ -37,61 +32,38 @@ class LinksTests(unittest.TestCase):              [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await links.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( -            Case( +            DisallowedCase(                  [make_msg("bob", 1), make_msg("bob", 2)],                  ("bob",),                  3              ), -            Case( +            DisallowedCase(                  [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],                  ("alice",),                  3              ), -            Case( +            DisallowedCase(                  [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],                  ("alice",),                  3              )          ) -        for recent_messages, culprit, total_links in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_links=total_links, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_links} links in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await links.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index 43211f097..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,28 +1,23 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_mentions: int - -  def make_msg(author: str, total_mentions: int) -> MockMessage:      """Makes a message with `total_mentions` mentions."""      return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest):      """Tests applying the `mentions` antispam rule."""      def setUp(self): +        self.apply = mentions.apply          self.config = {              "max": 2, -            "interval": 10 +            "interval": 10,          }      @async_test @@ -32,64 +27,41 @@ class TestMentions(unittest.TestCase):              [make_msg("bob", 0)],              [make_msg("bob", 2)],              [make_msg("bob", 1), make_msg("bob", 1)], -            [make_msg("bob", 1), make_msg("alice", 2)] +            [make_msg("bob", 1), make_msg("alice", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await mentions.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( -            Case( +            DisallowedCase(                  [make_msg("bob", 3)],                  ("bob",), -                3 +                3,              ), -            Case( +            DisallowedCase(                  [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],                  ("alice",), -                3 +                3,              ), -            Case( +            DisallowedCase(                  [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",), -                4 +                4,              )          ) -        for recent_messages, culprit, total_mentions in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_mentions=total_mentions, -                cofig=self.config -            ): -                desired_output = ( -                    f"sent {total_mentions} mentions in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await mentions.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py index 6377ffbc8..b339cccf7 100644 --- a/tests/bot/rules/test_role_mentions.py +++ b/tests/bot/rules/test_role_mentions.py @@ -1,6 +1,7 @@ -import unittest +from typing import Iterable  from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test @@ -9,10 +10,11 @@ def make_msg(author: str, n_mentions: int) -> MockMessage:      return MockMessage(author=author, role_mentions=[None] * n_mentions) -class RoleMentionsRuleTests(unittest.TestCase): +class RoleMentionsRuleTests(RuleTest):      """Tests for the `role_mentions` antispam rule."""      def setUp(self): +        self.apply = role_mentions.apply          self.config = {"max": 2, "interval": 10}      @async_test @@ -23,44 +25,33 @@ class RoleMentionsRuleTests(unittest.TestCase):              [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config): -                self.assertIsNone(await role_mentions.apply(last_message, recent_messages, self.config)) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_beyond_limit(self):          """Cases with more than the allowed amount of role mentions."""          cases = ( -            ( +            DisallowedCase(                  [make_msg("bob", 3)], -                "bob", +                ("bob",),                  3,              ), -            ( +            DisallowedCase(                  [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                "alice", +                ("alice",),                  4,              ),          ) -        for recent_messages, culprit, total_mentions in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit) -            expected_output = ( -                f"sent {total_mentions} role mentions in {self.config['interval']}s", -                (culprit,), -                relevant_messages, -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -                expected_output=expected_output, -            ): -                self.assertTupleEqual( -                    await role_mentions.apply(last_message, recent_messages, self.config), -                    expected_output, -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" | 
