aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar kwzrd <[email protected]>2020-02-02 18:31:50 +0100
committerGravatar kwzrd <[email protected]>2020-02-02 18:31:50 +0100
commitb89f9c55329aa44448d55963e22ce4f7a6ec0ff6 (patch)
tree906d68dc862565e650832e15d6e9599aba253ad8
parentImplement RuleTest ABC (diff)
Adjust existing tests to inherit from RuleTest ABC
-rw-r--r--tests/bot/rules/test_attachments.py79
-rw-r--r--tests/bot/rules/test_burst.py44
-rw-r--r--tests/bot/rules/test_burst_shared.py40
-rw-r--r--tests/bot/rules/test_chars.py53
-rw-r--r--tests/bot/rules/test_discord_emojis.py44
-rw-r--r--tests/bot/rules/test_links.py66
-rw-r--r--tests/bot/rules/test_mentions.py76
-rw-r--r--tests/bot/rules/test_role_mentions.py49
8 files changed, 157 insertions, 294 deletions
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index 419336417..e54b4b5b8 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,25 +1,20 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import attachments
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_attachments: int
-
-
def make_msg(author: str, total_attachments: int) -> MockMessage:
"""Builds a message with `total_attachments` attachments."""
return MockMessage(author=author, attachments=list(range(total_attachments)))
-class AttachmentRuleTests(unittest.TestCase):
+class AttachmentRuleTests(RuleTest):
"""Tests applying the `attachments` antispam rule."""
def setUp(self):
+ self.apply = attachments.apply
self.config = {"max": 5, "interval": 10}
@async_test
@@ -31,68 +26,46 @@ class AttachmentRuleTests(unittest.TestCase):
[make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await attachments.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- Case(
+ DisallowedCase(
[make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
("bob",),
- 10
+ 10,
),
- Case(
+ DisallowedCase(
[make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
("bob",),
- 6
+ 6,
),
- Case(
+ DisallowedCase(
[make_msg("alice", 6)],
("alice",),
- 6
+ 6,
),
- (
+ DisallowedCase(
[make_msg("alice", 1) for _ in range(6)],
("alice",),
- 6
+ 6,
),
)
- for recent_messages, culprit, total_attachments in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
)
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- total_attachments=total_attachments,
- config=self.config
- ):
- desired_output = (
- f"sent {total_attachments} attachments in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await attachments.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
index afcc5554d..72f0be0c7 100644
--- a/tests/bot/rules/test_burst.py
+++ b/tests/bot/rules/test_burst.py
@@ -1,6 +1,7 @@
-import unittest
+from typing import Iterable
from bot.rules import burst
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
@@ -13,10 +14,11 @@ def make_msg(author: str) -> MockMessage:
return MockMessage(author=author)
-class BurstRuleTests(unittest.TestCase):
+class BurstRuleTests(RuleTest):
"""Tests the `burst` antispam rule."""
def setUp(self):
+ self.apply = burst.apply
self.config = {"max": 2, "interval": 10}
@async_test
@@ -27,44 +29,28 @@ class BurstRuleTests(unittest.TestCase):
[make_msg("bob"), make_msg("alice"), make_msg("bob")],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config):
- self.assertIsNone(await burst.apply(last_message, recent_messages, self.config))
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
- (
+ DisallowedCase(
[make_msg("bob"), make_msg("bob"), make_msg("bob")],
- "bob",
+ ("bob",),
3,
),
- (
+ DisallowedCase(
[make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
- "bob",
+ ("bob",),
3,
),
)
- for recent_messages, culprit, total_msgs in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit)
- expected_output = (
- f"sent {total_msgs} messages in {self.config['interval']}s",
- (culprit,),
- relevant_messages,
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- expected_output=expected_output,
- ):
- self.assertTupleEqual(
- await burst.apply(last_message, recent_messages, self.config),
- expected_output,
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
index 401e0b666..47367a5f8 100644
--- a/tests/bot/rules/test_burst_shared.py
+++ b/tests/bot/rules/test_burst_shared.py
@@ -1,6 +1,7 @@
-import unittest
+from typing import Iterable
from bot.rules import burst_shared
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
@@ -13,10 +14,11 @@ def make_msg(author: str) -> MockMessage:
return MockMessage(author=author)
-class BurstSharedRuleTests(unittest.TestCase):
+class BurstSharedRuleTests(RuleTest):
"""Tests the `burst_shared` antispam rule."""
def setUp(self):
+ self.apply = burst_shared.apply
self.config = {"max": 2, "interval": 10}
@async_test
@@ -26,42 +28,32 @@ class BurstSharedRuleTests(unittest.TestCase):
There really isn't more to test here than a single case.
"""
- recent_messages = [make_msg("spongebob"), make_msg("patrick")]
- last_message = recent_messages[0]
+ cases = (
+ [make_msg("spongebob"), make_msg("patrick")],
+ )
- self.assertIsNone(await burst_shared.apply(last_message, recent_messages, self.config))
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
- (
+ DisallowedCase(
[make_msg("bob"), make_msg("bob"), make_msg("bob")],
{"bob"},
3,
),
- (
+ DisallowedCase(
[make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
{"bob", "alice"},
4,
),
)
- for recent_messages, culprits, total_msgs in cases:
- last_message = recent_messages[0]
- expected_output = (
- f"sent {total_msgs} messages in {self.config['interval']}s",
- culprits,
- recent_messages,
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return case.recent_messages
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- expected_output=expected_output,
- ):
- self.assertTupleEqual(
- await burst_shared.apply(last_message, recent_messages, self.config),
- expected_output,
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
index f466a898e..7cc36f49e 100644
--- a/tests/bot/rules/test_chars.py
+++ b/tests/bot/rules/test_chars.py
@@ -1,6 +1,7 @@
-import unittest
+from typing import Iterable
from bot.rules import chars
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
@@ -9,10 +10,11 @@ def make_msg(author: str, n_chars: int) -> MockMessage:
return MockMessage(author=author, content="A" * n_chars)
-class CharsRuleTests(unittest.TestCase):
+class CharsRuleTests(RuleTest):
"""Tests the `chars` antispam rule."""
def setUp(self):
+ self.apply = chars.apply
self.config = {
"max": 20, # Max allowed sum of chars per user
"interval": 10,
@@ -27,49 +29,38 @@ class CharsRuleTests(unittest.TestCase):
[make_msg("bob", 15), make_msg("alice", 15)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config):
- self.assertIsNone(await chars.apply(last_message, recent_messages, self.config))
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the total amount of chars exceeds the limit, triggering the rule."""
cases = (
- (
+ DisallowedCase(
[make_msg("bob", 21)],
- "bob",
+ ("bob",),
21,
),
- (
+ DisallowedCase(
[make_msg("bob", 15), make_msg("bob", 15)],
- "bob",
+ ("bob",),
30,
),
- (
+ DisallowedCase(
[make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
- "alice",
+ ("alice",),
30,
),
)
- for recent_messages, culprit, total_chars in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit)
- expected_output = (
- f"sent {total_chars} characters in {self.config['interval']}s",
- (culprit,),
- relevant_messages,
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- expected_output=expected_output,
- ):
- self.assertTupleEqual(
- await chars.apply(last_message, recent_messages, self.config),
- expected_output,
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
index 1c56c9563..0239b0b00 100644
--- a/tests/bot/rules/test_discord_emojis.py
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -1,6 +1,7 @@
-import unittest
+from typing import Iterable
from bot.rules import discord_emojis
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
@@ -11,10 +12,11 @@ def make_msg(author: str, n_emojis: int) -> MockMessage:
return MockMessage(author=author, content=discord_emoji * n_emojis)
-class DiscordEmojisRuleTests(unittest.TestCase):
+class DiscordEmojisRuleTests(RuleTest):
"""Tests for the `discord_emojis` antispam rule."""
def setUp(self):
+ self.apply = discord_emojis.apply
self.config = {"max": 2, "interval": 10}
@async_test
@@ -25,44 +27,28 @@ class DiscordEmojisRuleTests(unittest.TestCase):
[make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config):
- self.assertIsNone(await discord_emojis.apply(last_message, recent_messages, self.config))
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of discord emojis."""
cases = (
- (
+ DisallowedCase(
[make_msg("bob", 3)],
- "bob",
+ ("bob",),
3,
),
- (
+ DisallowedCase(
[make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- "alice",
+ ("alice",),
4,
),
)
- for recent_messages, culprit, total_emojis in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit)
- expected_output = (
- f"sent {total_emojis} emojis in {self.config['interval']}s",
- (culprit,),
- relevant_messages,
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- expected_output=expected_output,
- ):
- self.assertTupleEqual(
- await discord_emojis.apply(last_message, recent_messages, self.config),
- expected_output,
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index b77e01c84..3c3f90e5f 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -1,26 +1,21 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import links
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_links: int
-
-
def make_msg(author: str, total_links: int) -> MockMessage:
"""Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
return MockMessage(author=author, content=content)
-class LinksTests(unittest.TestCase):
+class LinksTests(RuleTest):
"""Tests applying the `links` rule."""
def setUp(self):
+ self.apply = links.apply
self.config = {
"max": 2,
"interval": 10
@@ -37,61 +32,38 @@ class LinksTests(unittest.TestCase):
[make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await links.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
- Case(
+ DisallowedCase(
[make_msg("bob", 1), make_msg("bob", 2)],
("bob",),
3
),
- Case(
+ DisallowedCase(
[make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
("alice",),
3
),
- Case(
+ DisallowedCase(
[make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
("alice",),
3
)
)
- for recent_messages, culprit, total_links in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_links=total_links,
- config=self.config
- ):
- desired_output = (
- f"sent {total_links} links in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await links.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index 43211f097..ebcdabac6 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -1,28 +1,23 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import mentions
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_mentions: int
-
-
def make_msg(author: str, total_mentions: int) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
return MockMessage(author=author, mentions=list(range(total_mentions)))
-class TestMentions(unittest.TestCase):
+class TestMentions(RuleTest):
"""Tests applying the `mentions` antispam rule."""
def setUp(self):
+ self.apply = mentions.apply
self.config = {
"max": 2,
- "interval": 10
+ "interval": 10,
}
@async_test
@@ -32,64 +27,41 @@ class TestMentions(unittest.TestCase):
[make_msg("bob", 0)],
[make_msg("bob", 2)],
[make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 1), make_msg("alice", 2)]
+ [make_msg("bob", 1), make_msg("alice", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await mentions.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
- Case(
+ DisallowedCase(
[make_msg("bob", 3)],
("bob",),
- 3
+ 3,
),
- Case(
+ DisallowedCase(
[make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
("alice",),
- 3
+ 3,
),
- Case(
+ DisallowedCase(
[make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
- 4
+ 4,
)
)
- for recent_messages, culprit, total_mentions in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_mentions=total_mentions,
- cofig=self.config
- ):
- desired_output = (
- f"sent {total_mentions} mentions in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await mentions.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
index 6377ffbc8..b339cccf7 100644
--- a/tests/bot/rules/test_role_mentions.py
+++ b/tests/bot/rules/test_role_mentions.py
@@ -1,6 +1,7 @@
-import unittest
+from typing import Iterable
from bot.rules import role_mentions
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
@@ -9,10 +10,11 @@ def make_msg(author: str, n_mentions: int) -> MockMessage:
return MockMessage(author=author, role_mentions=[None] * n_mentions)
-class RoleMentionsRuleTests(unittest.TestCase):
+class RoleMentionsRuleTests(RuleTest):
"""Tests for the `role_mentions` antispam rule."""
def setUp(self):
+ self.apply = role_mentions.apply
self.config = {"max": 2, "interval": 10}
@async_test
@@ -23,44 +25,33 @@ class RoleMentionsRuleTests(unittest.TestCase):
[make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(last_message=last_message, recent_messages=recent_messages, config=self.config):
- self.assertIsNone(await role_mentions.apply(last_message, recent_messages, self.config))
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of role mentions."""
cases = (
- (
+ DisallowedCase(
[make_msg("bob", 3)],
- "bob",
+ ("bob",),
3,
),
- (
+ DisallowedCase(
[make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- "alice",
+ ("alice",),
4,
),
)
- for recent_messages, culprit, total_mentions in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(msg for msg in recent_messages if msg.author == culprit)
- expected_output = (
- f"sent {total_mentions} role mentions in {self.config['interval']}s",
- (culprit,),
- relevant_messages,
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- expected_output=expected_output,
- ):
- self.assertTupleEqual(
- await role_mentions.apply(last_message, recent_messages, self.config),
- expected_output,
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} role mentions in {self.config['interval']}s"