diff options
-rw-r--r-- | bot/constants.py | 2 | ||||
-rw-r--r-- | bot/exts/filters/antispam.py | 84 | ||||
-rw-r--r-- | bot/exts/info/pep.py | 2 | ||||
-rw-r--r-- | bot/utils/caching.py (renamed from bot/utils/cache.py) | 0 | ||||
-rw-r--r-- | bot/utils/message_cache.py | 197 | ||||
-rw-r--r-- | config-default.yml | 2 | ||||
-rw-r--r-- | tests/bot/utils/test_message_cache.py | 214 |
7 files changed, 465 insertions, 36 deletions
diff --git a/bot/constants.py b/bot/constants.py index 12b5c02e5..407646b28 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter): class AntiSpam(metaclass=YAMLGetter): section = 'anti_spam' + cache_size: int + clean_offending: bool ping_everyone: bool diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 226da2790..987060779 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -1,8 +1,10 @@ import asyncio import logging +from collections import defaultdict from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta +from itertools import takewhile from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set @@ -20,6 +22,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments @@ -44,19 +47,18 @@ RULE_FUNCTION_MAPPING = { class DeletionContext: """Represents a Deletion Context for a single spam event.""" - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) + members: frozenset[Member] + triggered_in: TextChannel + channels: set[TextChannel] = field(default_factory=set) rules: Set[str] = field(default_factory=set) messages: Dict[int, Message] = field(default_factory=dict) attachments: List[List[str]] = field(default_factory=list) - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None: """Adds new rule violation events to the deletion context.""" self.rules.add(rule_name) - for member in members: - if member.id not in self.members: - self.members[member.id] = member + self.channels.update(channels) for message in messages: if message.id not in self.messages: @@ -69,11 +71,14 @@ class DeletionContext: async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(format_user(m) for m in self.members.values()) + triggered_by_users = ", ".join(format_user(m) for m in self.members) + triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" + channels_description = ", ".join(channel.mention for channel in self.channels) mod_alert_message = ( f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" + f"{triggered_in_channel}" + f"**Channels:** {channels_description}\n" f"**Rules:** {', '.join(rule for rule in self.rules)}\n" ) @@ -116,6 +121,14 @@ class AntiSpam(Cog): self.message_deletion_queue = dict() + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + self.max_interval = max_interval_config['interval'] + self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) + self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error") @property @@ -155,19 +168,10 @@ class AntiSpam(Cog): ): return - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] + self.cache.append(message) - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache)) for rule_name in AntiSpamConfig.rules: rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +179,10 @@ class AntiSpam(Cog): # Create a list of messages that were sent in the interval that the rule cares about. latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] + messages_for_rule = list( + takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) + ) + result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. @@ -190,19 +195,19 @@ class AntiSpam(Cog): full_reason = f"`{rule_name}` rule: {reason}" # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + authors_set = frozenset(members) + if authors_set not in self.message_deletion_queue: + log.trace(f"Creating queue for members `{authors_set}`") + self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel) scheduling.create_task( - self._process_deletion_context(message.channel.id), - name=f"AntiSpam._process_deletion_context({message.channel.id})" + self._process_deletion_context(authors_set), + name=f"AntiSpam._process_deletion_context({authors_set})" ) # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( + await self.message_deletion_queue[authors_set].add( rule_name=rule_name, - members=members, + channels=set(message.channel for message in messages_for_rule), messages=relevant_messages ) @@ -212,7 +217,7 @@ class AntiSpam(Cog): name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})" ) - await self.maybe_delete_messages(channel, relevant_messages) + await self.maybe_delete_messages(messages_for_rule) break @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -234,14 +239,18 @@ class AntiSpam(Cog): reason=reason ) - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + async def maybe_delete_messages(self, messages: List[Message]) -> None: """Cleans the messages if cleaning is configured.""" if AntiSpamConfig.clean_offending: # If we have more than one message, we can use bulk delete. if len(messages) > 1: message_ids = [message.id for message in messages] self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) + channel_messages = defaultdict(list) + for message in messages: + channel_messages[message.channel].append(message) + for channel, messages in channel_messages.items(): + await channel.delete_messages(messages) # Otherwise, the bulk delete endpoint will throw up. # Delete the message directly instead. @@ -252,7 +261,7 @@ class AntiSpam(Cog): except NotFound: log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - async def _process_deletion_context(self, context_id: int) -> None: + async def _process_deletion_context(self, context_id: frozenset) -> None: """Processes the Deletion Context queue.""" log.trace("Sleeping before processing message deletion queue.") await asyncio.sleep(10) @@ -264,6 +273,11 @@ class AntiSpam(Cog): deletion_context = self.message_deletion_queue.pop(context_id) await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Updates the message in the cache, if it's cached.""" + self.cache.update(after) + def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: """Validates the antispam configs.""" diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache log = logging.getLogger(__name__) diff --git a/bot/utils/cache.py b/bot/utils/caching.py index 68ce15607..68ce15607 100644 --- a/bot/utils/cache.py +++ b/bot/utils/caching.py diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..f68d280c9 --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,197 @@ +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: + """ + A data structure for caching messages. + + The cache is implemented as a circular buffer to allow constant time append, prepend, pop from either side, + and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be + implemented to work in linear time relative to the maximum size). + + The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message + is stored, to allow for constant time lookup by message ID. + + The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those + of a deque. + + The implementation is transparent to the user: to the user the first element is always at index 0, and there are + only as many elements as were inserted (meaning, without any pre-allocated placeholder values). + """ + + def __init__(self, maxlen: int, *, newest_first: bool = False): + if maxlen <= 0: + raise ValueError("maxlen must be positive") + self.maxlen = maxlen + self.newest_first = newest_first + + self._start = 0 + self._end = 0 + + self._messages: list[t.Optional[Message]] = [None] * self.maxlen + self._message_id_mapping = {} + + def append(self, message: Message) -> None: + """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" + if self.newest_first: + self._appendleft(message) + else: + self._appendright(message) + + def _appendright(self, message: Message) -> None: + """Add the received message to the end of the cache.""" + if self._is_full(): + del self._message_id_mapping[self._messages[self._start].id] + self._start = (self._start + 1) % self.maxlen + + self._messages[self._end] = message + self._message_id_mapping[message.id] = self._end + self._end = (self._end + 1) % self.maxlen + + def _appendleft(self, message: Message) -> None: + """Add the received message to the beginning of the cache.""" + if self._is_full(): + self._end = (self._end - 1) % self.maxlen + del self._message_id_mapping[self._messages[self._end].id] + + self._start = (self._start - 1) % self.maxlen + self._messages[self._start] = message + self._message_id_mapping[message.id] = self._start + + def pop(self) -> Message: + """Remove the last message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + self._end = (self._end - 1) % self.maxlen + message = self._messages[self._end] + del self._message_id_mapping[message.id] + self._messages[self._end] = None + + return message + + def popleft(self) -> Message: + """Return the first message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + message = self._messages[self._start] + del self._message_id_mapping[message.id] + self._messages[self._start] = None + self._start = (self._start + 1) % self.maxlen + + return message + + def clear(self) -> None: + """Remove all messages from the cache.""" + self._messages = [None] * self.maxlen + self._message_id_mapping = {} + + self._start = 0 + self._end = 0 + + def get_message(self, message_id: int) -> t.Optional[Message]: + """Return the message that has the given message ID, if it is cached.""" + index = self._message_id_mapping.get(message_id, None) + return self._messages[index] if index is not None else None + + def update(self, message: Message) -> bool: + """ + Update a cached message with new contents. + + Return True if the given message had a matching ID in the cache. + """ + index = self._message_id_mapping.get(message.id, None) + if index is None: + return False + self._messages[index] = message + return True + + def __contains__(self, message_id: int) -> bool: + """Return True if the cache contains a message with the given ID .""" + return message_id in self._message_id_mapping + + def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: + """ + Return the message(s) in the index or slice provided. + + This method makes the circular buffer implementation transparent to the user. + Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, + meaning at `self._start`. + """ + # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when + # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides + # by 5 is -5. + if isinstance(item, int): + if item >= len(self) or item < -len(self): + raise IndexError("cache index out of range") + return self._messages[(item + self._start) % self.maxlen] + + elif isinstance(item, slice): + length = len(self) + start, stop, step = item.indices(length) + + # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. + if (start >= stop and step >= 0) or (start <= stop and step <= 0): + return [] + + start = (start + self._start) % self.maxlen + stop = (stop + self._start) % self.maxlen + + # Having empty cells is an implementation detail. To the user the cache contains as many elements as they + # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. + if step > 0: + if ( + (self._start < self._end and not self._start < stop <= self._end) + or (self._start > self._end and self._end < stop <= self._start) + ): + stop = self._end + else: + lower_boundary = (self._start - 1) % self.maxlen + if ( + (self._start < self._end and not self._start - 1 <= stop < self._end) + or (self._start > self._end and self._end < stop < lower_boundary) + ): + stop = lower_boundary + + if (start < stop and step > 0) or (start > stop and step < 0): + return self._messages[start:stop:step] + # step != 1 may require a start offset in the second slicing. + if step > 0: + offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen + return self._messages[start::step] + self._messages[offset:stop:step] + else: + offset = ceil((start + 1) / -step) * -step - start - 1 + return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] + + else: + raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + + def __iter__(self) -> t.Iterator[Message]: + if self._is_empty(): + return + + if self._start < self._end: + yield from self._messages[self._start:self._end] + else: + yield from self._messages[self._start:] + yield from self._messages[:self._end] + + def __len__(self): + """Get the number of non-empty cells in the cache.""" + if self._is_empty(): + return 0 + if self._end > self._start: + return self._end - self._start + return self.maxlen - self._start + self._end + + def _is_empty(self) -> bool: + """Return True if the cache has no messages.""" + return self._messages[self._start] is None + + def _is_full(self) -> bool: + """Return True if every cell in the cache already contains a message.""" + return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 79828dd77..eaf8e0ad7 100644 --- a/config-default.yml +++ b/config-default.yml @@ -377,6 +377,8 @@ urls: anti_spam: + cache_size: 100 + # Clean messages that violate a rule. clean_offending: true ping_everyone: true diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): + """Tests for the MessageCache class in the `bot.utils.caching` module.""" + + def test_first_append_sets_the_first_value(self): + """Test if the first append adds the message to the first cell.""" + cache = MessageCache(maxlen=10) + message = MockMessage() + + cache.append(message) + + self.assertEqual(cache[0], message) + + def test_append_adds_in_the_right_order(self): + """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" + messages = [MockMessage(), MockMessage()] + + cache = MessageCache(maxlen=10, newest_first=False) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages, list(cache)) + + cache = MessageCache(maxlen=10, newest_first=True) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages[::-1], list(cache)) + + def test_appending_over_maxlen_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages.""" + cache = MessageCache(maxlen=2) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[1:], list(cache)) + + def test_appending_over_maxlen_with_newest_first_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" + cache = MessageCache(maxlen=2, newest_first=True) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[:0:-1], list(cache)) + + def test_pop_removes_from_the_end(self): + """Test if a pop removes the right-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.pop() + + self.assertEqual(msg, messages[-1]) + self.assertListEqual(messages[:-1], list(cache)) + + def test_popleft_removes_from_the_beginning(self): + """Test if a popleft removes the left-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.popleft() + + self.assertEqual(msg, messages[0]) + self.assertListEqual(messages[1:], list(cache)) + + def test_clear(self): + """Test if a clear makes the cache empty.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + cache.clear() + + self.assertListEqual(list(cache), []) + self.assertEqual(len(cache), 0) + + def test_get_message_returns_the_message(self): + """Test if get_message returns the cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertEqual(cache.get_message(1234), message) + + def test_get_message_returns_none(self): + """Test if get_message returns None for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIsNone(cache.get_message(4321)) + + def test_update_replaces_old_element(self): + """Test if an update replaced the old message with the same ID.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + message = MockMessage(id=1234) + cache.update(message) + + self.assertIs(cache.get_message(1234), message) + self.assertEqual(len(cache), 1) + + def test_contains_returns_true_for_cached_message(self): + """Test if contains returns True for an ID of a cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIn(1234, cache) + + def test_contains_returns_false_for_non_cached_message(self): + """Test if contains returns False for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertNotIn(4321, cache) + + def test_indexing(self): + """Test if the cache returns the correct messages by index.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(5)] + + for msg in messages: + cache.append(msg) + + for current_loop in range(-5, 5): + with self.subTest(current_loop=current_loop): + self.assertEqual(cache[current_loop], messages[current_loop]) + + def test_bad_index_raises_index_error(self): + """Test if the cache raises IndexError for invalid indices.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + test_cases = (-10, -4, 3, 4, 5) + + for msg in messages: + cache.append(msg) + + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with self.assertRaises(IndexError): + cache[current_loop] + + def test_slicing_with_unfilled_cache(self): + """Test if slicing returns the correct messages if the cache is not yet fully filled.""" + sizes = (5, 10, 55, 101) + + slices = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) + ) + + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size // 3 * 2)] + + for msg in messages: + cache.append(msg) + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) + + def test_slicing_with_overfilled_cache(self): + """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" + sizes = (5, 10, 55, 101) + + slices = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) + ) + + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size * 3 // 2)] + + for msg in messages: + cache.append(msg) + messages = messages[size // 2:] + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) + + def test_length(self): + """Test if len returns the correct number of items in the cache.""" + cache = MessageCache(maxlen=5) + + for current_loop in range(10): + with self.subTest(current_loop=current_loop): + self.assertEqual(len(cache), min(current_loop, 5)) + cache.append(MockMessage()) |