diff options
author | 2022-11-05 22:49:57 +0200 | |
---|---|---|
committer | 2022-11-05 22:49:57 +0200 | |
commit | 48a892fa42e5efe971ab69340fc5d731623020e5 (patch) | |
tree | 43c549d1d656a69455fda3e77567529a3f279f50 | |
parent | Handle threads in channel_scope (diff) |
Add message edit filtering
This edit handler takes into account filters already triggered for the message and ignores them (as long as it's a denied type)
To that end the message cache can now hold metadata to accompany each message in the cache.
-rw-r--r-- | bot/exts/filtering/_filter_context.py | 11 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/antispam.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 18 | ||||
-rw-r--r-- | bot/exts/filtering/_filters/unique/rich_embed.py | 6 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 55 | ||||
-rw-r--r-- | bot/utils/message_cache.py | 23 |
6 files changed, 87 insertions, 28 deletions
diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py index 3227b333a..4a213535a 100644 --- a/bot/exts/filtering/_filter_context.py +++ b/bot/exts/filtering/_filter_context.py @@ -7,6 +7,8 @@ from enum import Enum, auto from discord import DMChannel, Member, Message, TextChannel, Thread, User +from bot.utils.message_cache import MessageCache + if typing.TYPE_CHECKING: from bot.exts.filtering._filters.filter import Filter @@ -29,6 +31,8 @@ class FilterContext: content: str | Iterable # What actually needs filtering message: Message | None # The message involved embeds: list = field(default_factory=list) # Any embeds involved + before_message: Message | None = None + message_cache: MessageCache | None = None # Output context dm_content: str = field(default_factory=str) # The content to DM the invoker dm_embed: str = field(default_factory=str) # The embed description to DM the invoker @@ -45,6 +49,13 @@ class FilterContext: related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set) attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs. + @classmethod + def from_message( + cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None + ) -> FilterContext: + """Create a filtering context from the attributes of a message.""" + return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache) + def replace(self, **changes) -> FilterContext: """Return a new context object assigning new values to the specified fields.""" return replace(self, **changes) diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py index ed86c92c0..cf5875723 100644 --- a/bot/exts/filtering/_filter_lists/antispam.py +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -68,7 +68,7 @@ class AntispamList(UniquesListBase): earliest_relevant_at = arrow.utcnow() - timedelta(seconds=max_interval) relevant_messages = list( - takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.filtering_cog.message_cache) + takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.message_cache) ) new_ctx = ctx.replace(content=relevant_messages) triggers = await sublist.filter_list_result(new_ctx) diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index fd243a109..b5d6141d7 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -47,7 +47,8 @@ def list_type_converter(argument: str) -> ListType: raise BadArgument(f"No matching list type found for {argument!r}.") -@dataclass(frozen=True) +# AtomicList and its subclasses must have eq=False, otherwise the dataclass deco will replace the hash function. +@dataclass(frozen=True, eq=False) class AtomicList: """ Represents the atomic structure of a single filter list as it appears in the database. @@ -84,9 +85,8 @@ class AtomicList: """ return await self._create_filter_list_result(ctx, self.defaults, self.filters.values()) - @staticmethod async def _create_filter_list_result( - ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter] + self, ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter] ) -> list[Filter]: """A helper function to evaluate the result of `filter_list_result`.""" passed_by_default, failed_by_default = defaults.validations.evaluate(ctx) @@ -103,6 +103,13 @@ class AtomicList: if await filter_.triggered_on(ctx): relevant_filters.append(filter_) + if ctx.event == Event.MESSAGE_EDIT and ctx.message and self.list_type == ListType.DENY: + previously_triggered = ctx.message_cache.get_message_metadata(ctx.message.id) + if previously_triggered and self in previously_triggered: + ignore_filters = previously_triggered[self] + # This updates the cache. Some filters are ignored, but they're necessary if there's another edit. + previously_triggered[self] = relevant_filters + relevant_filters = [filter_ for filter_ in relevant_filters if filter_ not in ignore_filters] return relevant_filters def default(self, setting_name: str) -> Any: @@ -144,6 +151,9 @@ class AtomicList: messages = [f"#{filter_.id} (`{filter_.content}`)" for filter_ in triggers] return messages + def __hash__(self): + return hash(id(self)) + T = typing.TypeVar("T", bound=Filter) @@ -220,7 +230,7 @@ class FilterList(dict[ListType, AtomicList], typing.Generic[T], FieldRequiring): return hash(id(self)) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SubscribingAtomicList(AtomicList): """ A base class for a list of unique filters. diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py index 09d513373..5c3517e10 100644 --- a/bot/exts/filtering/_filters/unique/rich_embed.py +++ b/bot/exts/filtering/_filters/unique/rich_embed.py @@ -20,6 +20,12 @@ class RichEmbedFilter(UniqueFilter): async def triggered_on(self, ctx: FilterContext) -> bool: """Determine if `msg` contains any rich embeds not auto-generated from a URL.""" if ctx.embeds: + if ctx.event == Event.MESSAGE_EDIT: + # If the edit delta is less than 100 microseconds, it's probably a double filter trigger. + delta = ctx.message.edited_at - (ctx.before_message.edited_at or ctx.before_message.created_at) + if delta.total_seconds() < 0.0001: + return False + for embed in ctx.embeds: if embed.type == "rich": urls = URL_RE.findall(ctx.content) diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index c4c118b6f..05b2339b9 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -180,8 +180,30 @@ class Filtering(Cog): return self.message_cache.append(msg) - ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) - result_actions, list_messages, _ = await self._resolve_action(ctx) + ctx = FilterContext.from_message(Event.MESSAGE, msg, None, self.message_cache) + result_actions, list_messages, triggers = await self._resolve_action(ctx) + self.message_cache.update(msg, metadata=triggers) + if result_actions: + await result_actions.action(ctx) + if ctx.send_alert: + await self._send_alert(ctx, list_messages) + + @Cog.listener() + async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: + """Filter the contents of an edited message. Don't reinvoke filters already invoked on the `before` version.""" + # Only check changes to the message contents/attachments and embed additions, not pin status etc. + if all(( + before.content == after.content, # content hasn't changed + before.attachments == after.attachments, # attachments haven't changed + len(before.embeds) >= len(after.embeds) # embeds haven't been added + )): + return + + # Update the cache first, it might be used by the antispam filter. + # No need to update the triggers, they're going to be updated inside the sublists if necessary. + self.message_cache.update(after) + ctx = FilterContext.from_message(Event.MESSAGE_EDIT, after, before, self.message_cache) + result_actions, list_messages, triggers = await self._resolve_action(ctx) if result_actions: await result_actions.action(ctx) if ctx.send_alert: @@ -520,21 +542,17 @@ class Filtering(Cog): raise BadArgument("Please provide input.") if message: user = None if no_user else message.author - filter_ctx = FilterContext( - Event.MESSAGE, user, message.channel, message.content, message, message.embeds - ) + filter_ctx = FilterContext(Event.MESSAGE, user, message.channel, message.content, message, message.embeds) else: - filter_ctx = FilterContext( - Event.MESSAGE, None, ctx.guild.get_channel(Channels.python_general), string, None - ) + python_general = ctx.guild.get_channel(Channels.python_general) + filter_ctx = FilterContext(Event.MESSAGE, None, python_general, string, None) _, _, triggers = await self._resolve_action(filter_ctx) lines = [] - for filter_list, list_triggers in triggers.items(): - for sublist_type, sublist_triggers in list_triggers.items(): - if sublist_triggers: - triggers_repr = map(str, sublist_triggers) - lines.extend([f"**{filter_list[sublist_type].label.title()}s**", *triggers_repr, "\n"]) + for sublist, sublist_triggers in triggers.items(): + if sublist_triggers: + triggers_repr = map(str, sublist_triggers) + lines.extend([f"**{sublist.label.title()}s**", *triggers_repr, "\n"]) lines = lines[:-1] # Remove last newline. embed = Embed(colour=Colour.blue(), title="Match results") @@ -767,7 +785,7 @@ class Filtering(Cog): async def _resolve_action( self, ctx: FilterContext - ) -> tuple[Optional[ActionSettings], dict[FilterList, list[str]], dict[FilterList, dict[ListType, list[Filter]]]]: + ) -> tuple[ActionSettings | None, dict[FilterList, list[str]], dict[AtomicList, list[Filter]]]: """ Return the actions that should be taken for all filter lists in the given context. @@ -778,7 +796,8 @@ class Filtering(Cog): messages = {} triggers = {} for filter_list in self._subscriptions[ctx.event]: - list_actions, list_message, triggers[filter_list] = await filter_list.actions_for(ctx) + list_actions, list_message, list_triggers = await filter_list.actions_for(ctx) + triggers.update({filter_list[list_type]: filters for list_type, filters in list_triggers.items()}) if list_actions: actions.append(list_actions) if list_message: @@ -945,11 +964,11 @@ class Filtering(Cog): """If the filter is new and applies an auto-infraction, or was edited to apply a different one, log it.""" infraction_type = filter_.overrides[0].get("infraction_type") if not infraction_type: - infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type") + infraction_type = filter_list[list_type].default("infraction_type") if old_filter: old_infraction_type = old_filter.overrides[0].get("infraction_type") if not old_infraction_type: - old_infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type") + old_infraction_type = filter_list[list_type].default("infraction_type") if infraction_type == old_infraction_type: return @@ -1146,7 +1165,7 @@ class Filtering(Cog): # Extract all auto-infraction filters added in the past 7 days from each filter type for filter_list in self.filter_lists.values(): for sublist in filter_list.values(): - default_infraction_type = sublist.defaults.actions.get_setting("infraction_type") + default_infraction_type = sublist.default("infraction_type") for filter_ in sublist.filters.values(): if max(filter_.created_at, filter_.updated_at) < seven_days_ago: continue diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index f68d280c9..3e77e6a50 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -31,20 +31,23 @@ class MessageCache: self._start = 0 self._end = 0 - self._messages: list[t.Optional[Message]] = [None] * self.maxlen + self._messages: list[Message | None] = [None] * self.maxlen self._message_id_mapping = {} + self._message_metadata = {} - def append(self, message: Message) -> None: + def append(self, message: Message, *, metadata: dict | None = None) -> None: """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" if self.newest_first: self._appendleft(message) else: self._appendright(message) + self._message_metadata[message.id] = metadata def _appendright(self, message: Message) -> None: """Add the received message to the end of the cache.""" if self._is_full(): del self._message_id_mapping[self._messages[self._start].id] + del self._message_metadata[self._messages[self._start].id] self._start = (self._start + 1) % self.maxlen self._messages[self._end] = message @@ -56,6 +59,7 @@ class MessageCache: if self._is_full(): self._end = (self._end - 1) % self.maxlen del self._message_id_mapping[self._messages[self._end].id] + del self._message_metadata[self._messages[self._end].id] self._start = (self._start - 1) % self.maxlen self._messages[self._start] = message @@ -69,6 +73,7 @@ class MessageCache: self._end = (self._end - 1) % self.maxlen message = self._messages[self._end] del self._message_id_mapping[message.id] + del self._message_metadata[message.id] self._messages[self._end] = None return message @@ -80,6 +85,7 @@ class MessageCache: message = self._messages[self._start] del self._message_id_mapping[message.id] + del self._message_metadata[message.id] self._messages[self._start] = None self._start = (self._start + 1) % self.maxlen @@ -89,16 +95,21 @@ class MessageCache: """Remove all messages from the cache.""" self._messages = [None] * self.maxlen self._message_id_mapping = {} + self._message_metadata = {} self._start = 0 self._end = 0 - def get_message(self, message_id: int) -> t.Optional[Message]: + def get_message(self, message_id: int) -> Message | None: """Return the message that has the given message ID, if it is cached.""" index = self._message_id_mapping.get(message_id, None) return self._messages[index] if index is not None else None - def update(self, message: Message) -> bool: + def get_message_metadata(self, message_id: int) -> dict: + """Return the metadata of the message that has the given message ID, if it is cached.""" + return self._message_metadata.get(message_id, None) + + def update(self, message: Message, *, metadata: dict | None = None) -> bool: """ Update a cached message with new contents. @@ -108,13 +119,15 @@ class MessageCache: if index is None: return False self._messages[index] = message + if metadata is not None: + self._message_metadata[message.id] = metadata return True def __contains__(self, message_id: int) -> bool: """Return True if the cache contains a message with the given ID .""" return message_id in self._message_id_mapping - def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: + def __getitem__(self, item: int | slice) -> Message | list[Message]: """ Return the message(s) in the index or slice provided. |