diff options
Diffstat (limited to 'tests/bot/rules/test_mentions.py')
-rw-r--r-- | tests/bot/rules/test_mentions.py | 90 |
1 files changed, 31 insertions, 59 deletions
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest from tests.helpers import MockMessage, async_test -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage: """Makes a message with `total_mentions` mentions.""" return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest): """Tests applying the `mentions` antispam rule.""" def setUp(self): + self.apply = mentions.apply self.config = { "max": 2, - "interval": 10 + "interval": 10, } @async_test async def test_mentions_within_limit(self): """Messages with an allowed amount of mentions.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 1), msg("alice", 2)] + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 1), make_msg("alice", 2)], ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await mentions.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) @async_test async def test_mentions_exceeding_limit(self): """Messages with a higher than allowed amount of mentions.""" cases = ( - Case( - [msg("bob", 3)], + DisallowedCase( + [make_msg("bob", 3)], ("bob",), - 3 + 3, ), - Case( - [msg("alice", 2), msg("alice", 0), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], ("alice",), - 3 + 3, ), - Case( - [msg("bob", 2), msg("alice", 3), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], ("bob",), - 4 + 4, ) ) - for recent_messages, culprit, total_mentions in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if msg.author == last_message.author - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_mentions=total_mentions, - cofig=self.config - ): - desired_output = ( - f"sent {total_mentions} mentions in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await mentions.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} mentions in {self.config['interval']}s" |