aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/constants.py2
-rw-r--r--bot/exts/filters/antispam.py24
-rw-r--r--bot/exts/info/pep.py2
-rw-r--r--bot/utils/caching.py (renamed from bot/utils/cache.py)0
-rw-r--r--bot/utils/message_cache.py177
-rw-r--r--config-default.yml2
-rw-r--r--tests/bot/utils/test_message_cache.py208
7 files changed, 406 insertions, 9 deletions
diff --git a/bot/constants.py b/bot/constants.py
index 500803f33..34a814035 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter):
class AntiSpam(metaclass=YAMLGetter):
section = 'anti_spam'
+ cache_size: int
+
clean_offending: bool
ping_everyone: bool
diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py
index da4583e76..58cd8dec4 100644
--- a/bot/exts/filters/antispam.py
+++ b/bot/exts/filters/antispam.py
@@ -3,6 +3,7 @@ import logging
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timedelta
+from itertools import takewhile
from operator import attrgetter, itemgetter
from typing import Dict, Iterable, List, Set
@@ -20,6 +21,7 @@ from bot.converters import Duration
from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
from bot.exts.moderation.modlog import ModLog
from bot.utils import lock, scheduling
+from bot.utils.message_cache import MessageCache
from bot.utils.messages import format_user, send_attachments
@@ -122,6 +124,7 @@ class AntiSpam(Cog):
key=itemgetter('interval')
)
self.max_interval = max_interval_config['interval']
+ self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True)
self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")
@@ -162,12 +165,11 @@ class AntiSpam(Cog):
):
return
+ self.cache.append(message)
+
# Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls.
earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval)
- relevant_messages = [
- msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False)
- if not msg.author.bot
- ]
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))
for rule_name in AntiSpamConfig.rules:
rule_config = AntiSpamConfig.rules[rule_name]
@@ -175,9 +177,10 @@ class AntiSpam(Cog):
# Create a list of messages that were sent in the interval that the rule cares about.
latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval'])
- messages_for_rule = [
- msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp
- ]
+ messages_for_rule = list(
+ takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages)
+ )
+
result = await rule_function(message, messages_for_rule, rule_config)
# If the rule returns `None`, that means the message didn't violate it.
@@ -212,7 +215,7 @@ class AntiSpam(Cog):
name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"
)
- await self.maybe_delete_messages(channel, relevant_messages)
+ await self.maybe_delete_messages(channel, messages_for_rule)
break
@lock.lock_arg("antispam.punish", "member", attrgetter("id"))
@@ -264,6 +267,11 @@ class AntiSpam(Cog):
deletion_context = self.message_deletion_queue.pop(context_id)
await deletion_context.upload_messages(self.bot.user.id, self.mod_log)
+ @Cog.listener()
+ async def on_message_edit(self, before: Message, after: Message) -> None:
+ """Updates the message in the cache, if it's cached."""
+ self.cache.update(after)
+
def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:
"""Validates the antispam configs."""
diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py
index 8ac96bbdb..b11b34db0 100644
--- a/bot/exts/info/pep.py
+++ b/bot/exts/info/pep.py
@@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command
from bot.bot import Bot
from bot.constants import Keys
-from bot.utils.cache import AsyncCache
+from bot.utils.caching import AsyncCache
log = logging.getLogger(__name__)
diff --git a/bot/utils/cache.py b/bot/utils/caching.py
index 68ce15607..68ce15607 100644
--- a/bot/utils/cache.py
+++ b/bot/utils/caching.py
diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py
new file mode 100644
index 000000000..b2f8f66bf
--- /dev/null
+++ b/bot/utils/message_cache.py
@@ -0,0 +1,177 @@
+from __future__ import annotations
+
+import typing as t
+from math import ceil
+
+from discord import Message
+
+
+class MessageCache:
+ """
+ A data structure for caching messages.
+
+ The cache is implemented as a circular buffer to allow constant time append, prepend, push, pop,
+ and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be
+ implemented to work in linear time relative to the maximum size).
+
+ The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message
+ is stored, to allow for constant time lookup by message ID.
+
+ The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those
+ of a deque.
+
+ The implementation is transparent to the user: to the user the first element is always at index 0, and there are
+ only as many elements as were inserted (meaning, without any pre-allocated placeholder values).
+ """
+
+ def __init__(self, maxlen: int, *, newest_first: bool = False):
+ if maxlen <= 0:
+ raise ValueError("maxlen must be positive")
+ self.maxlen = maxlen
+ self.newest_first = newest_first
+
+ self._start = 0
+ self._end = 0
+
+ self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._message_id_mapping = {}
+
+ def append(self, message: Message) -> None:
+ """Add the received message to the cache, depending on the order of messages defined by `newest_first`."""
+ if self.newest_first:
+ self._appendleft(message)
+ else:
+ self._appendright(message)
+
+ def _appendright(self, message: Message) -> None:
+ """Add the received message to the end of the cache."""
+ if self._is_full():
+ del self._message_id_mapping[self._messages[self._start].id]
+ self._start = (self._start + 1) % self.maxlen
+
+ self._messages[self._end] = message
+ self._message_id_mapping[message.id] = self._end
+ self._end = (self._end + 1) % self.maxlen
+
+ def _appendleft(self, message: Message) -> None:
+ """Add the received message to the beginning of the cache."""
+ if self._is_full():
+ self._end = (self._end - 1) % self.maxlen
+ del self._message_id_mapping[self._messages[self._end].id]
+
+ self._start = (self._start - 1) % self.maxlen
+ self._messages[self._start] = message
+ self._message_id_mapping[message.id] = self._start
+
+ def pop(self) -> Message:
+ """Remove the last message in the cache and return it."""
+ if self._is_empty():
+ raise IndexError("pop from an empty cache")
+
+ self._end = (self._end - 1) % self.maxlen
+ message = self._messages[self._end]
+ del self._message_id_mapping[message.id]
+ self._messages[self._end] = None
+
+ return message
+
+ def popleft(self) -> Message:
+ """Return the first message in the cache and return it."""
+ if self._is_empty():
+ raise IndexError("pop from an empty cache")
+
+ message = self._messages[self._start]
+ del self._message_id_mapping[message.id]
+ self._messages[self._start] = None
+ self._start = (self._start + 1) % self.maxlen
+
+ return message
+
+ def clear(self) -> None:
+ """Remove all messages from the cache."""
+ self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._message_id_mapping = {}
+
+ self._start = 0
+ self._end = 0
+
+ def get_message(self, message_id: int) -> t.Optional[Message]:
+ """Return the message that has the given message ID, if it is cached."""
+ index = self._message_id_mapping.get(message_id, None)
+ return self._messages[index] if index is not None else None
+
+ def update(self, message: Message) -> bool:
+ """
+ Update a cached message with new contents.
+
+ Return True if the given message had a matching ID in the cache.
+ """
+ index = self._message_id_mapping.get(message.id, None)
+ if index is None:
+ return False
+ self._messages[index] = message
+ return True
+
+ def __contains__(self, message_id: int) -> bool:
+ """Return True if the cache contains a message with the given ID ."""
+ return message_id in self._message_id_mapping
+
+ def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]:
+ """
+ Return the message(s) in the index or slice provided.
+
+ This method makes the circular buffer implementation transparent to the user.
+ Providing 0 will return the message at the position perceived by the user to be the beginning of the cache,
+ meaning at `self._start`.
+ """
+ if isinstance(item, int):
+ if item >= len(self) or item < -len(self):
+ raise IndexError("cache index out of range")
+ return self._messages[(item + self._start) % self.maxlen]
+
+ elif isinstance(item, slice):
+ length = len(self)
+ start, stop, step = item.indices(length)
+
+ # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state.
+ if (start >= stop and step >= 0) or (start <= stop and step <= 0):
+ return []
+
+ start = (start + self._start) % self.maxlen
+ stop = (stop + self._start) % self.maxlen
+
+ # Having empty cells is an implementation detail. To the user the cache contains as many elements as they
+ # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail.
+ if (
+ (self._start < self._end and not self._start < stop <= self._end)
+ or (self._start > self._end and self._end < stop <= self._start)
+ ):
+ stop = self._end
+
+ if (start < stop and step > 0) or (start > stop and step < 0):
+ return self._messages[start:stop:step]
+ # step != 1 may require a start offset in the second slicing.
+ if step > 0:
+ offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen
+ else:
+ offset = self.maxlen - ((start + 1) % step)
+ return self._messages[start::step] + self._messages[offset:stop:step]
+
+ else:
+ raise TypeError(f"cache indices must be integers or slices, not {type(item)}")
+
+ def __len__(self):
+ """Get the number of non-empty cells in the cache."""
+ if self._is_empty():
+ return 0
+ if self._end > self._start:
+ return self._end - self._start
+ return self.maxlen - self._start + self._end
+
+ def _is_empty(self) -> bool:
+ """Return True if the cache has no messages."""
+ return self._messages[self._start] is None
+
+ def _is_full(self) -> bool:
+ """Return True if every cell in the cache already contains a message."""
+ return self._messages[self._end] is not None
diff --git a/config-default.yml b/config-default.yml
index 881a7df76..2412a7016 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -377,6 +377,8 @@ urls:
anti_spam:
+ cache_size: 100
+
# Clean messages that violate a rule.
clean_offending: true
ping_everyone: true
diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py
new file mode 100644
index 000000000..ff313c6d6
--- /dev/null
+++ b/tests/bot/utils/test_message_cache.py
@@ -0,0 +1,208 @@
+import unittest
+
+from bot.utils.message_cache import MessageCache
+from tests.helpers import MockMessage
+
+
+# noinspection SpellCheckingInspection
+class TestMessageCache(unittest.TestCase):
+ """Tests for the MessageCache class in the `bot.utils.caching` module."""
+
+ def test_first_append_sets_the_first_value(self):
+ """Test if the first append adds the message to the first cell."""
+ cache = MessageCache(maxlen=10)
+ message = MockMessage()
+
+ cache.append(message)
+
+ self.assertEqual(cache[0], message)
+
+ def test_append_adds_in_the_right_order(self):
+ """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise."""
+ messages = [MockMessage(), MockMessage()]
+
+ cache = MessageCache(maxlen=10, newest_first=False)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages, list(cache))
+
+ cache = MessageCache(maxlen=10, newest_first=True)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages[::-1], list(cache))
+
+ def test_appending_over_maxlen_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages."""
+ cache = MessageCache(maxlen=2)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_appending_over_maxlen_with_newest_first_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True."""
+ cache = MessageCache(maxlen=2, newest_first=True)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[:0:-1], list(cache))
+
+ def test_pop_removes_from_the_end(self):
+ """Test if a pop removes the right-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.pop()
+
+ self.assertEqual(msg, messages[-1])
+ self.assertListEqual(messages[:-1], list(cache))
+
+ def test_popleft_removes_from_the_beginning(self):
+ """Test if a popleft removes the left-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.popleft()
+
+ self.assertEqual(msg, messages[0])
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_clear(self):
+ """Test if a clear makes the cache empty."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ cache.clear()
+
+ self.assertListEqual(list(cache), [])
+ self.assertEqual(len(cache), 0)
+
+ def test_get_message_returns_the_message(self):
+ """Test if get_message returns the cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertEqual(cache.get_message(1234), message)
+
+ def test_get_message_returns_none(self):
+ """Test if get_message returns None for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIsNone(cache.get_message(4321))
+
+ def test_update_replaces_old_element(self):
+ """Test if an update replaced the old message with the same ID."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+ message = MockMessage(id=1234)
+ cache.update(message)
+
+ self.assertIs(cache.get_message(1234), message)
+ self.assertEqual(len(cache), 1)
+
+ def test_contains_returns_true_for_cached_message(self):
+ """Test if contains returns True for an ID of a cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIn(1234, cache)
+
+ def test_contains_returns_false_for_non_cached_message(self):
+ """Test if contains returns False for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertNotIn(4321, cache)
+
+ def test_indexing(self):
+ """Test if the cache returns the correct messages by index."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(5)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in range(-5, 5):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(cache[current_loop], messages[current_loop])
+
+ def test_bad_index_raises_index_error(self):
+ """Test if the cache raises IndexError for invalid indices."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+ test_cases = (-10, -4, 3, 4, 5)
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with self.assertRaises(IndexError):
+ cache[current_loop]
+
+ def test_slicing_with_unfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache is not yet fully filled."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(4)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ test_cases = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3)
+ )
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ self.assertListEqual(cache[current_loop], messages[current_loop])
+
+ def test_slicing_with_overfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache was appended with more messages it can contain."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(8)]
+
+ for msg in messages:
+ cache.append(msg)
+ messages = messages[3:]
+
+ test_cases = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3)
+ )
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ self.assertListEqual(cache[current_loop], messages[current_loop])
+
+ def test_length(self):
+ """Test if len returns the correct number of items in the cache."""
+ cache = MessageCache(maxlen=5)
+
+ for current_loop in range(10):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(len(cache), min(current_loop, 5))
+ cache.append(MockMessage())