diff options
author | 2021-08-24 17:08:12 +0100 | |
---|---|---|
committer | 2021-08-24 17:08:12 +0100 | |
commit | cf305038399fe151873f87a7f9d9d590e0c8822c (patch) | |
tree | 104fedd86b4cbf04d92d852dd5dae4add041ad28 | |
parent | Fix linting (diff) | |
parent | Merge branch 'main' into community-partners-access (diff) |
Merge remote-tracking branch 'origin/community-partners-access' into community-partners-access
-rw-r--r-- | bot/constants.py | 4 | ||||
-rw-r--r-- | bot/exts/filters/antispam.py | 84 | ||||
-rw-r--r-- | bot/exts/info/information.py | 2 | ||||
-rw-r--r-- | bot/exts/info/pep.py | 2 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 25 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/_utils.py | 13 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 79 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 25 | ||||
-rw-r--r-- | bot/exts/recruitment/talentpool/_cog.py | 6 | ||||
-rw-r--r-- | bot/exts/utils/internal.py | 7 | ||||
-rw-r--r-- | bot/exts/utils/reminders.py | 9 | ||||
-rw-r--r-- | bot/exts/utils/utils.py | 4 | ||||
-rw-r--r-- | bot/rules/mentions.py | 6 | ||||
-rw-r--r-- | bot/utils/caching.py (renamed from bot/utils/cache.py) | 0 | ||||
-rw-r--r-- | bot/utils/message_cache.py | 197 | ||||
-rw-r--r-- | config-default.yml | 5 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 6 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 4 | ||||
-rw-r--r-- | tests/bot/rules/test_mentions.py | 26 | ||||
-rw-r--r-- | tests/bot/utils/test_message_cache.py | 214 |
20 files changed, 610 insertions, 108 deletions
diff --git a/bot/constants.py b/bot/constants.py index 5b629a735..80e01b174 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter): class AntiSpam(metaclass=YAMLGetter): section = 'anti_spam' + cache_size: int + clean_offending: bool ping_everyone: bool @@ -687,7 +689,7 @@ class VideoPermission(metaclass=YAMLGetter): # Debug mode -DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") +DEBUG_MODE: bool = _CONFIG_YAML["debug"] == "true" # Paths BOT_DIR = os.path.dirname(__file__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 226da2790..987060779 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -1,8 +1,10 @@ import asyncio import logging +from collections import defaultdict from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta +from itertools import takewhile from operator import attrgetter, itemgetter from typing import Dict, Iterable, List, Set @@ -20,6 +22,7 @@ from bot.converters import Duration from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.moderation.modlog import ModLog from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache from bot.utils.messages import format_user, send_attachments @@ -44,19 +47,18 @@ RULE_FUNCTION_MAPPING = { class DeletionContext: """Represents a Deletion Context for a single spam event.""" - channel: TextChannel - members: Dict[int, Member] = field(default_factory=dict) + members: frozenset[Member] + triggered_in: TextChannel + channels: set[TextChannel] = field(default_factory=set) rules: Set[str] = field(default_factory=set) messages: Dict[int, Message] = field(default_factory=dict) attachments: List[List[str]] = field(default_factory=list) - async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: + async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None: """Adds new rule violation events to the deletion context.""" self.rules.add(rule_name) - for member in members: - if member.id not in self.members: - self.members[member.id] = member + self.channels.update(channels) for message in messages: if message.id not in self.messages: @@ -69,11 +71,14 @@ class DeletionContext: async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: """Method that takes care of uploading the queue and posting modlog alert.""" - triggered_by_users = ", ".join(format_user(m) for m in self.members.values()) + triggered_by_users = ", ".join(format_user(m) for m in self.members) + triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" + channels_description = ", ".join(channel.mention for channel in self.channels) mod_alert_message = ( f"**Triggered by:** {triggered_by_users}\n" - f"**Channel:** {self.channel.mention}\n" + f"{triggered_in_channel}" + f"**Channels:** {channels_description}\n" f"**Rules:** {', '.join(rule for rule in self.rules)}\n" ) @@ -116,6 +121,14 @@ class AntiSpam(Cog): self.message_deletion_queue = dict() + # Fetch the rule configuration with the highest rule interval. + max_interval_config = max( + AntiSpamConfig.rules.values(), + key=itemgetter('interval') + ) + self.max_interval = max_interval_config['interval'] + self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) + self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error") @property @@ -155,19 +168,10 @@ class AntiSpam(Cog): ): return - # Fetch the rule configuration with the highest rule interval. - max_interval_config = max( - AntiSpamConfig.rules.values(), - key=itemgetter('interval') - ) - max_interval = max_interval_config['interval'] + self.cache.append(message) - # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. - earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) - relevant_messages = [ - msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) - if not msg.author.bot - ] + earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache)) for rule_name in AntiSpamConfig.rules: rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +179,10 @@ class AntiSpam(Cog): # Create a list of messages that were sent in the interval that the rule cares about. latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) - messages_for_rule = [ - msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp - ] + messages_for_rule = list( + takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) + ) + result = await rule_function(message, messages_for_rule, rule_config) # If the rule returns `None`, that means the message didn't violate it. @@ -190,19 +195,19 @@ class AntiSpam(Cog): full_reason = f"`{rule_name}` rule: {reason}" # If there's no spam event going on for this channel, start a new Message Deletion Context - channel = message.channel - if channel.id not in self.message_deletion_queue: - log.trace(f"Creating queue for channel `{channel.id}`") - self.message_deletion_queue[message.channel.id] = DeletionContext(channel) + authors_set = frozenset(members) + if authors_set not in self.message_deletion_queue: + log.trace(f"Creating queue for members `{authors_set}`") + self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel) scheduling.create_task( - self._process_deletion_context(message.channel.id), - name=f"AntiSpam._process_deletion_context({message.channel.id})" + self._process_deletion_context(authors_set), + name=f"AntiSpam._process_deletion_context({authors_set})" ) # Add the relevant of this trigger to the Deletion Context - await self.message_deletion_queue[message.channel.id].add( + await self.message_deletion_queue[authors_set].add( rule_name=rule_name, - members=members, + channels=set(message.channel for message in messages_for_rule), messages=relevant_messages ) @@ -212,7 +217,7 @@ class AntiSpam(Cog): name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})" ) - await self.maybe_delete_messages(channel, relevant_messages) + await self.maybe_delete_messages(messages_for_rule) break @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -234,14 +239,18 @@ class AntiSpam(Cog): reason=reason ) - async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: + async def maybe_delete_messages(self, messages: List[Message]) -> None: """Cleans the messages if cleaning is configured.""" if AntiSpamConfig.clean_offending: # If we have more than one message, we can use bulk delete. if len(messages) > 1: message_ids = [message.id for message in messages] self.mod_log.ignore(Event.message_delete, *message_ids) - await channel.delete_messages(messages) + channel_messages = defaultdict(list) + for message in messages: + channel_messages[message.channel].append(message) + for channel, messages in channel_messages.items(): + await channel.delete_messages(messages) # Otherwise, the bulk delete endpoint will throw up. # Delete the message directly instead. @@ -252,7 +261,7 @@ class AntiSpam(Cog): except NotFound: log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - async def _process_deletion_context(self, context_id: int) -> None: + async def _process_deletion_context(self, context_id: frozenset) -> None: """Processes the Deletion Context queue.""" log.trace("Sleeping before processing message deletion queue.") await asyncio.sleep(10) @@ -264,6 +273,11 @@ class AntiSpam(Cog): deletion_context = self.message_deletion_queue.pop(context_id) await deletion_context.upload_messages(self.bot.user.id, self.mod_log) + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """Updates the message in the cache, if it's cached.""" + self.cache.update(after) + def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: """Validates the antispam configs.""" diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 664b6cb13..38b436b7d 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -8,6 +8,7 @@ from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union import rapidfuzz from discord import AllowedMentions, Colour, Embed, Guild, Message, Role from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role +from discord.utils import escape_markdown from bot import constants from bot.api import ResponseCodeError @@ -244,6 +245,7 @@ class Information(Cog): name = str(user) if on_server and user.nick: name = f"{user.nick} ({name})" + name = escape_markdown(name) if user.public_flags.verified_bot: name += f" {constants.Emojis.verified_bot}" diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot.bot import Bot from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache log = logging.getLogger(__name__) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 3c5e5d3bf..6ba4e74e9 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -258,13 +258,17 @@ class InfractionScheduler: ctx: Context, infr_type: str, user: MemberOrUser, - send_msg: bool = True + *, + send_msg: bool = True, + notify: bool = True ) -> None: """ Prematurely end an infraction for a user and log the action in the mod log. If `send_msg` is True, then a pardoning confirmation message will be sent to - the context channel. Otherwise, no such message will be sent. + the context channel. Otherwise, no such message will be sent. + + If `notify` is True, notify the user of the pardon via DM where applicable. """ log.trace(f"Pardoning {infr_type} infraction for {user}.") @@ -285,7 +289,7 @@ class InfractionScheduler: return # Deactivate the infraction and cancel its scheduled expiration task. - log_text = await self.deactivate_infraction(response[0], send_log=False) + log_text = await self.deactivate_infraction(response[0], send_log=False, notify=notify) log_text["Member"] = messages.format_user(user) log_text["Actor"] = ctx.author.mention @@ -338,7 +342,9 @@ class InfractionScheduler: async def deactivate_infraction( self, infraction: _utils.Infraction, - send_log: bool = True + *, + send_log: bool = True, + notify: bool = True ) -> t.Dict[str, str]: """ Deactivate an active infraction and return a dictionary of lines to send in a mod log. @@ -347,6 +353,8 @@ class InfractionScheduler: expiration task cancelled. If `send_log` is True, a mod log is sent for the deactivation of the infraction. + If `notify` is True, notify the user of the pardon via DM where applicable. + Infractions of unsupported types will raise a ValueError. """ guild = self.bot.get_guild(constants.Guild.id) @@ -373,7 +381,7 @@ class InfractionScheduler: try: log.trace("Awaiting the pardon action coroutine.") - returned_log = await self._pardon_action(infraction) + returned_log = await self._pardon_action(infraction, notify) if returned_log is not None: log_text = {**log_text, **returned_log} # Merge the logs together @@ -461,10 +469,15 @@ class InfractionScheduler: return log_text @abstractmethod - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action( + self, + infraction: _utils.Infraction, + notify: bool + ) -> t.Optional[t.Dict[str, str]]: """ Execute deactivation steps specific to the infraction's type and return a log dict. + If `notify` is True, notify the user of the pardon via DM where applicable. If an infraction type is unsupported, return None instead. """ raise NotImplementedError diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 9d94bca2d..b20ef1d06 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -139,15 +139,20 @@ async def get_active_infraction( # Checks to see if the moderator should be told there is an active infraction if send_msg: log.trace(f"{user} has active infractions of type {infr_type}.") - await ctx.send( - f":x: According to my records, this user already has a {infr_type} infraction. " - f"See infraction **#{active_infractions[0]['id']}**." - ) + await send_active_infraction_message(ctx, active_infractions[0]) return active_infractions[0] else: log.trace(f"{user} does not have active infractions of type {infr_type}.") +async def send_active_infraction_message(ctx: Context, infraction: Infraction) -> None: + """Send a message stating that the given infraction is active.""" + await ctx.send( + f":x: According to my records, this user already has a {infraction['type']} infraction. " + f"See infraction **#{infraction['id']}**." + ) + + async def notify_infraction( user: MemberOrUser, infr_type: str, diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 48ffbd773..2f9083c29 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -279,8 +279,19 @@ class Infractions(InfractionScheduler, commands.Cog): async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a mute infraction with kwargs passed to `post_infraction`.""" - if await _utils.get_active_infraction(ctx, user, "mute"): - return + if active := await _utils.get_active_infraction(ctx, user, "mute", send_msg=False): + if active["actor"] != self.bot.user.id: + await _utils.send_active_infraction_message(ctx, active) + return + + # Allow the current mute attempt to override an automatically triggered mute. + log_text = await self.deactivate_infraction(active, notify=False) + if "Failure" in log_text: + await ctx.send( + f":x: can't override infraction **mute** for {user.mention}: " + f"failed to deactivate. {log_text['Failure']}" + ) + return infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs) if infraction is None: @@ -344,7 +355,7 @@ class Infractions(InfractionScheduler, commands.Cog): return log.trace("Old tempban is being replaced by new permaban.") - await self.pardon_infraction(ctx, "ban", user, is_temporary) + await self.pardon_infraction(ctx, "ban", user, send_msg=is_temporary) infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs) if infraction is None: @@ -402,8 +413,15 @@ class Infractions(InfractionScheduler, commands.Cog): # endregion # region: Base pardon functions - async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Remove a user's muted role, DM them a notification, and return a log dict.""" + async def pardon_mute( + self, + user_id: int, + guild: discord.Guild, + reason: t.Optional[str], + *, + notify: bool = True + ) -> t.Dict[str, str]: + """Remove a user's muted role, optionally DM them a notification, and return a log dict.""" user = guild.get_member(user_id) log_text = {} @@ -412,16 +430,17 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_update, user.id) await user.remove_roles(self._muted_role, reason=reason) - # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You have been unmuted", - content="You may now send messages in the server.", - icon_url=_utils.INFRACTION_ICONS["mute"][1] - ) + if notify: + # DM the user about the expiration. + notified = await _utils.notify_pardon( + user=user, + title="You have been unmuted", + content="You may now send messages in the server.", + icon_url=_utils.INFRACTION_ICONS["mute"][1] + ) + log_text["DM"] = "Sent" if notified else "**Failed**" log_text["Member"] = format_user(user) - log_text["DM"] = "Sent" if notified else "**Failed**" else: log.info(f"Failed to unmute user {user_id}: user not found") log_text["Failure"] = "User was not found in the guild." @@ -443,31 +462,39 @@ class Infractions(InfractionScheduler, commands.Cog): return log_text - async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: - """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" + async def pardon_voice_ban( + self, + user_id: int, + guild: discord.Guild, + *, + notify: bool = True + ) -> t.Dict[str, str]: + """Optionally DM the user a pardon notification and return a log dict.""" user = guild.get_member(user_id) log_text = {} if user: - # DM user about infraction expiration - notified = await _utils.notify_pardon( - user=user, - title="Voice ban ended", - content="You have been unbanned and can verify yourself again in the server.", - icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] - ) + if notify: + # DM user about infraction expiration + notified = await _utils.notify_pardon( + user=user, + title="Voice ban ended", + content="You have been unbanned and can verify yourself again in the server.", + icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] + ) + log_text["DM"] = "Sent" if notified else "**Failed**" log_text["Member"] = format_user(user) - log_text["DM"] = "Sent" if notified else "**Failed**" else: log_text["Info"] = "User was not found in the guild." return log_text - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: + async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]: """ Execute deactivation steps specific to the infraction's type and return a log dict. + If `notify` is True, notify the user of the pardon via DM where applicable. If an infraction type is unsupported, return None instead. """ guild = self.bot.get_guild(constants.Guild.id) @@ -475,11 +502,11 @@ class Infractions(InfractionScheduler, commands.Cog): reason = f"Infraction #{infraction['id']} expired or was pardoned." if infraction["type"] == "mute": - return await self.pardon_mute(user_id, guild, reason) + return await self.pardon_mute(user_id, guild, reason, notify=notify) elif infraction["type"] == "ban": return await self.pardon_ban(user_id, guild, reason) elif infraction["type"] == "voice_ban": - return await self.pardon_voice_ban(user_id, guild, reason) + return await self.pardon_voice_ban(user_id, guild, notify=notify) # endregion diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 07e79b9fe..05a2bbe10 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -192,8 +192,8 @@ class Superstarify(InfractionScheduler, Cog): """Remove the superstarify infraction and allow the user to change their nickname.""" await self.pardon_infraction(ctx, "superstar", member) - async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: - """Pardon a superstar infraction and return a log dict.""" + async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]: + """Pardon a superstar infraction, optionally notify the user via DM, and return a log dict.""" if infraction["type"] != "superstar": return @@ -208,18 +208,19 @@ class Superstarify(InfractionScheduler, Cog): ) return {} + log_text = {"Member": format_user(user)} + # DM the user about the expiration. - notified = await _utils.notify_pardon( - user=user, - title="You are no longer superstarified", - content="You may now change your nickname on the server.", - icon_url=_utils.INFRACTION_ICONS["superstar"][1] - ) + if notify: + notified = await _utils.notify_pardon( + user=user, + title="You are no longer superstarified", + content="You may now change your nickname on the server.", + icon_url=_utils.INFRACTION_ICONS["superstar"][1] + ) + log_text["DM"] = "Sent" if notified else "**Failed**" - return { - "Member": format_user(user), - "DM": "Sent" if notified else "**Failed**" - } + return log_text @staticmethod def get_nick(infraction_id: int, member_id: int) -> str: diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 5c1a1cd3f..c297f70c2 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -263,7 +263,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): } ) - msg = f"✅ The nomination for {user} has been added to the talent pool" + msg = f"✅ The nomination for {user.mention} has been added to the talent pool" if history: msg += f"\n\n({len(history)} previous nominations in total)" @@ -311,7 +311,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): return if await self.unwatch(user.id, reason): - await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") + await ctx.send(f":white_check_mark: Messages sent by {user.mention} will no longer be relayed") else: await ctx.send(":x: The specified user does not have an active nomination") @@ -344,7 +344,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): return if not any(entry["actor"] == actor.id for entry in nomination["entries"]): - await ctx.send(f":x: {actor} doesn't have an entry in this nomination.") + await ctx.send(f":x: {actor.mention} doesn't have an entry in this nomination.") return self.log.trace(f"Changing reason for nomination with id {nomination_id} of actor {actor} to {repr(reason)}") diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 6f2da3131..5d2cd7611 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -11,10 +11,10 @@ from io import StringIO from typing import Any, Optional, Tuple import discord -from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role, is_owner from bot.bot import Bot -from bot.constants import Roles +from bot.constants import DEBUG_MODE, Roles from bot.utils import find_nth_occurrence, send_to_paste_service log = logging.getLogger(__name__) @@ -33,6 +33,9 @@ class Internal(Cog): self.socket_event_total = 0 self.socket_events = Counter() + if DEBUG_MODE: + self.eval.add_check(is_owner().predicate) + @Cog.listener() async def on_socket_response(self, msg: dict) -> None: """When a websocket event is received, increase our counters.""" diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 2496ce5cc..90ad7ef2e 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -15,7 +15,7 @@ from bot.constants import ( Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES ) -from bot.converters import Duration +from bot.converters import Duration, UserMentionOrID from bot.pagination import LinePaginator from bot.utils.checks import has_any_role_check, has_no_roles_check from bot.utils.lock import lock_arg @@ -30,6 +30,7 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 Mentionable = t.Union[discord.Member, discord.Role] +ReminderMention = t.Union[UserMentionOrID, discord.Role] class Reminders(Cog): @@ -214,14 +215,14 @@ class Reminders(Cog): @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) async def remind_group( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str ) -> None: """Commands for managing your reminders.""" await self.new_reminder(ctx, mentions=mentions, expiration=expiration, content=content) @remind_group.command(name="new", aliases=("add", "create")) async def new_reminder( - self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str + self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str ) -> None: """ Set yourself a simple reminder. @@ -367,7 +368,7 @@ class Reminders(Cog): await self.edit_reminder(ctx, id_, {"content": content}) @edit_reminder_group.command(name="mentions", aliases=("pings",)) - async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: + async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[ReminderMention]) -> None: """Edit one of your reminder's mentions.""" # Remove duplicate mentions mentions = set(mentions) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index c4a466943..f91a9fee6 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -14,7 +14,6 @@ from bot.converters import Snowflake from bot.decorators import in_whitelist from bot.pagination import LinePaginator from bot.utils import messages -from bot.utils.checks import has_no_roles_check from bot.utils.time import time_since log = logging.getLogger(__name__) @@ -162,9 +161,6 @@ class Utils(Cog): @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_PARTNERS_COMMUNITY_ROLES) async def snowflake(self, ctx: Context, *snowflakes: Snowflake) -> None: """Get Discord snowflake creation time.""" - if len(snowflakes) > 1 and await has_no_roles_check(ctx, *STAFF_ROLES): - raise BadArgument("Cannot process more than one snowflake in one invocation.") - if not snowflakes: raise BadArgument("At least one snowflake must be provided.") diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 79725a4b1..6f5addad1 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -13,7 +13,11 @@ async def apply( if msg.author == last_message.author ) - total_recent_mentions = sum(len(msg.mentions) for msg in relevant_messages) + total_recent_mentions = sum( + not user.bot + for msg in relevant_messages + for user in msg.mentions + ) if total_recent_mentions > config['max']: return ( diff --git a/bot/utils/cache.py b/bot/utils/caching.py index 68ce15607..68ce15607 100644 --- a/bot/utils/cache.py +++ b/bot/utils/caching.py diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..f68d280c9 --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,197 @@ +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: + """ + A data structure for caching messages. + + The cache is implemented as a circular buffer to allow constant time append, prepend, pop from either side, + and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be + implemented to work in linear time relative to the maximum size). + + The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message + is stored, to allow for constant time lookup by message ID. + + The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those + of a deque. + + The implementation is transparent to the user: to the user the first element is always at index 0, and there are + only as many elements as were inserted (meaning, without any pre-allocated placeholder values). + """ + + def __init__(self, maxlen: int, *, newest_first: bool = False): + if maxlen <= 0: + raise ValueError("maxlen must be positive") + self.maxlen = maxlen + self.newest_first = newest_first + + self._start = 0 + self._end = 0 + + self._messages: list[t.Optional[Message]] = [None] * self.maxlen + self._message_id_mapping = {} + + def append(self, message: Message) -> None: + """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" + if self.newest_first: + self._appendleft(message) + else: + self._appendright(message) + + def _appendright(self, message: Message) -> None: + """Add the received message to the end of the cache.""" + if self._is_full(): + del self._message_id_mapping[self._messages[self._start].id] + self._start = (self._start + 1) % self.maxlen + + self._messages[self._end] = message + self._message_id_mapping[message.id] = self._end + self._end = (self._end + 1) % self.maxlen + + def _appendleft(self, message: Message) -> None: + """Add the received message to the beginning of the cache.""" + if self._is_full(): + self._end = (self._end - 1) % self.maxlen + del self._message_id_mapping[self._messages[self._end].id] + + self._start = (self._start - 1) % self.maxlen + self._messages[self._start] = message + self._message_id_mapping[message.id] = self._start + + def pop(self) -> Message: + """Remove the last message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + self._end = (self._end - 1) % self.maxlen + message = self._messages[self._end] + del self._message_id_mapping[message.id] + self._messages[self._end] = None + + return message + + def popleft(self) -> Message: + """Return the first message in the cache and return it.""" + if self._is_empty(): + raise IndexError("pop from an empty cache") + + message = self._messages[self._start] + del self._message_id_mapping[message.id] + self._messages[self._start] = None + self._start = (self._start + 1) % self.maxlen + + return message + + def clear(self) -> None: + """Remove all messages from the cache.""" + self._messages = [None] * self.maxlen + self._message_id_mapping = {} + + self._start = 0 + self._end = 0 + + def get_message(self, message_id: int) -> t.Optional[Message]: + """Return the message that has the given message ID, if it is cached.""" + index = self._message_id_mapping.get(message_id, None) + return self._messages[index] if index is not None else None + + def update(self, message: Message) -> bool: + """ + Update a cached message with new contents. + + Return True if the given message had a matching ID in the cache. + """ + index = self._message_id_mapping.get(message.id, None) + if index is None: + return False + self._messages[index] = message + return True + + def __contains__(self, message_id: int) -> bool: + """Return True if the cache contains a message with the given ID .""" + return message_id in self._message_id_mapping + + def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: + """ + Return the message(s) in the index or slice provided. + + This method makes the circular buffer implementation transparent to the user. + Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, + meaning at `self._start`. + """ + # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when + # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides + # by 5 is -5. + if isinstance(item, int): + if item >= len(self) or item < -len(self): + raise IndexError("cache index out of range") + return self._messages[(item + self._start) % self.maxlen] + + elif isinstance(item, slice): + length = len(self) + start, stop, step = item.indices(length) + + # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. + if (start >= stop and step >= 0) or (start <= stop and step <= 0): + return [] + + start = (start + self._start) % self.maxlen + stop = (stop + self._start) % self.maxlen + + # Having empty cells is an implementation detail. To the user the cache contains as many elements as they + # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. + if step > 0: + if ( + (self._start < self._end and not self._start < stop <= self._end) + or (self._start > self._end and self._end < stop <= self._start) + ): + stop = self._end + else: + lower_boundary = (self._start - 1) % self.maxlen + if ( + (self._start < self._end and not self._start - 1 <= stop < self._end) + or (self._start > self._end and self._end < stop < lower_boundary) + ): + stop = lower_boundary + + if (start < stop and step > 0) or (start > stop and step < 0): + return self._messages[start:stop:step] + # step != 1 may require a start offset in the second slicing. + if step > 0: + offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen + return self._messages[start::step] + self._messages[offset:stop:step] + else: + offset = ceil((start + 1) / -step) * -step - start - 1 + return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] + + else: + raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + + def __iter__(self) -> t.Iterator[Message]: + if self._is_empty(): + return + + if self._start < self._end: + yield from self._messages[self._start:self._end] + else: + yield from self._messages[self._start:] + yield from self._messages[:self._end] + + def __len__(self): + """Get the number of non-empty cells in the cache.""" + if self._is_empty(): + return 0 + if self._end > self._start: + return self._end - self._start + return self.maxlen - self._start + self._end + + def _is_empty(self) -> bool: + """Return True if the cache has no messages.""" + return self._messages[self._start] is None + + def _is_full(self) -> bool: + """Return True if every cell in the cache already contains a message.""" + return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 79828dd77..8e0b97a51 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,3 +1,6 @@ +debug: !ENV ["BOT_DEBUG", "true"] + + bot: prefix: "!" sentry_dsn: !ENV "BOT_SENTRY_DSN" @@ -377,6 +380,8 @@ urls: anti_spam: + cache_size: 100 + # Clean messages that violate a rule. clean_offending: true ping_everyone: true diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b9d527770..f844a9181 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): async def test_voice_unban_user_not_found(self): """Should include info to return dict when user was not found from guild.""" self.guild.get_member.return_value = None - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, {"Info": "User was not found in the guild."}) @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): notify_pardon_mock.return_value = True format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "Sent" @@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): notify_pardon_mock.return_value = False format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") + result = await self.cog.pardon_voice_ban(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "**Failed**" diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5f95ced9f..eb256f1fd 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"]) test_cases = [ test_case([], None, None, True), - test_case([{"id": 123987}], {"id": 123987}, "123987", False), - test_case([{"id": 123987}], {"id": 123987}, "123987", True) + test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False), + test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True) ] for case in test_cases: diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index 6444532f2..f8805ac48 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,12 +2,14 @@ from typing import Iterable from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage +from tests.helpers import MockMember, MockMessage -def make_msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: """Makes a message with `total_mentions` mentions.""" - return MockMessage(author=author, mentions=list(range(total_mentions))) + user_mentions = [MockMember() for _ in range(total_user_mentions)] + bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] + return MockMessage(author=author, mentions=user_mentions+bot_mentions) class TestMentions(RuleTest): @@ -48,11 +50,27 @@ class TestMentions(RuleTest): [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], ("bob",), 4, - ) + ), + DisallowedCase( + [make_msg("bob", 3, 1)], + ("bob",), + 3, + ), ) await self.run_disallowed(cases) + async def test_ignore_bot_mentions(self): + """Messages with an allowed amount of mentions, also containing bot mentions.""" + cases = ( + [make_msg("bob", 0, 3)], + [make_msg("bob", 2, 1)], + [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], + [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] + ) + + await self.run_allowed(cases) + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: last_message = case.recent_messages[0] return tuple( diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): + """Tests for the MessageCache class in the `bot.utils.caching` module.""" + + def test_first_append_sets_the_first_value(self): + """Test if the first append adds the message to the first cell.""" + cache = MessageCache(maxlen=10) + message = MockMessage() + + cache.append(message) + + self.assertEqual(cache[0], message) + + def test_append_adds_in_the_right_order(self): + """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" + messages = [MockMessage(), MockMessage()] + + cache = MessageCache(maxlen=10, newest_first=False) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages, list(cache)) + + cache = MessageCache(maxlen=10, newest_first=True) + for msg in messages: + cache.append(msg) + self.assertListEqual(messages[::-1], list(cache)) + + def test_appending_over_maxlen_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages.""" + cache = MessageCache(maxlen=2) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[1:], list(cache)) + + def test_appending_over_maxlen_with_newest_first_removes_oldest(self): + """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" + cache = MessageCache(maxlen=2, newest_first=True) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + + self.assertListEqual(messages[:0:-1], list(cache)) + + def test_pop_removes_from_the_end(self): + """Test if a pop removes the right-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.pop() + + self.assertEqual(msg, messages[-1]) + self.assertListEqual(messages[:-1], list(cache)) + + def test_popleft_removes_from_the_beginning(self): + """Test if a popleft removes the left-most message.""" + cache = MessageCache(maxlen=3) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + msg = cache.popleft() + + self.assertEqual(msg, messages[0]) + self.assertListEqual(messages[1:], list(cache)) + + def test_clear(self): + """Test if a clear makes the cache empty.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + + for msg in messages: + cache.append(msg) + cache.clear() + + self.assertListEqual(list(cache), []) + self.assertEqual(len(cache), 0) + + def test_get_message_returns_the_message(self): + """Test if get_message returns the cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertEqual(cache.get_message(1234), message) + + def test_get_message_returns_none(self): + """Test if get_message returns None for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIsNone(cache.get_message(4321)) + + def test_update_replaces_old_element(self): + """Test if an update replaced the old message with the same ID.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + message = MockMessage(id=1234) + cache.update(message) + + self.assertIs(cache.get_message(1234), message) + self.assertEqual(len(cache), 1) + + def test_contains_returns_true_for_cached_message(self): + """Test if contains returns True for an ID of a cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertIn(1234, cache) + + def test_contains_returns_false_for_non_cached_message(self): + """Test if contains returns False for an ID of a non-cached message.""" + cache = MessageCache(maxlen=5) + message = MockMessage(id=1234) + + cache.append(message) + + self.assertNotIn(4321, cache) + + def test_indexing(self): + """Test if the cache returns the correct messages by index.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(5)] + + for msg in messages: + cache.append(msg) + + for current_loop in range(-5, 5): + with self.subTest(current_loop=current_loop): + self.assertEqual(cache[current_loop], messages[current_loop]) + + def test_bad_index_raises_index_error(self): + """Test if the cache raises IndexError for invalid indices.""" + cache = MessageCache(maxlen=5) + messages = [MockMessage() for _ in range(3)] + test_cases = (-10, -4, 3, 4, 5) + + for msg in messages: + cache.append(msg) + + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with self.assertRaises(IndexError): + cache[current_loop] + + def test_slicing_with_unfilled_cache(self): + """Test if slicing returns the correct messages if the cache is not yet fully filled.""" + sizes = (5, 10, 55, 101) + + slices = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) + ) + + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size // 3 * 2)] + + for msg in messages: + cache.append(msg) + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) + + def test_slicing_with_overfilled_cache(self): + """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" + sizes = (5, 10, 55, 101) + + slices = ( + slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), + slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), + slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) + ) + + for size in sizes: + cache = MessageCache(maxlen=size) + messages = [MockMessage() for _ in range(size * 3 // 2)] + + for msg in messages: + cache.append(msg) + messages = messages[size // 2:] + + for slice_ in slices: + with self.subTest(current_loop=(size, slice_)): + self.assertListEqual(cache[slice_], messages[slice_]) + + def test_length(self): + """Test if len returns the correct number of items in the cache.""" + cache = MessageCache(maxlen=5) + + for current_loop in range(10): + with self.subTest(current_loop=current_loop): + self.assertEqual(len(cache), min(current_loop, 5)) + cache.append(MockMessage()) |