From 8ec6f0880be0ca876b9a255a0fd5e0090695d5b8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 15 Jul 2021 22:31:01 -0700 Subject: Fix get_active_infraction test --- tests/bot/exts/moderation/infraction/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 50a717bb5..f3af7bea9 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"]) test_cases = [ test_case([], None, None, True), - test_case([{"id": 123987}], {"id": 123987}, "123987", False), - test_case([{"id": 123987}], {"id": 123987}, "123987", True) + test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False), + test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True) ] for case in test_cases: -- cgit v1.2.3 From 24e4a9d2f4c037f7652b7300772ec0c6c6ab8d3c Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 15 Jul 2021 22:44:42 -0700 Subject: Remove redundant parameter from pardon_voice_ban --- bot/exts/moderation/infraction/infractions.py | 4 ++-- tests/bot/exts/moderation/infraction/test_infractions.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index eaf718af4..dbf56d6bb 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -455,7 +455,7 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: + async def pardon_voice_ban(self, user_id: int, guild: discord.Guild) -> t.Dict[str, str]: """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" user = guild.get_member(user_id) log_text = {} @@ -491,7 +491,7 @@ class Infractions(InfractionScheduler, commands.Cog): elif infraction["type"] == "ban": return await self.pardon_ban(user_id, guild, reason) elif infraction["type"] == "voice_ban": - return await self.pardon_voice_ban(user_id, guild, reason) + return await self.pardon_voice_ban(user_id, guild) # endregion diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b9d527770..f844a9181 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): async def test_voice_unban_user_not_found(self): """Should include info to return dict when user was not found from guild.""" self.guild.get_member.return_value = None - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, {"Info": "User was not found in the guild."}) @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): notify_pardon_mock.return_value = True format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "Sent" @@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): notify_pardon_mock.return_value = False format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "**Failed**" -- cgit v1.2.3 From 34d7f7cc61be8b76bbc32fce1d6c8cbfc97b4cb8 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sat, 21 Aug 2021 01:28:07 +0300 Subject: AntiSpam modified to work with cache The anti-spam cog now uses a cache instead of reading channel history. The cache is for all channels in the guild, and does not remove deleted messages. That means that the anti-spam logic now works cross-channel and counts deleted messages. The size of the cache is determined via a new field in the config YAML file. The cache was implemented as a separate class, MessageCache, which uses circular buffer logic. This allows for constant time addition and removal form either side, and lookup. The cache does not support removal from the middle of the cache. The cache additionally stores a mapping from message ID's to the index of the message in the cache, to allow constant time lookup by message ID. The commit additionally adds accompanying tests, and renames `cache.py` to `caching.py` to better distinguish it from the new `message_cache.py` and convey that it's for general caching utilities. --- bot/constants.py | 2 + bot/exts/filters/antispam.py | 24 ++-- bot/exts/info/pep.py | 2 +- bot/utils/cache.py | 41 ------- bot/utils/caching.py | 41 +++++++ bot/utils/message_cache.py | 177 +++++++++++++++++++++++++++++ config-default.yml | 2 + tests/bot/utils/test_message_cache.py | 208 ++++++++++++++++++++++++++++++++++ 8 files changed, 447 insertions(+), 50 deletions(-) delete mode 100644 bot/utils/cache.py create mode 100644 bot/utils/caching.py create mode 100644 bot/utils/message_cache.py create mode 100644 tests/bot/utils/test_message_cache.py (limited to 'tests') diff --git a/bot/constants.py b/bot/constants.py index 500803f33..34a814035 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter): class AntiSpam(metaclass=YAMLGetter): section = 'anti_spam' + cache_size: int + clean_offending: bool ping_everyone: bool diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index da4583e76..58cd8dec4 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -3,6 +3,7 @@ import logging from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta +from itertools import takewhile from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set @@ -20,6 +21,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments @@ -122,6 +124,7 @@ class AntiSpam(Cog): key=itemgetter('interval') ) self.max_interval = max_interval_config['interval'] + self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error") @@ -162,12 +165,11 @@ class AntiSpam(Cog): ): return + self.cache.append(message) + # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache)) for rule_name in AntiSpamConfig.rules: rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +177,10 @@ class AntiSpam(Cog): # Create a list of messages that were sent in the interval that the rule cares about. latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] + messages_for_rule = list( + takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) + ) + result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. @@ -212,7 +215,7 @@ class AntiSpam(Cog): name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})" ) - await self.maybe_delete_messages(channel, relevant_messages) + await self.maybe_delete_messages(channel, messages_for_rule) break @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -264,6 +267,11 @@ class AntiSpam(Cog): deletion_context = self.message_deletion_queue.pop(context_id) await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Updates the message in the cache, if it's cached.""" + self.cache.update(after) + def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: """Validates the antispam configs.""" diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache log = logging.getLogger(__name__) diff --git a/bot/utils/cache.py b/bot/utils/cache.py deleted file mode 100644 index 68ce15607..000000000 --- a/bot/utils/cache.py +++ /dev/null @@ -1,41 +0,0 @@ -import functools -from collections import OrderedDict -from typing import Any, Callable - - -class AsyncCache: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - - def __init__(self, max_size: int = 128): - self._cache = OrderedDict() - self._max_size = max_size - - def __call__(self, arg_offset: int = 0) -> Callable: - """Decorator for async cache.""" - - def decorator(function: Callable) -> Callable: - """Define the async cache decorator.""" - - @functools.wraps(function) - async def wrapper(*args) -> Any: - """Decorator wrapper for the caching logic.""" - key = args[arg_offset:] - - if key not in self._cache: - if len(self._cache) > self._max_size: - self._cache.popitem(last=False) - - self._cache[key] = await function(*args) - return self._cache[key] - return wrapper - return decorator - - def clear(self) -> None: - """Clear cache instance.""" - self._cache.clear() diff --git a/bot/utils/caching.py b/bot/utils/caching.py new file mode 100644 index 000000000..68ce15607 --- /dev/null +++ b/bot/utils/caching.py @@ -0,0 +1,41 @@ +import functools +from collections import OrderedDict +from typing import Any, Callable + + +class AsyncCache: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + + def __init__(self, max_size: int = 128): + self._cache = OrderedDict() + self._max_size = max_size + + def __call__(self, arg_offset: int = 0) -> Callable: + """Decorator for async cache.""" + + def decorator(function: Callable) -> Callable: + """Define the async cache decorator.""" + + @functools.wraps(function) + async def wrapper(*args) -> Any: + """Decorator wrapper for the caching logic.""" + key = args[arg_offset:] + + if key not in self._cache: + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + self._cache[key] = await function(*args) + return self._cache[key] + return wrapper + return decorator + + def clear(self) -> None: + """Clear cache instance.""" + self._cache.clear() diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..b2f8f66bf --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: + """ + A data structure for caching messages. + + The cache is implemented as a circular buffer to allow constant time append, prepend, push, pop, + and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be + implemented to work in linear time relative to the maximum size). + + The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message + is stored, to allow for constant time lookup by message ID. + + The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those + of a deque. + + The implementation is transparent to the user: to the user the first element is always at index 0, and there are + only as many elements as were inserted (meaning, without any pre-allocated placeholder values). + """ + + def __init__(self, maxlen: int, *, newest_first: bool = False): + if maxlen <= 0: + raise ValueError("maxlen must be positive") + self.maxlen = maxlen + self.newest_first = newest_first + + self._start = 0 + self._end = 0 + + self._messages: list[t.Optional[Message]] = [None] * self.maxlen + self._message_id_mapping = {} + + def append(self, message: Message) -> None: + """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" + if self.newest_first: + self._appendleft(message) + else: + self._appendright(message) + + def _appendright(self, message: Message) -> None: + """Add the received message to the end of the cache.""" + if self._is_full(): + del self._message_id_mapping[self._messages[self._start].id] + self._start = (self._start + 1) % self.maxlen + + self._messages[self._end] = message + self._message_id_mapping[message.id] = self._end + self._end = (self._end + 1) % self.maxlen + + def _appendleft(self, message: Message) -> None: + """Add the received message to the beginning of the cache.""" + if self._is_full(): + self._end = (self._end - 1) % self.maxlen + del self._message_id_mapping[self._messages[self._end].id] + + self._start = (self._start - 1) % self.maxlen + self._messages[self._start] = message + self._message_id_mapping[message.id] = self._start + + def pop(self) -> Message: + """Remove the last message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + self._end = (self._end - 1) % self.maxlen + message = self._messages[self._end] + del self._message_id_mapping[message.id] + self._messages[self._end] = None + + return message + + def popleft(self) -> Message: + """Return the first message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + message = self._messages[self._start] + del self._message_id_mapping[message.id] + self._messages[self._start] = None + self._start = (self._start + 1) % self.maxlen + + return message + + def clear(self) -> None: + """Remove all messages from the cache.""" + self._messages: list[t.Optional[Message]] = [None] * self.maxlen + self._message_id_mapping = {} + + self._start = 0 + self._end = 0 + + def get_message(self, message_id: int) -> t.Optional[Message]: + """Return the message that has the given message ID, if it is cached.""" + index = self._message_id_mapping.get(message_id, None) + return self._messages[index] if index is not None else None + + def update(self, message: Message) -> bool: + """ + Update a cached message with new contents. + + Return True if the given message had a matching ID in the cache. + """ + index = self._message_id_mapping.get(message.id, None) + if index is None: + return False + self._messages[index] = message + return True + + def __contains__(self, message_id: int) -> bool: + """Return True if the cache contains a message with the given ID .""" + return message_id in self._message_id_mapping + + def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: + """ + Return the message(s) in the index or slice provided. + + This method makes the circular buffer implementation transparent to the user. + Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, + meaning at `self._start`. + """ + if isinstance(item, int): + if item >= len(self) or item < -len(self): + raise IndexError("cache index out of range") + return self._messages[(item + self._start) % self.maxlen] + + elif isinstance(item, slice): + length = len(self) + start, stop, step = item.indices(length) + + # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. + if (start >= stop and step >= 0) or (start <= stop and step <= 0): + return [] + + start = (start + self._start) % self.maxlen + stop = (stop + self._start) % self.maxlen + + # Having empty cells is an implementation detail. To the user the cache contains as many elements as they + # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. + if ( + (self._start < self._end and not self._start < stop <= self._end) + or (self._start > self._end and self._end < stop <= self._start) + ): + stop = self._end + + if (start < stop and step > 0) or (start > stop and step < 0): + return self._messages[start:stop:step] + # step != 1 may require a start offset in the second slicing. + if step > 0: + offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen + else: + offset = self.maxlen - ((start + 1) % step) + return self._messages[start::step] + self._messages[offset:stop:step] + + else: + raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + + def __len__(self): + """Get the number of non-empty cells in the cache.""" + if self._is_empty(): + return 0 + if self._end > self._start: + return self._end - self._start + return self.maxlen - self._start + self._end + + def _is_empty(self) -> bool: + """Return True if the cache has no messages.""" + return self._messages[self._start] is None + + def _is_full(self) -> bool: + """Return True if every cell in the cache already contains a message.""" + return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 881a7df76..2412a7016 100644 --- a/config-default.yml +++ b/config-default.yml @@ -377,6 +377,8 @@ urls: anti_spam: + cache_size: 100 + # Clean messages that violate a rule. clean_offending: true ping_everyone: true diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..ff313c6d6 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,208 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): + """Tests for the MessageCache class in the `bot.utils.caching` module.""" + + def test_first_append_sets_the_first_value(self): + """Test if the first append adds the message to the first cell.""" + cache = MessageCache(maxlen=10) + message = MockMessage() + + cache.append(message) + + self.assertEqual(cache[0], message) + + def test_append_adds_in_the_right_order(self): + """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" + messages = [MockMessage(), MockMessage()] + + cache = MessageCache(maxlen=10, newest_first=False) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages, list(cache)) + + cache = MessageCache(maxlen=10, newest_first=True) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages[::-1], list(cache)) + + def test_appending_over_maxlen_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages.""" + cache = MessageCache(maxlen=2) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[1:], list(cache)) + + def test_appending_over_maxlen_with_newest_first_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" + cache = MessageCache(maxlen=2, newest_first=True) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[:0:-1], list(cache)) + + def test_pop_removes_from_the_end(self): + """Test if a pop removes the right-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.pop() + + self.assertEqual(msg, messages[-1]) + self.assertListEqual(messages[:-1], list(cache)) + + def test_popleft_removes_from_the_beginning(self): + """Test if a popleft removes the left-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.popleft() + + self.assertEqual(msg, messages[0]) + self.assertListEqual(messages[1:], list(cache)) + + def test_clear(self): + """Test if a clear makes the cache empty.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + cache.clear() + + self.assertListEqual(list(cache), []) + self.assertEqual(len(cache), 0) + + def test_get_message_returns_the_message(self): + """Test if get_message returns the cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertEqual(cache.get_message(1234), message) + + def test_get_message_returns_none(self): + """Test if get_message returns None for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIsNone(cache.get_message(4321)) + + def test_update_replaces_old_element(self): + """Test if an update replaced the old message with the same ID.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + message = MockMessage(id=1234) + cache.update(message) + + self.assertIs(cache.get_message(1234), message) + self.assertEqual(len(cache), 1) + + def test_contains_returns_true_for_cached_message(self): + """Test if contains returns True for an ID of a cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIn(1234, cache) + + def test_contains_returns_false_for_non_cached_message(self): + """Test if contains returns False for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertNotIn(4321, cache) + + def test_indexing(self): + """Test if the cache returns the correct messages by index.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(5)] + + for msg in messages: + cache.append(msg) + + for current_loop in range(-5, 5): + with self.subTest(current_loop=current_loop): + self.assertEqual(cache[current_loop], messages[current_loop]) + + def test_bad_index_raises_index_error(self): + """Test if the cache raises IndexError for invalid indices.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + test_cases = (-10, -4, 3, 4, 5) + + for msg in messages: + cache.append(msg) + + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with self.assertRaises(IndexError): + cache[current_loop] + + def test_slicing_with_unfilled_cache(self): + """Test if slicing returns the correct messages if the cache is not yet fully filled.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(4)] + + for msg in messages: + cache.append(msg) + + test_cases = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3) + ) + + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + self.assertListEqual(cache[current_loop], messages[current_loop]) + + def test_slicing_with_overfilled_cache(self): + """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(8)] + + for msg in messages: + cache.append(msg) + messages = messages[3:] + + test_cases = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3) + ) + + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + self.assertListEqual(cache[current_loop], messages[current_loop]) + + def test_length(self): + """Test if len returns the correct number of items in the cache.""" + cache = MessageCache(maxlen=5) + + for current_loop in range(10): + with self.subTest(current_loop=current_loop): + self.assertEqual(len(cache), min(current_loop, 5)) + cache.append(MockMessage()) -- cgit v1.2.3 From b4ddc0b7fa3601807bb36139efd06693a3bfe9fc Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sat, 21 Aug 2021 15:13:04 +0300 Subject: Fix MessageCache slicing bugs, improve tests --- bot/utils/message_cache.py | 23 ++++++++++++------ tests/bot/utils/test_message_cache.py | 44 ++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 26 deletions(-) (limited to 'tests') diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index b2f8f66bf..67da8ecf3 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -142,20 +142,29 @@ class MessageCache: # Having empty cells is an implementation detail. To the user the cache contains as many elements as they # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. - if ( - (self._start < self._end and not self._start < stop <= self._end) - or (self._start > self._end and self._end < stop <= self._start) - ): - stop = self._end + if step > 0: + if ( + (self._start < self._end and not self._start < stop <= self._end) + or (self._start > self._end and self._end < stop <= self._start) + ): + stop = self._end + else: + lower_boundary = (self._start - 1) % self.maxlen + if ( + (self._start < self._end and not self._start - 1 <= stop < self._end) + or (self._start > self._end and self._end < stop < lower_boundary) + ): + stop = lower_boundary if (start < stop and step > 0) or (start > stop and step < 0): return self._messages[start:stop:step] # step != 1 may require a start offset in the second slicing. if step > 0: offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen + return self._messages[start::step] + self._messages[offset:stop:step] else: - offset = self.maxlen - ((start + 1) % step) - return self._messages[start::step] + self._messages[offset:stop:step] + offset = ceil((start + 1) / -step) * -step - start - 1 + return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] else: raise TypeError(f"cache indices must be integers or slices, not {type(item)}") diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py index ff313c6d6..5e871cd19 100644 --- a/tests/bot/utils/test_message_cache.py +++ b/tests/bot/utils/test_message_cache.py @@ -163,40 +163,46 @@ class TestMessageCache(unittest.TestCase): def test_slicing_with_unfilled_cache(self): """Test if slicing returns the correct messages if the cache is not yet fully filled.""" - cache = MessageCache(maxlen=5) - messages = [MockMessage() for _ in range(4)] - - for msg in messages: - cache.append(msg) + sizes = (5, 10, 55, 101) - test_cases = ( + slices = ( slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), slice(None, None, -3) ) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - self.assertListEqual(cache[current_loop], messages[current_loop]) + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size // 3 * 2)] + + for msg in messages: + cache.append(msg) + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) def test_slicing_with_overfilled_cache(self): """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" - cache = MessageCache(maxlen=5) - messages = [MockMessage() for _ in range(8)] - - for msg in messages: - cache.append(msg) - messages = messages[3:] + sizes = (5, 10, 55, 101) - test_cases = ( + slices = ( slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), slice(None, None, -3) ) - for current_loop in test_cases: - with self.subTest(current_loop=current_loop): - self.assertListEqual(cache[current_loop], messages[current_loop]) + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size * 3 // 2)] + + for msg in messages: + cache.append(msg) + messages = messages[size // 2:] + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) def test_length(self): """Test if len returns the correct number of items in the cache.""" -- cgit v1.2.3 From 0531b1ec1fd55018d358d07f3ab52d0ada3cdaae Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sat, 21 Aug 2021 20:52:12 +0300 Subject: Additional comments and tests for slicing --- bot/utils/message_cache.py | 3 +++ tests/bot/utils/test_message_cache.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index 6d219c313..f5656cdeb 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -122,6 +122,9 @@ class MessageCache: Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, meaning at `self._start`. """ + # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when + # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides + # by 5 is -5. if isinstance(item, int): if item >= len(self) or item < -len(self): raise IndexError("cache index out of range") diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py index 5e871cd19..04bfd28d1 100644 --- a/tests/bot/utils/test_message_cache.py +++ b/tests/bot/utils/test_message_cache.py @@ -168,7 +168,7 @@ class TestMessageCache(unittest.TestCase): slices = ( slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), - slice(None, None, -3) + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) ) for size in sizes: @@ -189,7 +189,7 @@ class TestMessageCache(unittest.TestCase): slices = ( slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), - slice(None, None, -3) + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) ) for size in sizes: -- cgit v1.2.3 From 7daab91633b324901d48653b99fa1bfdf6b093ff Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 23 Aug 2021 18:20:37 +0100 Subject: Fix current tests by using MockMember in mention lists --- tests/bot/rules/test_mentions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index 6444532f2..a5e42d0a9 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,12 +2,14 @@ from typing import Iterable from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage +from tests.helpers import MockMember, MockMessage -def make_msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: """Makes a message with `total_mentions` mentions.""" - return MockMessage(author=author, mentions=list(range(total_mentions))) + user_mentions = [MockMember() for _ in range(total_user_mentions)] + bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] + return MockMessage(author=author, mentions=user_mentions+bot_mentions) class TestMentions(RuleTest): -- cgit v1.2.3 From a6121e6aa6e5aeaa9ae95a8973408e947958c6e0 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 23 Aug 2021 18:33:35 +0100 Subject: Added some more test cases to ensure bot mentions aren't counted --- tests/bot/rules/test_mentions.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index a5e42d0a9..f8805ac48 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -50,11 +50,27 @@ class TestMentions(RuleTest): [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], ("bob",), 4, - ) + ), + DisallowedCase( + [make_msg("bob", 3, 1)], + ("bob",), + 3, + ), ) await self.run_disallowed(cases) + async def test_ignore_bot_mentions(self): + """Messages with an allowed amount of mentions, also containing bot mentions.""" + cases = ( + [make_msg("bob", 0, 3)], + [make_msg("bob", 2, 1)], + [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], + [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] + ) + + await self.run_allowed(cases) + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: last_message = case.recent_messages[0] return tuple( -- cgit v1.2.3