diff options
| -rw-r--r-- | bot/exts/filtering/_filter_context.py | 11 | ||||
| -rw-r--r-- | bot/exts/filtering/_filter_lists/antispam.py | 2 | ||||
| -rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 18 | ||||
| -rw-r--r-- | bot/exts/filtering/_filters/unique/rich_embed.py | 6 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 55 | ||||
| -rw-r--r-- | bot/utils/message_cache.py | 23 | 
6 files changed, 87 insertions, 28 deletions
diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py index 3227b333a..4a213535a 100644 --- a/bot/exts/filtering/_filter_context.py +++ b/bot/exts/filtering/_filter_context.py @@ -7,6 +7,8 @@ from enum import Enum, auto  from discord import DMChannel, Member, Message, TextChannel, Thread, User +from bot.utils.message_cache import MessageCache +  if typing.TYPE_CHECKING:      from bot.exts.filtering._filters.filter import Filter @@ -29,6 +31,8 @@ class FilterContext:      content: str | Iterable  # What actually needs filtering      message: Message | None  # The message involved      embeds: list = field(default_factory=list)  # Any embeds involved +    before_message: Message | None = None +    message_cache: MessageCache | None = None      # Output context      dm_content: str = field(default_factory=str)  # The content to DM the invoker      dm_embed: str = field(default_factory=str)  # The embed description to DM the invoker @@ -45,6 +49,13 @@ class FilterContext:      related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set)      attachments: dict[int, list[str]] = field(default_factory=dict)  # Message ID to attachment URLs. +    @classmethod +    def from_message( +        cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None +    ) -> FilterContext: +        """Create a filtering context from the attributes of a message.""" +        return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache) +      def replace(self, **changes) -> FilterContext:          """Return a new context object assigning new values to the specified fields."""          return replace(self, **changes) diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py index ed86c92c0..cf5875723 100644 --- a/bot/exts/filtering/_filter_lists/antispam.py +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -68,7 +68,7 @@ class AntispamList(UniquesListBase):          earliest_relevant_at = arrow.utcnow() - timedelta(seconds=max_interval)          relevant_messages = list( -            takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.filtering_cog.message_cache) +            takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.message_cache)          )          new_ctx = ctx.replace(content=relevant_messages)          triggers = await sublist.filter_list_result(new_ctx) diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index fd243a109..b5d6141d7 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -47,7 +47,8 @@ def list_type_converter(argument: str) -> ListType:      raise BadArgument(f"No matching list type found for {argument!r}.") -@dataclass(frozen=True) +# AtomicList and its subclasses must have eq=False, otherwise the dataclass deco will replace the hash function. +@dataclass(frozen=True, eq=False)  class AtomicList:      """      Represents the atomic structure of a single filter list as it appears in the database. @@ -84,9 +85,8 @@ class AtomicList:          """          return await self._create_filter_list_result(ctx, self.defaults, self.filters.values()) -    @staticmethod      async def _create_filter_list_result( -        ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter] +        self, ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter]      ) -> list[Filter]:          """A helper function to evaluate the result of `filter_list_result`."""          passed_by_default, failed_by_default = defaults.validations.evaluate(ctx) @@ -103,6 +103,13 @@ class AtomicList:                      if await filter_.triggered_on(ctx):                          relevant_filters.append(filter_) +        if ctx.event == Event.MESSAGE_EDIT and ctx.message and self.list_type == ListType.DENY: +            previously_triggered = ctx.message_cache.get_message_metadata(ctx.message.id) +            if previously_triggered and self in previously_triggered: +                ignore_filters = previously_triggered[self] +                # This updates the cache. Some filters are ignored, but they're necessary if there's another edit. +                previously_triggered[self] = relevant_filters +                relevant_filters = [filter_ for filter_ in relevant_filters if filter_ not in ignore_filters]          return relevant_filters      def default(self, setting_name: str) -> Any: @@ -144,6 +151,9 @@ class AtomicList:              messages = [f"#{filter_.id} (`{filter_.content}`)" for filter_ in triggers]          return messages +    def __hash__(self): +        return hash(id(self)) +  T = typing.TypeVar("T", bound=Filter) @@ -220,7 +230,7 @@ class FilterList(dict[ListType, AtomicList], typing.Generic[T], FieldRequiring):          return hash(id(self)) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False)  class SubscribingAtomicList(AtomicList):      """      A base class for a list of unique filters. diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py index 09d513373..5c3517e10 100644 --- a/bot/exts/filtering/_filters/unique/rich_embed.py +++ b/bot/exts/filtering/_filters/unique/rich_embed.py @@ -20,6 +20,12 @@ class RichEmbedFilter(UniqueFilter):      async def triggered_on(self, ctx: FilterContext) -> bool:          """Determine if `msg` contains any rich embeds not auto-generated from a URL."""          if ctx.embeds: +            if ctx.event == Event.MESSAGE_EDIT: +                # If the edit delta is less than 100 microseconds, it's probably a double filter trigger. +                delta = ctx.message.edited_at - (ctx.before_message.edited_at or ctx.before_message.created_at) +                if delta.total_seconds() < 0.0001: +                    return False +              for embed in ctx.embeds:                  if embed.type == "rich":                      urls = URL_RE.findall(ctx.content) diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index c4c118b6f..05b2339b9 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -180,8 +180,30 @@ class Filtering(Cog):              return          self.message_cache.append(msg) -        ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) -        result_actions, list_messages, _ = await self._resolve_action(ctx) +        ctx = FilterContext.from_message(Event.MESSAGE, msg, None, self.message_cache) +        result_actions, list_messages, triggers = await self._resolve_action(ctx) +        self.message_cache.update(msg, metadata=triggers) +        if result_actions: +            await result_actions.action(ctx) +        if ctx.send_alert: +            await self._send_alert(ctx, list_messages) + +    @Cog.listener() +    async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: +        """Filter the contents of an edited message. Don't reinvoke filters already invoked on the `before` version.""" +        # Only check changes to the message contents/attachments and embed additions, not pin status etc. +        if all(( +            before.content == after.content,  # content hasn't changed +            before.attachments == after.attachments,  # attachments haven't changed +            len(before.embeds) >= len(after.embeds)  # embeds haven't been added +        )): +            return + +        # Update the cache first, it might be used by the antispam filter. +        # No need to update the triggers, they're going to be updated inside the sublists if necessary. +        self.message_cache.update(after) +        ctx = FilterContext.from_message(Event.MESSAGE_EDIT, after, before, self.message_cache) +        result_actions, list_messages, triggers = await self._resolve_action(ctx)          if result_actions:              await result_actions.action(ctx)          if ctx.send_alert: @@ -520,21 +542,17 @@ class Filtering(Cog):              raise BadArgument("Please provide input.")          if message:              user = None if no_user else message.author -            filter_ctx = FilterContext( -                Event.MESSAGE, user, message.channel, message.content, message, message.embeds -            ) +            filter_ctx = FilterContext(Event.MESSAGE, user, message.channel, message.content, message, message.embeds)          else: -            filter_ctx = FilterContext( -                Event.MESSAGE, None, ctx.guild.get_channel(Channels.python_general), string, None -            ) +            python_general = ctx.guild.get_channel(Channels.python_general) +            filter_ctx = FilterContext(Event.MESSAGE, None, python_general, string, None)          _, _, triggers = await self._resolve_action(filter_ctx)          lines = [] -        for filter_list, list_triggers in triggers.items(): -            for sublist_type, sublist_triggers in list_triggers.items(): -                if sublist_triggers: -                    triggers_repr = map(str, sublist_triggers) -                    lines.extend([f"**{filter_list[sublist_type].label.title()}s**", *triggers_repr, "\n"]) +        for sublist, sublist_triggers in triggers.items(): +            if sublist_triggers: +                triggers_repr = map(str, sublist_triggers) +                lines.extend([f"**{sublist.label.title()}s**", *triggers_repr, "\n"])          lines = lines[:-1]  # Remove last newline.          embed = Embed(colour=Colour.blue(), title="Match results") @@ -767,7 +785,7 @@ class Filtering(Cog):      async def _resolve_action(          self, ctx: FilterContext -    ) -> tuple[Optional[ActionSettings], dict[FilterList, list[str]], dict[FilterList, dict[ListType, list[Filter]]]]: +    ) -> tuple[ActionSettings | None, dict[FilterList, list[str]], dict[AtomicList, list[Filter]]]:          """          Return the actions that should be taken for all filter lists in the given context. @@ -778,7 +796,8 @@ class Filtering(Cog):          messages = {}          triggers = {}          for filter_list in self._subscriptions[ctx.event]: -            list_actions, list_message, triggers[filter_list] = await filter_list.actions_for(ctx) +            list_actions, list_message, list_triggers = await filter_list.actions_for(ctx) +            triggers.update({filter_list[list_type]: filters for list_type, filters in list_triggers.items()})              if list_actions:                  actions.append(list_actions)              if list_message: @@ -945,11 +964,11 @@ class Filtering(Cog):          """If the filter is new and applies an auto-infraction, or was edited to apply a different one, log it."""          infraction_type = filter_.overrides[0].get("infraction_type")          if not infraction_type: -            infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type") +            infraction_type = filter_list[list_type].default("infraction_type")          if old_filter:              old_infraction_type = old_filter.overrides[0].get("infraction_type")              if not old_infraction_type: -                old_infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type") +                old_infraction_type = filter_list[list_type].default("infraction_type")              if infraction_type == old_infraction_type:                  return @@ -1146,7 +1165,7 @@ class Filtering(Cog):          # Extract all auto-infraction filters added in the past 7 days from each filter type          for filter_list in self.filter_lists.values():              for sublist in filter_list.values(): -                default_infraction_type = sublist.defaults.actions.get_setting("infraction_type") +                default_infraction_type = sublist.default("infraction_type")                  for filter_ in sublist.filters.values():                      if max(filter_.created_at, filter_.updated_at) < seven_days_ago:                          continue diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index f68d280c9..3e77e6a50 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -31,20 +31,23 @@ class MessageCache:          self._start = 0          self._end = 0 -        self._messages: list[t.Optional[Message]] = [None] * self.maxlen +        self._messages: list[Message | None] = [None] * self.maxlen          self._message_id_mapping = {} +        self._message_metadata = {} -    def append(self, message: Message) -> None: +    def append(self, message: Message, *, metadata: dict | None = None) -> None:          """Add the received message to the cache, depending on the order of messages defined by `newest_first`."""          if self.newest_first:              self._appendleft(message)          else:              self._appendright(message) +        self._message_metadata[message.id] = metadata      def _appendright(self, message: Message) -> None:          """Add the received message to the end of the cache."""          if self._is_full():              del self._message_id_mapping[self._messages[self._start].id] +            del self._message_metadata[self._messages[self._start].id]              self._start = (self._start + 1) % self.maxlen          self._messages[self._end] = message @@ -56,6 +59,7 @@ class MessageCache:          if self._is_full():              self._end = (self._end - 1) % self.maxlen              del self._message_id_mapping[self._messages[self._end].id] +            del self._message_metadata[self._messages[self._end].id]          self._start = (self._start - 1) % self.maxlen          self._messages[self._start] = message @@ -69,6 +73,7 @@ class MessageCache:          self._end = (self._end - 1) % self.maxlen          message = self._messages[self._end]          del self._message_id_mapping[message.id] +        del self._message_metadata[message.id]          self._messages[self._end] = None          return message @@ -80,6 +85,7 @@ class MessageCache:          message = self._messages[self._start]          del self._message_id_mapping[message.id] +        del self._message_metadata[message.id]          self._messages[self._start] = None          self._start = (self._start + 1) % self.maxlen @@ -89,16 +95,21 @@ class MessageCache:          """Remove all messages from the cache."""          self._messages = [None] * self.maxlen          self._message_id_mapping = {} +        self._message_metadata = {}          self._start = 0          self._end = 0 -    def get_message(self, message_id: int) -> t.Optional[Message]: +    def get_message(self, message_id: int) -> Message | None:          """Return the message that has the given message ID, if it is cached."""          index = self._message_id_mapping.get(message_id, None)          return self._messages[index] if index is not None else None -    def update(self, message: Message) -> bool: +    def get_message_metadata(self, message_id: int) -> dict: +        """Return the metadata of the message that has the given message ID, if it is cached.""" +        return self._message_metadata.get(message_id, None) + +    def update(self, message: Message, *, metadata: dict | None = None) -> bool:          """          Update a cached message with new contents. @@ -108,13 +119,15 @@ class MessageCache:          if index is None:              return False          self._messages[index] = message +        if metadata is not None: +            self._message_metadata[message.id] = metadata          return True      def __contains__(self, message_id: int) -> bool:          """Return True if the cache contains a message with the given ID ."""          return message_id in self._message_id_mapping -    def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: +    def __getitem__(self, item: int | slice) -> Message | list[Message]:          """          Return the message(s) in the index or slice provided.  |