diff options
| -rw-r--r-- | bot/constants.py | 4 | ||||
| -rw-r--r-- | bot/exts/filters/antispam.py | 84 | ||||
| -rw-r--r-- | bot/exts/info/information.py | 2 | ||||
| -rw-r--r-- | bot/exts/info/pep.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 25 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_utils.py | 13 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 79 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 25 | ||||
| -rw-r--r-- | bot/exts/recruitment/talentpool/_cog.py | 6 | ||||
| -rw-r--r-- | bot/exts/utils/internal.py | 7 | ||||
| -rw-r--r-- | bot/exts/utils/reminders.py | 9 | ||||
| -rw-r--r-- | bot/exts/utils/utils.py | 4 | ||||
| -rw-r--r-- | bot/rules/mentions.py | 6 | ||||
| -rw-r--r-- | bot/utils/caching.py (renamed from bot/utils/cache.py) | 0 | ||||
| -rw-r--r-- | bot/utils/message_cache.py | 197 | ||||
| -rw-r--r-- | config-default.yml | 5 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 6 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 4 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 26 | ||||
| -rw-r--r-- | tests/bot/utils/test_message_cache.py | 214 | 
20 files changed, 610 insertions, 108 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 5b629a735..80e01b174 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -575,6 +575,8 @@ class Metabase(metaclass=YAMLGetter):  class AntiSpam(metaclass=YAMLGetter):      section = 'anti_spam' +    cache_size: int +      clean_offending: bool      ping_everyone: bool @@ -687,7 +689,7 @@ class VideoPermission(metaclass=YAMLGetter):  # Debug mode -DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") +DEBUG_MODE: bool = _CONFIG_YAML["debug"] == "true"  # Paths  BOT_DIR = os.path.dirname(__file__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 226da2790..987060779 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -1,8 +1,10 @@  import asyncio  import logging +from collections import defaultdict  from collections.abc import Mapping  from dataclasses import dataclass, field  from datetime import datetime, timedelta +from itertools import takewhile  from operator import attrgetter, itemgetter  from typing import Dict, Iterable, List, Set @@ -20,6 +22,7 @@ from bot.converters import Duration  from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.exts.moderation.modlog import ModLog  from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache  from bot.utils.messages import format_user, send_attachments @@ -44,19 +47,18 @@ RULE_FUNCTION_MAPPING = {  class DeletionContext:      """Represents a Deletion Context for a single spam event.""" -    channel: TextChannel -    members: Dict[int, Member] = field(default_factory=dict) +    members: frozenset[Member] +    triggered_in: TextChannel +    channels: set[TextChannel] = field(default_factory=set)      rules: Set[str] = field(default_factory=set)      messages: Dict[int, Message] = field(default_factory=dict)      attachments: List[List[str]] = field(default_factory=list) -    async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: +    async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None:          """Adds new rule violation events to the deletion context."""          self.rules.add(rule_name) -        for member in members: -            if member.id not in self.members: -                self.members[member.id] = member +        self.channels.update(channels)          for message in messages:              if message.id not in self.messages: @@ -69,11 +71,14 @@ class DeletionContext:      async def upload_messages(self, actor_id: int, modlog: ModLog) -> None:          """Method that takes care of uploading the queue and posting modlog alert.""" -        triggered_by_users = ", ".join(format_user(m) for m in self.members.values()) +        triggered_by_users = ", ".join(format_user(m) for m in self.members) +        triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" +        channels_description = ", ".join(channel.mention for channel in self.channels)          mod_alert_message = (              f"**Triggered by:** {triggered_by_users}\n" -            f"**Channel:** {self.channel.mention}\n" +            f"{triggered_in_channel}" +            f"**Channels:** {channels_description}\n"              f"**Rules:** {', '.join(rule for rule in self.rules)}\n"          ) @@ -116,6 +121,14 @@ class AntiSpam(Cog):          self.message_deletion_queue = dict() +        # Fetch the rule configuration with the highest rule interval. +        max_interval_config = max( +            AntiSpamConfig.rules.values(), +            key=itemgetter('interval') +        ) +        self.max_interval = max_interval_config['interval'] +        self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) +          self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")      @property @@ -155,19 +168,10 @@ class AntiSpam(Cog):          ):              return -        # Fetch the rule configuration with the highest rule interval. -        max_interval_config = max( -            AntiSpamConfig.rules.values(), -            key=itemgetter('interval') -        ) -        max_interval = max_interval_config['interval'] +        self.cache.append(message) -        # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. -        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) -        relevant_messages = [ -            msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) -            if not msg.author.bot -        ] +        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))          for rule_name in AntiSpamConfig.rules:              rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +179,10 @@ class AntiSpam(Cog):              # Create a list of messages that were sent in the interval that the rule cares about.              latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) -            messages_for_rule = [ -                msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp -            ] +            messages_for_rule = list( +                takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) +            ) +              result = await rule_function(message, messages_for_rule, rule_config)              # If the rule returns `None`, that means the message didn't violate it. @@ -190,19 +195,19 @@ class AntiSpam(Cog):                  full_reason = f"`{rule_name}` rule: {reason}"                  # If there's no spam event going on for this channel, start a new Message Deletion Context -                channel = message.channel -                if channel.id not in self.message_deletion_queue: -                    log.trace(f"Creating queue for channel `{channel.id}`") -                    self.message_deletion_queue[message.channel.id] = DeletionContext(channel) +                authors_set = frozenset(members) +                if authors_set not in self.message_deletion_queue: +                    log.trace(f"Creating queue for members `{authors_set}`") +                    self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel)                      scheduling.create_task( -                        self._process_deletion_context(message.channel.id), -                        name=f"AntiSpam._process_deletion_context({message.channel.id})" +                        self._process_deletion_context(authors_set), +                        name=f"AntiSpam._process_deletion_context({authors_set})"                      )                  # Add the relevant of this trigger to the Deletion Context -                await self.message_deletion_queue[message.channel.id].add( +                await self.message_deletion_queue[authors_set].add(                      rule_name=rule_name, -                    members=members, +                    channels=set(message.channel for message in messages_for_rule),                      messages=relevant_messages                  ) @@ -212,7 +217,7 @@ class AntiSpam(Cog):                          name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"                      ) -                await self.maybe_delete_messages(channel, relevant_messages) +                await self.maybe_delete_messages(messages_for_rule)                  break      @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -234,14 +239,18 @@ class AntiSpam(Cog):                  reason=reason              ) -    async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: +    async def maybe_delete_messages(self, messages: List[Message]) -> None:          """Cleans the messages if cleaning is configured."""          if AntiSpamConfig.clean_offending:              # If we have more than one message, we can use bulk delete.              if len(messages) > 1:                  message_ids = [message.id for message in messages]                  self.mod_log.ignore(Event.message_delete, *message_ids) -                await channel.delete_messages(messages) +                channel_messages = defaultdict(list) +                for message in messages: +                    channel_messages[message.channel].append(message) +                for channel, messages in channel_messages.items(): +                    await channel.delete_messages(messages)              # Otherwise, the bulk delete endpoint will throw up.              # Delete the message directly instead. @@ -252,7 +261,7 @@ class AntiSpam(Cog):                  except NotFound:                      log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") -    async def _process_deletion_context(self, context_id: int) -> None: +    async def _process_deletion_context(self, context_id: frozenset) -> None:          """Processes the Deletion Context queue."""          log.trace("Sleeping before processing message deletion queue.")          await asyncio.sleep(10) @@ -264,6 +273,11 @@ class AntiSpam(Cog):          deletion_context = self.message_deletion_queue.pop(context_id)          await deletion_context.upload_messages(self.bot.user.id, self.mod_log) +    @Cog.listener() +    async def on_message_edit(self, before: Message, after: Message) -> None: +        """Updates the message in the cache, if it's cached.""" +        self.cache.update(after) +  def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:      """Validates the antispam configs.""" diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 664b6cb13..38b436b7d 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -8,6 +8,7 @@ from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union  import rapidfuzz  from discord import AllowedMentions, Colour, Embed, Guild, Message, Role  from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role +from discord.utils import escape_markdown  from bot import constants  from bot.api import ResponseCodeError @@ -244,6 +245,7 @@ class Information(Cog):          name = str(user)          if on_server and user.nick:              name = f"{user.nick} ({name})" +        name = escape_markdown(name)          if user.public_flags.verified_bot:              name += f" {constants.Emojis.verified_bot}" diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command  from bot.bot import Bot  from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache  log = logging.getLogger(__name__) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 3c5e5d3bf..6ba4e74e9 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -258,13 +258,17 @@ class InfractionScheduler:              ctx: Context,              infr_type: str,              user: MemberOrUser, -            send_msg: bool = True +            *, +            send_msg: bool = True, +            notify: bool = True      ) -> None:          """          Prematurely end an infraction for a user and log the action in the mod log.          If `send_msg` is True, then a pardoning confirmation message will be sent to -        the context channel.  Otherwise, no such message will be sent. +        the context channel. Otherwise, no such message will be sent. + +        If `notify` is True, notify the user of the pardon via DM where applicable.          """          log.trace(f"Pardoning {infr_type} infraction for {user}.") @@ -285,7 +289,7 @@ class InfractionScheduler:              return          # Deactivate the infraction and cancel its scheduled expiration task. -        log_text = await self.deactivate_infraction(response[0], send_log=False) +        log_text = await self.deactivate_infraction(response[0], send_log=False, notify=notify)          log_text["Member"] = messages.format_user(user)          log_text["Actor"] = ctx.author.mention @@ -338,7 +342,9 @@ class InfractionScheduler:      async def deactivate_infraction(          self,          infraction: _utils.Infraction, -        send_log: bool = True +        *, +        send_log: bool = True, +        notify: bool = True      ) -> t.Dict[str, str]:          """          Deactivate an active infraction and return a dictionary of lines to send in a mod log. @@ -347,6 +353,8 @@ class InfractionScheduler:          expiration task cancelled. If `send_log` is True, a mod log is sent for the          deactivation of the infraction. +        If `notify` is True, notify the user of the pardon via DM where applicable. +          Infractions of unsupported types will raise a ValueError.          """          guild = self.bot.get_guild(constants.Guild.id) @@ -373,7 +381,7 @@ class InfractionScheduler:          try:              log.trace("Awaiting the pardon action coroutine.") -            returned_log = await self._pardon_action(infraction) +            returned_log = await self._pardon_action(infraction, notify)              if returned_log is not None:                  log_text = {**log_text, **returned_log}  # Merge the logs together @@ -461,10 +469,15 @@ class InfractionScheduler:          return log_text      @abstractmethod -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: +    async def _pardon_action( +        self, +        infraction: _utils.Infraction, +        notify: bool +    ) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. +        If `notify` is True, notify the user of the pardon via DM where applicable.          If an infraction type is unsupported, return None instead.          """          raise NotImplementedError diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 9d94bca2d..b20ef1d06 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -139,15 +139,20 @@ async def get_active_infraction(          # Checks to see if the moderator should be told there is an active infraction          if send_msg:              log.trace(f"{user} has active infractions of type {infr_type}.") -            await ctx.send( -                f":x: According to my records, this user already has a {infr_type} infraction. " -                f"See infraction **#{active_infractions[0]['id']}**." -            ) +            await send_active_infraction_message(ctx, active_infractions[0])          return active_infractions[0]      else:          log.trace(f"{user} does not have active infractions of type {infr_type}.") +async def send_active_infraction_message(ctx: Context, infraction: Infraction) -> None: +    """Send a message stating that the given infraction is active.""" +    await ctx.send( +        f":x: According to my records, this user already has a {infraction['type']} infraction. " +        f"See infraction **#{infraction['id']}**." +    ) + +  async def notify_infraction(          user: MemberOrUser,          infr_type: str, diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 48ffbd773..2f9083c29 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -279,8 +279,19 @@ class Infractions(InfractionScheduler, commands.Cog):      async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None:          """Apply a mute infraction with kwargs passed to `post_infraction`.""" -        if await _utils.get_active_infraction(ctx, user, "mute"): -            return +        if active := await _utils.get_active_infraction(ctx, user, "mute", send_msg=False): +            if active["actor"] != self.bot.user.id: +                await _utils.send_active_infraction_message(ctx, active) +                return + +            # Allow the current mute attempt to override an automatically triggered mute. +            log_text = await self.deactivate_infraction(active, notify=False) +            if "Failure" in log_text: +                await ctx.send( +                    f":x: can't override infraction **mute** for {user.mention}: " +                    f"failed to deactivate. {log_text['Failure']}" +                ) +                return          infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs)          if infraction is None: @@ -344,7 +355,7 @@ class Infractions(InfractionScheduler, commands.Cog):                  return              log.trace("Old tempban is being replaced by new permaban.") -            await self.pardon_infraction(ctx, "ban", user, is_temporary) +            await self.pardon_infraction(ctx, "ban", user, send_msg=is_temporary)          infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs)          if infraction is None: @@ -402,8 +413,15 @@ class Infractions(InfractionScheduler, commands.Cog):      # endregion      # region: Base pardon functions -    async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: -        """Remove a user's muted role, DM them a notification, and return a log dict.""" +    async def pardon_mute( +        self, +        user_id: int, +        guild: discord.Guild, +        reason: t.Optional[str], +        *, +        notify: bool = True +    ) -> t.Dict[str, str]: +        """Remove a user's muted role, optionally DM them a notification, and return a log dict."""          user = guild.get_member(user_id)          log_text = {} @@ -412,16 +430,17 @@ class Infractions(InfractionScheduler, commands.Cog):              self.mod_log.ignore(Event.member_update, user.id)              await user.remove_roles(self._muted_role, reason=reason) -            # DM the user about the expiration. -            notified = await _utils.notify_pardon( -                user=user, -                title="You have been unmuted", -                content="You may now send messages in the server.", -                icon_url=_utils.INFRACTION_ICONS["mute"][1] -            ) +            if notify: +                # DM the user about the expiration. +                notified = await _utils.notify_pardon( +                    user=user, +                    title="You have been unmuted", +                    content="You may now send messages in the server.", +                    icon_url=_utils.INFRACTION_ICONS["mute"][1] +                ) +                log_text["DM"] = "Sent" if notified else "**Failed**"              log_text["Member"] = format_user(user) -            log_text["DM"] = "Sent" if notified else "**Failed**"          else:              log.info(f"Failed to unmute user {user_id}: user not found")              log_text["Failure"] = "User was not found in the guild." @@ -443,31 +462,39 @@ class Infractions(InfractionScheduler, commands.Cog):          return log_text -    async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: -        """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" +    async def pardon_voice_ban( +        self, +        user_id: int, +        guild: discord.Guild, +        *, +        notify: bool = True +    ) -> t.Dict[str, str]: +        """Optionally DM the user a pardon notification and return a log dict."""          user = guild.get_member(user_id)          log_text = {}          if user: -            # DM user about infraction expiration -            notified = await _utils.notify_pardon( -                user=user, -                title="Voice ban ended", -                content="You have been unbanned and can verify yourself again in the server.", -                icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] -            ) +            if notify: +                # DM user about infraction expiration +                notified = await _utils.notify_pardon( +                    user=user, +                    title="Voice ban ended", +                    content="You have been unbanned and can verify yourself again in the server.", +                    icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] +                ) +                log_text["DM"] = "Sent" if notified else "**Failed**"              log_text["Member"] = format_user(user) -            log_text["DM"] = "Sent" if notified else "**Failed**"          else:              log_text["Info"] = "User was not found in the guild."          return log_text -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: +    async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. +        If `notify` is True, notify the user of the pardon via DM where applicable.          If an infraction type is unsupported, return None instead.          """          guild = self.bot.get_guild(constants.Guild.id) @@ -475,11 +502,11 @@ class Infractions(InfractionScheduler, commands.Cog):          reason = f"Infraction #{infraction['id']} expired or was pardoned."          if infraction["type"] == "mute": -            return await self.pardon_mute(user_id, guild, reason) +            return await self.pardon_mute(user_id, guild, reason, notify=notify)          elif infraction["type"] == "ban":              return await self.pardon_ban(user_id, guild, reason)          elif infraction["type"] == "voice_ban": -            return await self.pardon_voice_ban(user_id, guild, reason) +            return await self.pardon_voice_ban(user_id, guild, notify=notify)      # endregion diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 07e79b9fe..05a2bbe10 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -192,8 +192,8 @@ class Superstarify(InfractionScheduler, Cog):          """Remove the superstarify infraction and allow the user to change their nickname."""          await self.pardon_infraction(ctx, "superstar", member) -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: -        """Pardon a superstar infraction and return a log dict.""" +    async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]: +        """Pardon a superstar infraction, optionally notify the user via DM, and return a log dict."""          if infraction["type"] != "superstar":              return @@ -208,18 +208,19 @@ class Superstarify(InfractionScheduler, Cog):              )              return {} +        log_text = {"Member": format_user(user)} +          # DM the user about the expiration. -        notified = await _utils.notify_pardon( -            user=user, -            title="You are no longer superstarified", -            content="You may now change your nickname on the server.", -            icon_url=_utils.INFRACTION_ICONS["superstar"][1] -        ) +        if notify: +            notified = await _utils.notify_pardon( +                user=user, +                title="You are no longer superstarified", +                content="You may now change your nickname on the server.", +                icon_url=_utils.INFRACTION_ICONS["superstar"][1] +            ) +            log_text["DM"] = "Sent" if notified else "**Failed**" -        return { -            "Member": format_user(user), -            "DM": "Sent" if notified else "**Failed**" -        } +        return log_text      @staticmethod      def get_nick(infraction_id: int, member_id: int) -> str: diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 5c1a1cd3f..c297f70c2 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -263,7 +263,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):              }          ) -        msg = f"✅ The nomination for {user} has been added to the talent pool" +        msg = f"✅ The nomination for {user.mention} has been added to the talent pool"          if history:              msg += f"\n\n({len(history)} previous nominations in total)" @@ -311,7 +311,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):              return          if await self.unwatch(user.id, reason): -            await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed") +            await ctx.send(f":white_check_mark: Messages sent by {user.mention} will no longer be relayed")          else:              await ctx.send(":x: The specified user does not have an active nomination") @@ -344,7 +344,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):              return          if not any(entry["actor"] == actor.id for entry in nomination["entries"]): -            await ctx.send(f":x: {actor} doesn't have an entry in this nomination.") +            await ctx.send(f":x: {actor.mention} doesn't have an entry in this nomination.")              return          self.log.trace(f"Changing reason for nomination with id {nomination_id} of actor {actor} to {repr(reason)}") diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 6f2da3131..5d2cd7611 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -11,10 +11,10 @@ from io import StringIO  from typing import Any, Optional, Tuple  import discord -from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.commands import Cog, Context, group, has_any_role, is_owner  from bot.bot import Bot -from bot.constants import Roles +from bot.constants import DEBUG_MODE, Roles  from bot.utils import find_nth_occurrence, send_to_paste_service  log = logging.getLogger(__name__) @@ -33,6 +33,9 @@ class Internal(Cog):          self.socket_event_total = 0          self.socket_events = Counter() +        if DEBUG_MODE: +            self.eval.add_check(is_owner().predicate) +      @Cog.listener()      async def on_socket_response(self, msg: dict) -> None:          """When a websocket event is received, increase our counters.""" diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 847883fc7..1db2de6dc 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -15,7 +15,7 @@ from bot.constants import (      Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES,      Roles, STAFF_PARTNERS_COMMUNITY_ROLES  ) -from bot.converters import Duration +from bot.converters import Duration, UserMentionOrID  from bot.pagination import LinePaginator  from bot.utils.checks import has_any_role_check, has_no_roles_check  from bot.utils.lock import lock_arg @@ -30,6 +30,7 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5  Mentionable = t.Union[discord.Member, discord.Role] +ReminderMention = t.Union[UserMentionOrID, discord.Role]  class Reminders(Cog): @@ -214,14 +215,14 @@ class Reminders(Cog):      @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True)      async def remind_group( -        self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str +        self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str      ) -> None:          """Commands for managing your reminders."""          await self.new_reminder(ctx, mentions=mentions, expiration=expiration, content=content)      @remind_group.command(name="new", aliases=("add", "create"))      async def new_reminder( -        self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str +        self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str      ) -> None:          """          Set yourself a simple reminder. @@ -366,7 +367,7 @@ class Reminders(Cog):          await self.edit_reminder(ctx, id_, {"content": content})      @edit_reminder_group.command(name="mentions", aliases=("pings",)) -    async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: +    async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[ReminderMention]) -> None:          """Edit one of your reminder's mentions."""          # Remove duplicate mentions          mentions = set(mentions) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index c4a466943..f91a9fee6 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -14,7 +14,6 @@ from bot.converters import Snowflake  from bot.decorators import in_whitelist  from bot.pagination import LinePaginator  from bot.utils import messages -from bot.utils.checks import has_no_roles_check  from bot.utils.time import time_since  log = logging.getLogger(__name__) @@ -162,9 +161,6 @@ class Utils(Cog):      @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_PARTNERS_COMMUNITY_ROLES)      async def snowflake(self, ctx: Context, *snowflakes: Snowflake) -> None:          """Get Discord snowflake creation time.""" -        if len(snowflakes) > 1 and await has_no_roles_check(ctx, *STAFF_ROLES): -            raise BadArgument("Cannot process more than one snowflake in one invocation.") -          if not snowflakes:              raise BadArgument("At least one snowflake must be provided.") diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 79725a4b1..6f5addad1 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -13,7 +13,11 @@ async def apply(          if msg.author == last_message.author      ) -    total_recent_mentions = sum(len(msg.mentions) for msg in relevant_messages) +    total_recent_mentions = sum( +        not user.bot +        for msg in relevant_messages +        for user in msg.mentions +    )      if total_recent_mentions > config['max']:          return ( diff --git a/bot/utils/cache.py b/bot/utils/caching.py index 68ce15607..68ce15607 100644 --- a/bot/utils/cache.py +++ b/bot/utils/caching.py diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..f68d280c9 --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,197 @@ +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: +    """ +    A data structure for caching messages. + +    The cache is implemented as a circular buffer to allow constant time append, prepend, pop from either side, +    and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be +    implemented to work in linear time relative to the maximum size). + +    The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message +    is stored, to allow for constant time lookup by message ID. + +    The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those +    of a deque. + +    The implementation is transparent to the user: to the user the first element is always at index 0, and there are +    only as many elements as were inserted (meaning, without any pre-allocated placeholder values). +    """ + +    def __init__(self, maxlen: int, *, newest_first: bool = False): +        if maxlen <= 0: +            raise ValueError("maxlen must be positive") +        self.maxlen = maxlen +        self.newest_first = newest_first + +        self._start = 0 +        self._end = 0 + +        self._messages: list[t.Optional[Message]] = [None] * self.maxlen +        self._message_id_mapping = {} + +    def append(self, message: Message) -> None: +        """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" +        if self.newest_first: +            self._appendleft(message) +        else: +            self._appendright(message) + +    def _appendright(self, message: Message) -> None: +        """Add the received message to the end of the cache.""" +        if self._is_full(): +            del self._message_id_mapping[self._messages[self._start].id] +            self._start = (self._start + 1) % self.maxlen + +        self._messages[self._end] = message +        self._message_id_mapping[message.id] = self._end +        self._end = (self._end + 1) % self.maxlen + +    def _appendleft(self, message: Message) -> None: +        """Add the received message to the beginning of the cache.""" +        if self._is_full(): +            self._end = (self._end - 1) % self.maxlen +            del self._message_id_mapping[self._messages[self._end].id] + +        self._start = (self._start - 1) % self.maxlen +        self._messages[self._start] = message +        self._message_id_mapping[message.id] = self._start + +    def pop(self) -> Message: +        """Remove the last message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        self._end = (self._end - 1) % self.maxlen +        message = self._messages[self._end] +        del self._message_id_mapping[message.id] +        self._messages[self._end] = None + +        return message + +    def popleft(self) -> Message: +        """Return the first message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        message = self._messages[self._start] +        del self._message_id_mapping[message.id] +        self._messages[self._start] = None +        self._start = (self._start + 1) % self.maxlen + +        return message + +    def clear(self) -> None: +        """Remove all messages from the cache.""" +        self._messages = [None] * self.maxlen +        self._message_id_mapping = {} + +        self._start = 0 +        self._end = 0 + +    def get_message(self, message_id: int) -> t.Optional[Message]: +        """Return the message that has the given message ID, if it is cached.""" +        index = self._message_id_mapping.get(message_id, None) +        return self._messages[index] if index is not None else None + +    def update(self, message: Message) -> bool: +        """ +        Update a cached message with new contents. + +        Return True if the given message had a matching ID in the cache. +        """ +        index = self._message_id_mapping.get(message.id, None) +        if index is None: +            return False +        self._messages[index] = message +        return True + +    def __contains__(self, message_id: int) -> bool: +        """Return True if the cache contains a message with the given ID .""" +        return message_id in self._message_id_mapping + +    def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: +        """ +        Return the message(s) in the index or slice provided. + +        This method makes the circular buffer implementation transparent to the user. +        Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, +        meaning at `self._start`. +        """ +        # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when +        # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides +        # by 5 is -5. +        if isinstance(item, int): +            if item >= len(self) or item < -len(self): +                raise IndexError("cache index out of range") +            return self._messages[(item + self._start) % self.maxlen] + +        elif isinstance(item, slice): +            length = len(self) +            start, stop, step = item.indices(length) + +            # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. +            if (start >= stop and step >= 0) or (start <= stop and step <= 0): +                return [] + +            start = (start + self._start) % self.maxlen +            stop = (stop + self._start) % self.maxlen + +            # Having empty cells is an implementation detail. To the user the cache contains as many elements as they +            # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. +            if step > 0: +                if ( +                    (self._start < self._end and not self._start < stop <= self._end) +                    or (self._start > self._end and self._end < stop <= self._start) +                ): +                    stop = self._end +            else: +                lower_boundary = (self._start - 1) % self.maxlen +                if ( +                    (self._start < self._end and not self._start - 1 <= stop < self._end) +                    or (self._start > self._end and self._end < stop < lower_boundary) +                ): +                    stop = lower_boundary + +            if (start < stop and step > 0) or (start > stop and step < 0): +                return self._messages[start:stop:step] +            # step != 1 may require a start offset in the second slicing. +            if step > 0: +                offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen +                return self._messages[start::step] + self._messages[offset:stop:step] +            else: +                offset = ceil((start + 1) / -step) * -step - start - 1 +                return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] + +        else: +            raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + +    def __iter__(self) -> t.Iterator[Message]: +        if self._is_empty(): +            return + +        if self._start < self._end: +            yield from self._messages[self._start:self._end] +        else: +            yield from self._messages[self._start:] +            yield from self._messages[:self._end] + +    def __len__(self): +        """Get the number of non-empty cells in the cache.""" +        if self._is_empty(): +            return 0 +        if self._end > self._start: +            return self._end - self._start +        return self.maxlen - self._start + self._end + +    def _is_empty(self) -> bool: +        """Return True if the cache has no messages.""" +        return self._messages[self._start] is None + +    def _is_full(self) -> bool: +        """Return True if every cell in the cache already contains a message.""" +        return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 79828dd77..8e0b97a51 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,3 +1,6 @@ +debug: !ENV ["BOT_DEBUG", "true"] + +  bot:      prefix:         "!"      sentry_dsn:     !ENV "BOT_SENTRY_DSN" @@ -377,6 +380,8 @@ urls:  anti_spam: +    cache_size: 100 +      # Clean messages that violate a rule.      clean_offending: true      ping_everyone: true diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b9d527770..f844a9181 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):      async def test_voice_unban_user_not_found(self):          """Should include info to return dict when user was not found from guild."""          self.guild.get_member.return_value = None -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {"Info": "User was not found in the guild."})      @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = True          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "Sent" @@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = False          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "**Failed**" diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5f95ced9f..eb256f1fd 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):          test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"])          test_cases = [              test_case([], None, None, True), -            test_case([{"id": 123987}], {"id": 123987}, "123987", False), -            test_case([{"id": 123987}], {"id": 123987}, "123987", True) +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False), +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True)          ]          for case in test_cases: diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index 6444532f2..f8805ac48 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,12 +2,14 @@ from typing import Iterable  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage +from tests.helpers import MockMember, MockMessage -def make_msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage:      """Makes a message with `total_mentions` mentions.""" -    return MockMessage(author=author, mentions=list(range(total_mentions))) +    user_mentions = [MockMember() for _ in range(total_user_mentions)] +    bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] +    return MockMessage(author=author, mentions=user_mentions+bot_mentions)  class TestMentions(RuleTest): @@ -48,11 +50,27 @@ class TestMentions(RuleTest):                  [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",),                  4, -            ) +            ), +            DisallowedCase( +                [make_msg("bob", 3, 1)], +                ("bob",), +                3, +            ),          )          await self.run_disallowed(cases) +    async def test_ignore_bot_mentions(self): +        """Messages with an allowed amount of mentions, also containing bot mentions.""" +        cases = ( +            [make_msg("bob", 0, 3)], +            [make_msg("bob", 2, 1)], +            [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], +            [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] +        ) + +        await self.run_allowed(cases) +      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          last_message = case.recent_messages[0]          return tuple( diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): +    """Tests for the MessageCache class in the `bot.utils.caching` module.""" + +    def test_first_append_sets_the_first_value(self): +        """Test if the first append adds the message to the first cell.""" +        cache = MessageCache(maxlen=10) +        message = MockMessage() + +        cache.append(message) + +        self.assertEqual(cache[0], message) + +    def test_append_adds_in_the_right_order(self): +        """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" +        messages = [MockMessage(), MockMessage()] + +        cache = MessageCache(maxlen=10, newest_first=False) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages, list(cache)) + +        cache = MessageCache(maxlen=10, newest_first=True) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages[::-1], list(cache)) + +    def test_appending_over_maxlen_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages.""" +        cache = MessageCache(maxlen=2) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[1:], list(cache)) + +    def test_appending_over_maxlen_with_newest_first_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" +        cache = MessageCache(maxlen=2, newest_first=True) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[:0:-1], list(cache)) + +    def test_pop_removes_from_the_end(self): +        """Test if a pop removes the right-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.pop() + +        self.assertEqual(msg, messages[-1]) +        self.assertListEqual(messages[:-1], list(cache)) + +    def test_popleft_removes_from_the_beginning(self): +        """Test if a popleft removes the left-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.popleft() + +        self.assertEqual(msg, messages[0]) +        self.assertListEqual(messages[1:], list(cache)) + +    def test_clear(self): +        """Test if a clear makes the cache empty.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        cache.clear() + +        self.assertListEqual(list(cache), []) +        self.assertEqual(len(cache), 0) + +    def test_get_message_returns_the_message(self): +        """Test if get_message returns the cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertEqual(cache.get_message(1234), message) + +    def test_get_message_returns_none(self): +        """Test if get_message returns None for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIsNone(cache.get_message(4321)) + +    def test_update_replaces_old_element(self): +        """Test if an update replaced the old message with the same ID.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) +        message = MockMessage(id=1234) +        cache.update(message) + +        self.assertIs(cache.get_message(1234), message) +        self.assertEqual(len(cache), 1) + +    def test_contains_returns_true_for_cached_message(self): +        """Test if contains returns True for an ID of a cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIn(1234, cache) + +    def test_contains_returns_false_for_non_cached_message(self): +        """Test if contains returns False for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertNotIn(4321, cache) + +    def test_indexing(self): +        """Test if the cache returns the correct messages by index.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(5)] + +        for msg in messages: +            cache.append(msg) + +        for current_loop in range(-5, 5): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(cache[current_loop], messages[current_loop]) + +    def test_bad_index_raises_index_error(self): +        """Test if the cache raises IndexError for invalid indices.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] +        test_cases = (-10, -4, 3, 4, 5) + +        for msg in messages: +            cache.append(msg) + +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with self.assertRaises(IndexError): +                    cache[current_loop] + +    def test_slicing_with_unfilled_cache(self): +        """Test if slicing returns the correct messages if the cache is not yet fully filled.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size // 3 * 2)] + +            for msg in messages: +                cache.append(msg) + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_slicing_with_overfilled_cache(self): +        """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size * 3 // 2)] + +            for msg in messages: +                cache.append(msg) +            messages = messages[size // 2:] + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_length(self): +        """Test if len returns the correct number of items in the cache.""" +        cache = MessageCache(maxlen=5) + +        for current_loop in range(10): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(len(cache), min(current_loop, 5)) +                cache.append(MockMessage()) | 
