diff options
-rw-r--r-- | bot/rules/attachments.py | 8 | ||||
-rw-r--r-- | tests/rules/__init__.py | 0 | ||||
-rw-r--r-- | tests/rules/test_attachments.py | 52 |
3 files changed, 55 insertions, 5 deletions
diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index e71b96183..9cf9877fd 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -6,19 +6,17 @@ from discord import Member, Message async def apply( - last_message: Message, - recent_messages: List[Message], - config: Dict[str, int] + last_message: Message, recent_messages: List[Message], config: Dict[str, int] ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: """Apply attachment spam detection filter.""" - relevant_messages = tuple( + relevant_messages = [last_message] + [ msg for msg in recent_messages if ( msg.author == last_message.author and len(msg.attachments) > 0 ) - ) + ] total_recent_attachments = sum(len(msg.attachments) for msg in relevant_messages) if total_recent_attachments > config['max']: diff --git a/tests/rules/__init__.py b/tests/rules/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/rules/__init__.py diff --git a/tests/rules/test_attachments.py b/tests/rules/test_attachments.py new file mode 100644 index 000000000..6f025b3cb --- /dev/null +++ b/tests/rules/test_attachments.py @@ -0,0 +1,52 @@ +import asyncio +from dataclasses import dataclass +from typing import Any, List + +import pytest + +from bot.rules import attachments + + +# Using `MagicMock` sadly doesn't work for this usecase +# since it's __eq__ compares the MagicMock's ID. We just +# want to compare the actual attributes we set. +@dataclass +class FakeMessage: + author: str + attachments: List[Any] + + +def msg(total_attachments: int): + return FakeMessage(author='lemon', attachments=list(range(total_attachments))) + + + 'messages', + ( + (msg(0), msg(0), msg(0)), + (msg(2), msg(2)), + (msg(0),), + ) +) +def test_allows_messages_without_too_many_attachments(messages): + last_message, *recent_messages = messages + coro = attachments.apply(last_message, recent_messages, {'max': 5}) + assert asyncio.run(coro) is None + + + ('messages', 'relevant_messages', 'total'), + ( + ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), + ((msg(6),), [msg(6)], 6), + ((msg(1),) * 6, [msg(1)] * 6, 6), + ) +) +def test_disallows_messages_with_too_many_attachments(messages, relevant_messages, total): + last_message, *recent_messages = messages + coro = attachments.apply(last_message, recent_messages, {'max': 5}) + assert asyncio.run(coro) == ( + f"sent {total} attachments in 5s", + ('lemon',), + relevant_messages + ) |