diff options
| author | 2021-08-23 21:55:53 +0300 | |
|---|---|---|
| committer | 2021-08-23 21:55:53 +0300 | |
| commit | 697c0da9b781cc71acca24b0c1b56ae10a7959cf (patch) | |
| tree | 7a84dbf7cbd2e689916db42a023159f30b586d89 | |
| parent | Merge pull request #1682 from python-discord/feat/mod/1665/override-auto-mute (diff) | |
| parent | Merge branch 'main' into mbaruh/anti-spam (diff) | |
Merge pull request #1760 from python-discord/mbaruh/anti-spam
Cross-channel and deleted messages anti-spam
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | bot/exts/filters/antispam.py | 84 | ||||
| -rw-r--r-- | bot/exts/info/pep.py | 2 | ||||
| -rw-r--r-- | bot/utils/caching.py (renamed from bot/utils/cache.py) | 0 | ||||
| -rw-r--r-- | bot/utils/message_cache.py | 197 | ||||
| -rw-r--r-- | config-default.yml | 2 | ||||
| -rw-r--r-- | tests/bot/utils/test_message_cache.py | 214 | 
7 files changed, 465 insertions, 36 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 12b5c02e5..407646b28 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter):  class AntiSpam(metaclass=YAMLGetter):      section = 'anti_spam' +    cache_size: int +      clean_offending: bool      ping_everyone: bool diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 226da2790..987060779 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -1,8 +1,10 @@  import asyncio  import logging +from collections import defaultdict  from collections.abc import Mapping  from dataclasses import dataclass, field  from datetime import datetime, timedelta +from itertools import takewhile  from operator import attrgetter, itemgetter  from typing import Dict, Iterable, List, Set @@ -20,6 +22,7 @@ from bot.converters import Duration  from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.exts.moderation.modlog import ModLog  from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache  from bot.utils.messages import format_user, send_attachments @@ -44,19 +47,18 @@ RULE_FUNCTION_MAPPING = {  class DeletionContext:      """Represents a Deletion Context for a single spam event.""" -    channel: TextChannel -    members: Dict[int, Member] = field(default_factory=dict) +    members: frozenset[Member] +    triggered_in: TextChannel +    channels: set[TextChannel] = field(default_factory=set)      rules: Set[str] = field(default_factory=set)      messages: Dict[int, Message] = field(default_factory=dict)      attachments: List[List[str]] = field(default_factory=list) -    async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: +    async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None:          """Adds new rule violation events to the deletion context."""          self.rules.add(rule_name) -        for member in members: -            if member.id not in self.members: -                self.members[member.id] = member +        self.channels.update(channels)          for message in messages:              if message.id not in self.messages: @@ -69,11 +71,14 @@ class DeletionContext:      async def upload_messages(self, actor_id: int, modlog: ModLog) -> None:          """Method that takes care of uploading the queue and posting modlog alert.""" -        triggered_by_users = ", ".join(format_user(m) for m in self.members.values()) +        triggered_by_users = ", ".join(format_user(m) for m in self.members) +        triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" +        channels_description = ", ".join(channel.mention for channel in self.channels)          mod_alert_message = (              f"**Triggered by:** {triggered_by_users}\n" -            f"**Channel:** {self.channel.mention}\n" +            f"{triggered_in_channel}" +            f"**Channels:** {channels_description}\n"              f"**Rules:** {', '.join(rule for rule in self.rules)}\n"          ) @@ -116,6 +121,14 @@ class AntiSpam(Cog):          self.message_deletion_queue = dict() +        # Fetch the rule configuration with the highest rule interval. +        max_interval_config = max( +            AntiSpamConfig.rules.values(), +            key=itemgetter('interval') +        ) +        self.max_interval = max_interval_config['interval'] +        self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) +          self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")      @property @@ -155,19 +168,10 @@ class AntiSpam(Cog):          ):              return -        # Fetch the rule configuration with the highest rule interval. -        max_interval_config = max( -            AntiSpamConfig.rules.values(), -            key=itemgetter('interval') -        ) -        max_interval = max_interval_config['interval'] +        self.cache.append(message) -        # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. -        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) -        relevant_messages = [ -            msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) -            if not msg.author.bot -        ] +        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))          for rule_name in AntiSpamConfig.rules:              rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +179,10 @@ class AntiSpam(Cog):              # Create a list of messages that were sent in the interval that the rule cares about.              latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) -            messages_for_rule = [ -                msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp -            ] +            messages_for_rule = list( +                takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) +            ) +              result = await rule_function(message, messages_for_rule, rule_config)              # If the rule returns `None`, that means the message didn't violate it. @@ -190,19 +195,19 @@ class AntiSpam(Cog):                  full_reason = f"`{rule_name}` rule: {reason}"                  # If there's no spam event going on for this channel, start a new Message Deletion Context -                channel = message.channel -                if channel.id not in self.message_deletion_queue: -                    log.trace(f"Creating queue for channel `{channel.id}`") -                    self.message_deletion_queue[message.channel.id] = DeletionContext(channel) +                authors_set = frozenset(members) +                if authors_set not in self.message_deletion_queue: +                    log.trace(f"Creating queue for members `{authors_set}`") +                    self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel)                      scheduling.create_task( -                        self._process_deletion_context(message.channel.id), -                        name=f"AntiSpam._process_deletion_context({message.channel.id})" +                        self._process_deletion_context(authors_set), +                        name=f"AntiSpam._process_deletion_context({authors_set})"                      )                  # Add the relevant of this trigger to the Deletion Context -                await self.message_deletion_queue[message.channel.id].add( +                await self.message_deletion_queue[authors_set].add(                      rule_name=rule_name, -                    members=members, +                    channels=set(message.channel for message in messages_for_rule),                      messages=relevant_messages                  ) @@ -212,7 +217,7 @@ class AntiSpam(Cog):                          name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"                      ) -                await self.maybe_delete_messages(channel, relevant_messages) +                await self.maybe_delete_messages(messages_for_rule)                  break      @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -234,14 +239,18 @@ class AntiSpam(Cog):                  reason=reason              ) -    async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: +    async def maybe_delete_messages(self, messages: List[Message]) -> None:          """Cleans the messages if cleaning is configured."""          if AntiSpamConfig.clean_offending:              # If we have more than one message, we can use bulk delete.              if len(messages) > 1:                  message_ids = [message.id for message in messages]                  self.mod_log.ignore(Event.message_delete, *message_ids) -                await channel.delete_messages(messages) +                channel_messages = defaultdict(list) +                for message in messages: +                    channel_messages[message.channel].append(message) +                for channel, messages in channel_messages.items(): +                    await channel.delete_messages(messages)              # Otherwise, the bulk delete endpoint will throw up.              # Delete the message directly instead. @@ -252,7 +261,7 @@ class AntiSpam(Cog):                  except NotFound:                      log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") -    async def _process_deletion_context(self, context_id: int) -> None: +    async def _process_deletion_context(self, context_id: frozenset) -> None:          """Processes the Deletion Context queue."""          log.trace("Sleeping before processing message deletion queue.")          await asyncio.sleep(10) @@ -264,6 +273,11 @@ class AntiSpam(Cog):          deletion_context = self.message_deletion_queue.pop(context_id)          await deletion_context.upload_messages(self.bot.user.id, self.mod_log) +    @Cog.listener() +    async def on_message_edit(self, before: Message, after: Message) -> None: +        """Updates the message in the cache, if it's cached.""" +        self.cache.update(after) +  def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:      """Validates the antispam configs.""" diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command  from bot.bot import Bot  from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache  log = logging.getLogger(__name__) diff --git a/bot/utils/cache.py b/bot/utils/caching.py index 68ce15607..68ce15607 100644 --- a/bot/utils/cache.py +++ b/bot/utils/caching.py diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..f68d280c9 --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,197 @@ +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: +    """ +    A data structure for caching messages. + +    The cache is implemented as a circular buffer to allow constant time append, prepend, pop from either side, +    and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be +    implemented to work in linear time relative to the maximum size). + +    The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message +    is stored, to allow for constant time lookup by message ID. + +    The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those +    of a deque. + +    The implementation is transparent to the user: to the user the first element is always at index 0, and there are +    only as many elements as were inserted (meaning, without any pre-allocated placeholder values). +    """ + +    def __init__(self, maxlen: int, *, newest_first: bool = False): +        if maxlen <= 0: +            raise ValueError("maxlen must be positive") +        self.maxlen = maxlen +        self.newest_first = newest_first + +        self._start = 0 +        self._end = 0 + +        self._messages: list[t.Optional[Message]] = [None] * self.maxlen +        self._message_id_mapping = {} + +    def append(self, message: Message) -> None: +        """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" +        if self.newest_first: +            self._appendleft(message) +        else: +            self._appendright(message) + +    def _appendright(self, message: Message) -> None: +        """Add the received message to the end of the cache.""" +        if self._is_full(): +            del self._message_id_mapping[self._messages[self._start].id] +            self._start = (self._start + 1) % self.maxlen + +        self._messages[self._end] = message +        self._message_id_mapping[message.id] = self._end +        self._end = (self._end + 1) % self.maxlen + +    def _appendleft(self, message: Message) -> None: +        """Add the received message to the beginning of the cache.""" +        if self._is_full(): +            self._end = (self._end - 1) % self.maxlen +            del self._message_id_mapping[self._messages[self._end].id] + +        self._start = (self._start - 1) % self.maxlen +        self._messages[self._start] = message +        self._message_id_mapping[message.id] = self._start + +    def pop(self) -> Message: +        """Remove the last message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        self._end = (self._end - 1) % self.maxlen +        message = self._messages[self._end] +        del self._message_id_mapping[message.id] +        self._messages[self._end] = None + +        return message + +    def popleft(self) -> Message: +        """Return the first message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        message = self._messages[self._start] +        del self._message_id_mapping[message.id] +        self._messages[self._start] = None +        self._start = (self._start + 1) % self.maxlen + +        return message + +    def clear(self) -> None: +        """Remove all messages from the cache.""" +        self._messages = [None] * self.maxlen +        self._message_id_mapping = {} + +        self._start = 0 +        self._end = 0 + +    def get_message(self, message_id: int) -> t.Optional[Message]: +        """Return the message that has the given message ID, if it is cached.""" +        index = self._message_id_mapping.get(message_id, None) +        return self._messages[index] if index is not None else None + +    def update(self, message: Message) -> bool: +        """ +        Update a cached message with new contents. + +        Return True if the given message had a matching ID in the cache. +        """ +        index = self._message_id_mapping.get(message.id, None) +        if index is None: +            return False +        self._messages[index] = message +        return True + +    def __contains__(self, message_id: int) -> bool: +        """Return True if the cache contains a message with the given ID .""" +        return message_id in self._message_id_mapping + +    def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: +        """ +        Return the message(s) in the index or slice provided. + +        This method makes the circular buffer implementation transparent to the user. +        Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, +        meaning at `self._start`. +        """ +        # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when +        # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides +        # by 5 is -5. +        if isinstance(item, int): +            if item >= len(self) or item < -len(self): +                raise IndexError("cache index out of range") +            return self._messages[(item + self._start) % self.maxlen] + +        elif isinstance(item, slice): +            length = len(self) +            start, stop, step = item.indices(length) + +            # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. +            if (start >= stop and step >= 0) or (start <= stop and step <= 0): +                return [] + +            start = (start + self._start) % self.maxlen +            stop = (stop + self._start) % self.maxlen + +            # Having empty cells is an implementation detail. To the user the cache contains as many elements as they +            # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. +            if step > 0: +                if ( +                    (self._start < self._end and not self._start < stop <= self._end) +                    or (self._start > self._end and self._end < stop <= self._start) +                ): +                    stop = self._end +            else: +                lower_boundary = (self._start - 1) % self.maxlen +                if ( +                    (self._start < self._end and not self._start - 1 <= stop < self._end) +                    or (self._start > self._end and self._end < stop < lower_boundary) +                ): +                    stop = lower_boundary + +            if (start < stop and step > 0) or (start > stop and step < 0): +                return self._messages[start:stop:step] +            # step != 1 may require a start offset in the second slicing. +            if step > 0: +                offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen +                return self._messages[start::step] + self._messages[offset:stop:step] +            else: +                offset = ceil((start + 1) / -step) * -step - start - 1 +                return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] + +        else: +            raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + +    def __iter__(self) -> t.Iterator[Message]: +        if self._is_empty(): +            return + +        if self._start < self._end: +            yield from self._messages[self._start:self._end] +        else: +            yield from self._messages[self._start:] +            yield from self._messages[:self._end] + +    def __len__(self): +        """Get the number of non-empty cells in the cache.""" +        if self._is_empty(): +            return 0 +        if self._end > self._start: +            return self._end - self._start +        return self.maxlen - self._start + self._end + +    def _is_empty(self) -> bool: +        """Return True if the cache has no messages.""" +        return self._messages[self._start] is None + +    def _is_full(self) -> bool: +        """Return True if every cell in the cache already contains a message.""" +        return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 79828dd77..eaf8e0ad7 100644 --- a/config-default.yml +++ b/config-default.yml @@ -377,6 +377,8 @@ urls:  anti_spam: +    cache_size: 100 +      # Clean messages that violate a rule.      clean_offending: true      ping_everyone: true diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): +    """Tests for the MessageCache class in the `bot.utils.caching` module.""" + +    def test_first_append_sets_the_first_value(self): +        """Test if the first append adds the message to the first cell.""" +        cache = MessageCache(maxlen=10) +        message = MockMessage() + +        cache.append(message) + +        self.assertEqual(cache[0], message) + +    def test_append_adds_in_the_right_order(self): +        """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" +        messages = [MockMessage(), MockMessage()] + +        cache = MessageCache(maxlen=10, newest_first=False) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages, list(cache)) + +        cache = MessageCache(maxlen=10, newest_first=True) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages[::-1], list(cache)) + +    def test_appending_over_maxlen_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages.""" +        cache = MessageCache(maxlen=2) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[1:], list(cache)) + +    def test_appending_over_maxlen_with_newest_first_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" +        cache = MessageCache(maxlen=2, newest_first=True) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[:0:-1], list(cache)) + +    def test_pop_removes_from_the_end(self): +        """Test if a pop removes the right-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.pop() + +        self.assertEqual(msg, messages[-1]) +        self.assertListEqual(messages[:-1], list(cache)) + +    def test_popleft_removes_from_the_beginning(self): +        """Test if a popleft removes the left-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.popleft() + +        self.assertEqual(msg, messages[0]) +        self.assertListEqual(messages[1:], list(cache)) + +    def test_clear(self): +        """Test if a clear makes the cache empty.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        cache.clear() + +        self.assertListEqual(list(cache), []) +        self.assertEqual(len(cache), 0) + +    def test_get_message_returns_the_message(self): +        """Test if get_message returns the cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertEqual(cache.get_message(1234), message) + +    def test_get_message_returns_none(self): +        """Test if get_message returns None for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIsNone(cache.get_message(4321)) + +    def test_update_replaces_old_element(self): +        """Test if an update replaced the old message with the same ID.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) +        message = MockMessage(id=1234) +        cache.update(message) + +        self.assertIs(cache.get_message(1234), message) +        self.assertEqual(len(cache), 1) + +    def test_contains_returns_true_for_cached_message(self): +        """Test if contains returns True for an ID of a cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIn(1234, cache) + +    def test_contains_returns_false_for_non_cached_message(self): +        """Test if contains returns False for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertNotIn(4321, cache) + +    def test_indexing(self): +        """Test if the cache returns the correct messages by index.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(5)] + +        for msg in messages: +            cache.append(msg) + +        for current_loop in range(-5, 5): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(cache[current_loop], messages[current_loop]) + +    def test_bad_index_raises_index_error(self): +        """Test if the cache raises IndexError for invalid indices.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] +        test_cases = (-10, -4, 3, 4, 5) + +        for msg in messages: +            cache.append(msg) + +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with self.assertRaises(IndexError): +                    cache[current_loop] + +    def test_slicing_with_unfilled_cache(self): +        """Test if slicing returns the correct messages if the cache is not yet fully filled.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size // 3 * 2)] + +            for msg in messages: +                cache.append(msg) + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_slicing_with_overfilled_cache(self): +        """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size * 3 // 2)] + +            for msg in messages: +                cache.append(msg) +            messages = messages[size // 2:] + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_length(self): +        """Test if len returns the correct number of items in the cache.""" +        cache = MessageCache(maxlen=5) + +        for current_loop in range(10): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(len(cache), min(current_loop, 5)) +                cache.append(MockMessage()) | 
