aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mbaruh <[email protected]>2021-08-21 01:28:07 +0300
committerGravatar mbaruh <[email protected]>2021-08-21 01:28:07 +0300
commit34d7f7cc61be8b76bbc32fce1d6c8cbfc97b4cb8 (patch)
tree8a88dc1ab613d5a1cd4248872c2ff8f806596dc5
parentMove max_interval to init (diff)
AntiSpam modified to work with cache
The anti-spam cog now uses a cache instead of reading channel history. The cache is for all channels in the guild, and does not remove deleted messages. That means that the anti-spam logic now works cross-channel and counts deleted messages. The size of the cache is determined via a new field in the config YAML file. The cache was implemented as a separate class, MessageCache, which uses circular buffer logic. This allows for constant time addition and removal form either side, and lookup. The cache does not support removal from the middle of the cache. The cache additionally stores a mapping from message ID's to the index of the message in the cache, to allow constant time lookup by message ID. The commit additionally adds accompanying tests, and renames `cache.py` to `caching.py` to better distinguish it from the new `message_cache.py` and convey that it's for general caching utilities.
-rw-r--r--bot/constants.py2
-rw-r--r--bot/exts/filters/antispam.py24
-rw-r--r--bot/exts/info/pep.py2
-rw-r--r--bot/utils/caching.py (renamed from bot/utils/cache.py)0
-rw-r--r--bot/utils/message_cache.py177
-rw-r--r--config-default.yml2
-rw-r--r--tests/bot/utils/test_message_cache.py208
7 files changed, 406 insertions, 9 deletions
diff --git a/bot/constants.py b/bot/constants.py
index 500803f33..34a814035 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter):
class AntiSpam(metaclass=YAMLGetter):
section = 'anti_spam'
+ cache_size: int
+
clean_offending: bool
ping_everyone: bool
diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py
index da4583e76..58cd8dec4 100644
--- a/bot/exts/filters/antispam.py
+++ b/bot/exts/filters/antispam.py
@@ -3,6 +3,7 @@ import logging
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timedelta
+from itertools import takewhile
from operator import attrgetter, itemgetter
from typing import Dict, Iterable, List, Set
@@ -20,6 +21,7 @@ from bot.converters import Duration
from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
from bot.exts.moderation.modlog import ModLog
from bot.utils import lock, scheduling
+from bot.utils.message_cache import MessageCache
from bot.utils.messages import format_user, send_attachments
@@ -122,6 +124,7 @@ class AntiSpam(Cog):
key=itemgetter('interval')
)
self.max_interval = max_interval_config['interval']
+ self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True)
self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")
@@ -162,12 +165,11 @@ class AntiSpam(Cog):
):
return
+ self.cache.append(message)
+
# Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls.
earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval)
- relevant_messages = [
- msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False)
- if not msg.author.bot
- ]
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))
for rule_name in AntiSpamConfig.rules:
rule_config = AntiSpamConfig.rules[rule_name]
@@ -175,9 +177,10 @@ class AntiSpam(Cog):
# Create a list of messages that were sent in the interval that the rule cares about.
latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval'])
- messages_for_rule = [
- msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp
- ]
+ messages_for_rule = list(
+ takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages)
+ )
+
result = await rule_function(message, messages_for_rule, rule_config)
# If the rule returns `None`, that means the message didn't violate it.
@@ -212,7 +215,7 @@ class AntiSpam(Cog):
name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"
)
- await self.maybe_delete_messages(channel, relevant_messages)
+ await self.maybe_delete_messages(channel, messages_for_rule)
break
@lock.lock_arg("antispam.punish", "member", attrgetter("id"))
@@ -264,6 +267,11 @@ class AntiSpam(Cog):
deletion_context = self.message_deletion_queue.pop(context_id)
await deletion_context.upload_messages(self.bot.user.id, self.mod_log)
+ @Cog.listener()
+ async def on_message_edit(self, before: Message, after: Message) -> None:
+ """Updates the message in the cache, if it's cached."""
+ self.cache.update(after)
+
def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:
"""Validates the antispam configs."""
diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py
index 8ac96bbdb..b11b34db0 100644
--- a/bot/exts/info/pep.py
+++ b/bot/exts/info/pep.py
@@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command
from bot.bot import Bot
from bot.constants import Keys
-from bot.utils.cache import AsyncCache
+from bot.utils.caching import AsyncCache
log = logging.getLogger(__name__)
diff --git a/bot/utils/cache.py b/bot/utils/caching.py
index 68ce15607..68ce15607 100644
--- a/bot/utils/cache.py
+++ b/bot/utils/caching.py
diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py
new file mode 100644
index 000000000..b2f8f66bf
--- /dev/null
+++ b/bot/utils/message_cache.py
@@ -0,0 +1,177 @@
+from __future__ import annotations
+
+import typing as t
+from math import ceil
+
+from discord import Message
+
+
+class MessageCache:
+ """
+ A data structure for caching messages.
+
+ The cache is implemented as a circular buffer to allow constant time append, prepend, push, pop,
+ and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be
+ implemented to work in linear time relative to the maximum size).
+
+ The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message
+ is stored, to allow for constant time lookup by message ID.
+
+ The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those
+ of a deque.
+
+ The implementation is transparent to the user: to the user the first element is always at index 0, and there are
+ only as many elements as were inserted (meaning, without any pre-allocated placeholder values).
+ """
+
+ def __init__(self, maxlen: int, *, newest_first: bool = False):
+ if maxlen <= 0:
+ raise ValueError("maxlen must be positive")
+ self.maxlen = maxlen
+ self.newest_first = newest_first
+
+ self._start = 0
+ self._end = 0
+
+ self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._message_id_mapping = {}
+
+ def append(self, message: Message) -> None:
+ """Add the received message to the cache, depending on the order of messages defined by `newest_first`."""
+ if self.newest_first:
+ self._appendleft(message)
+ else:
+ self._appendright(message)
+
+ def _appendright(self, message: Message) -> None:
+ """Add the received message to the end of the cache."""
+ if self._is_full():
+ del self._message_id_mapping[self._messages[self._start].id]
+ self._start = (self._start + 1) % self.maxlen
+
+ self._messages[self._end] = message
+ self._message_id_mapping[message.id] = self._end
+ self._end = (self._end + 1) % self.maxlen
+
+ def _appendleft(self, message: Message) -> None:
+ """Add the received message to the beginning of the cache."""
+ if self._is_full():
+ self._end = (self._end - 1) % self.maxlen
+ del self._message_id_mapping[self._messages[self._end].id]
+
+ self._start = (self._start - 1) % self.maxlen
+ self._messages[self._start] = message
+ self._message_id_mapping[message.id] = self._start
+
+ def pop(self) -> Message:
+ """Remove the last message in the cache and return it."""
+ if self._is_empty():
+ raise IndexError("pop from an empty cache")
+
+ self._end = (self._end - 1) % self.maxlen
+ message = self._messages[self._end]
+ del self._message_id_mapping[message.id]
+ self._messages[self._end] = None
+
+ return message
+
+ def popleft(self) -> Message:
+ """Return the first message in the cache and return it."""
+ if self._is_empty():
+ raise IndexError("pop from an empty cache")
+
+ message = self._messages[self._start]
+ del self._message_id_mapping[message.id]
+ self._messages[self._start] = None
+ self._start = (self._start + 1) % self.maxlen
+
+ return message
+
+ def clear(self) -> None:
+ """Remove all messages from the cache."""
+ self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._message_id_mapping = {}
+
+ self._start = 0
+ self._end = 0
+
+ def get_message(self, message_id: int) -> t.Optional[Message]:
+ """Return the message that has the given message ID, if it is cached."""
+ index = self._message_id_mapping.get(message_id, None)
+ return self._messages[index] if index is not None else None
+
+ def update(self, message: Message) -> bool:
+ """
+ Update a cached message with new contents.
+
+ Return True if the given message had a matching ID in the cache.
+ """
+ index = self._message_id_mapping.get(message.id, None)
+ if index is None:
+ return False
+ self._messages[index] = message
+ return True
+
+ def __contains__(self, message_id: int) -> bool:
+ """Return True if the cache contains a message with the given ID ."""
+ return message_id in self._message_id_mapping
+
+ def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]:
+ """
+ Return the message(s) in the index or slice provided.
+
+ This method makes the circular buffer implementation transparent to the user.
+ Providing 0 will return the message at the position perceived by the user to be the beginning of the cache,
+ meaning at `self._start`.
+ """
+ if isinstance(item, int):
+ if item >= len(self) or item < -len(self):
+ raise IndexError("cache index out of range")
+ return self._messages[(item + self._start) % self.maxlen]
+
+ elif isinstance(item, slice):
+ length = len(self)
+ start, stop, step = item.indices(length)
+
+ # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state.
+ if (start >= stop and step >= 0) or (start <= stop and step <= 0):
+ return []
+
+ start = (start + self._start) % self.maxlen
+ stop = (stop + self._start) % self.maxlen
+
+ # Having empty cells is an implementation detail. To the user the cache contains as many elements as they
+ # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail.
+ if (
+ (self._start < self._end and not self._start < stop <= self._end)
+ or (self._start > self._end and self._end < stop <= self._start)
+ ):
+ stop = self._end
+
+ if (start < stop and step > 0) or (start > stop and step < 0):
+ return self._messages[start:stop:step]
+ # step != 1 may require a start offset in the second slicing.
+ if step > 0:
+ offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen
+ else:
+ offset = self.maxlen - ((start + 1) % step)
+ return self._messages[start::step] + self._messages[offset:stop:step]
+
+ else:
+ raise TypeError(f"cache indices must be integers or slices, not {type(item)}")
+
+ def __len__(self):
+ """Get the number of non-empty cells in the cache."""
+ if self._is_empty():
+ return 0
+ if self._end > self._start:
+ return self._end - self._start
+ return self.maxlen - self._start + self._end
+
+ def _is_empty(self) -> bool:
+ """Return True if the cache has no messages."""
+ return self._messages[self._start] is None
+
+ def _is_full(self) -> bool:
+ """Return True if every cell in the cache already contains a message."""
+ return self._messages[self._end] is not None
diff --git a/config-default.yml b/config-default.yml
index 881a7df76..2412a7016 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -377,6 +377,8 @@ urls:
anti_spam:
+ cache_size: 100
+
# Clean messages that violate a rule.
clean_offending: true
ping_everyone: true
diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py
new file mode 100644
index 000000000..ff313c6d6
--- /dev/null
+++ b/tests/bot/utils/test_message_cache.py
@@ -0,0 +1,208 @@
+import unittest
+
+from bot.utils.message_cache import MessageCache
+from tests.helpers import MockMessage
+
+
+# noinspection SpellCheckingInspection
+class TestMessageCache(unittest.TestCase):
+ """Tests for the MessageCache class in the `bot.utils.caching` module."""
+
+ def test_first_append_sets_the_first_value(self):
+ """Test if the first append adds the message to the first cell."""
+ cache = MessageCache(maxlen=10)
+ message = MockMessage()
+
+ cache.append(message)
+
+ self.assertEqual(cache[0], message)
+
+ def test_append_adds_in_the_right_order(self):
+ """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise."""
+ messages = [MockMessage(), MockMessage()]
+
+ cache = MessageCache(maxlen=10, newest_first=False)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages, list(cache))
+
+ cache = MessageCache(maxlen=10, newest_first=True)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages[::-1], list(cache))
+
+ def test_appending_over_maxlen_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages."""
+ cache = MessageCache(maxlen=2)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_appending_over_maxlen_with_newest_first_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True."""
+ cache = MessageCache(maxlen=2, newest_first=True)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[:0:-1], list(cache))
+
+ def test_pop_removes_from_the_end(self):
+ """Test if a pop removes the right-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.pop()
+
+ self.assertEqual(msg, messages[-1])
+ self.assertListEqual(messages[:-1], list(cache))
+
+ def test_popleft_removes_from_the_beginning(self):
+ """Test if a popleft removes the left-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.popleft()
+
+ self.assertEqual(msg, messages[0])
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_clear(self):
+ """Test if a clear makes the cache empty."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ cache.clear()
+
+ self.assertListEqual(list(cache), [])
+ self.assertEqual(len(cache), 0)
+
+ def test_get_message_returns_the_message(self):
+ """Test if get_message returns the cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertEqual(cache.get_message(1234), message)
+
+ def test_get_message_returns_none(self):
+ """Test if get_message returns None for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIsNone(cache.get_message(4321))
+
+ def test_update_replaces_old_element(self):
+ """Test if an update replaced the old message with the same ID."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+ message = MockMessage(id=1234)
+ cache.update(message)
+
+ self.assertIs(cache.get_message(1234), message)
+ self.assertEqual(len(cache), 1)
+
+ def test_contains_returns_true_for_cached_message(self):
+ """Test if contains returns True for an ID of a cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIn(1234, cache)
+
+ def test_contains_returns_false_for_non_cached_message(self):
+ """Test if contains returns False for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertNotIn(4321, cache)
+
+ def test_indexing(self):
+ """Test if the cache returns the correct messages by index."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(5)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in range(-5, 5):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(cache[current_loop], messages[current_loop])
+
+ def test_bad_index_raises_index_error(self):
+ """Test if the cache raises IndexError for invalid indices."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+ test_cases = (-10, -4, 3, 4, 5)
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with self.assertRaises(IndexError):
+ cache[current_loop]
+
+ def test_slicing_with_unfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache is not yet fully filled."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(4)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ test_cases = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3)
+ )
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ self.assertListEqual(cache[current_loop], messages[current_loop])
+
+ def test_slicing_with_overfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache was appended with more messages it can contain."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(8)]
+
+ for msg in messages:
+ cache.append(msg)
+ messages = messages[3:]
+
+ test_cases = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3)
+ )
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ self.assertListEqual(cache[current_loop], messages[current_loop])
+
+ def test_length(self):
+ """Test if len returns the correct number of items in the cache."""
+ cache = MessageCache(maxlen=5)
+
+ for current_loop in range(10):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(len(cache), min(current_loop, 5))
+ cache.append(MockMessage())