aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mbaruh <[email protected]>2022-11-05 22:49:57 +0200
committerGravatar mbaruh <[email protected]>2022-11-05 22:49:57 +0200
commit48a892fa42e5efe971ab69340fc5d731623020e5 (patch)
tree43c549d1d656a69455fda3e77567529a3f279f50
parentHandle threads in channel_scope (diff)
Add message edit filtering
This edit handler takes into account filters already triggered for the message and ignores them (as long as it's a denied type) To that end the message cache can now hold metadata to accompany each message in the cache.
-rw-r--r--bot/exts/filtering/_filter_context.py11
-rw-r--r--bot/exts/filtering/_filter_lists/antispam.py2
-rw-r--r--bot/exts/filtering/_filter_lists/filter_list.py18
-rw-r--r--bot/exts/filtering/_filters/unique/rich_embed.py6
-rw-r--r--bot/exts/filtering/filtering.py55
-rw-r--r--bot/utils/message_cache.py23
6 files changed, 87 insertions, 28 deletions
diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py
index 3227b333a..4a213535a 100644
--- a/bot/exts/filtering/_filter_context.py
+++ b/bot/exts/filtering/_filter_context.py
@@ -7,6 +7,8 @@ from enum import Enum, auto
from discord import DMChannel, Member, Message, TextChannel, Thread, User
+from bot.utils.message_cache import MessageCache
+
if typing.TYPE_CHECKING:
from bot.exts.filtering._filters.filter import Filter
@@ -29,6 +31,8 @@ class FilterContext:
content: str | Iterable # What actually needs filtering
message: Message | None # The message involved
embeds: list = field(default_factory=list) # Any embeds involved
+ before_message: Message | None = None
+ message_cache: MessageCache | None = None
# Output context
dm_content: str = field(default_factory=str) # The content to DM the invoker
dm_embed: str = field(default_factory=str) # The embed description to DM the invoker
@@ -45,6 +49,13 @@ class FilterContext:
related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set)
attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs.
+ @classmethod
+ def from_message(
+ cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None
+ ) -> FilterContext:
+ """Create a filtering context from the attributes of a message."""
+ return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache)
+
def replace(self, **changes) -> FilterContext:
"""Return a new context object assigning new values to the specified fields."""
return replace(self, **changes)
diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py
index ed86c92c0..cf5875723 100644
--- a/bot/exts/filtering/_filter_lists/antispam.py
+++ b/bot/exts/filtering/_filter_lists/antispam.py
@@ -68,7 +68,7 @@ class AntispamList(UniquesListBase):
earliest_relevant_at = arrow.utcnow() - timedelta(seconds=max_interval)
relevant_messages = list(
- takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.filtering_cog.message_cache)
+ takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.message_cache)
)
new_ctx = ctx.replace(content=relevant_messages)
triggers = await sublist.filter_list_result(new_ctx)
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py
index fd243a109..b5d6141d7 100644
--- a/bot/exts/filtering/_filter_lists/filter_list.py
+++ b/bot/exts/filtering/_filter_lists/filter_list.py
@@ -47,7 +47,8 @@ def list_type_converter(argument: str) -> ListType:
raise BadArgument(f"No matching list type found for {argument!r}.")
-@dataclass(frozen=True)
+# AtomicList and its subclasses must have eq=False, otherwise the dataclass deco will replace the hash function.
+@dataclass(frozen=True, eq=False)
class AtomicList:
"""
Represents the atomic structure of a single filter list as it appears in the database.
@@ -84,9 +85,8 @@ class AtomicList:
"""
return await self._create_filter_list_result(ctx, self.defaults, self.filters.values())
- @staticmethod
async def _create_filter_list_result(
- ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter]
+ self, ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter]
) -> list[Filter]:
"""A helper function to evaluate the result of `filter_list_result`."""
passed_by_default, failed_by_default = defaults.validations.evaluate(ctx)
@@ -103,6 +103,13 @@ class AtomicList:
if await filter_.triggered_on(ctx):
relevant_filters.append(filter_)
+ if ctx.event == Event.MESSAGE_EDIT and ctx.message and self.list_type == ListType.DENY:
+ previously_triggered = ctx.message_cache.get_message_metadata(ctx.message.id)
+ if previously_triggered and self in previously_triggered:
+ ignore_filters = previously_triggered[self]
+ # This updates the cache. Some filters are ignored, but they're necessary if there's another edit.
+ previously_triggered[self] = relevant_filters
+ relevant_filters = [filter_ for filter_ in relevant_filters if filter_ not in ignore_filters]
return relevant_filters
def default(self, setting_name: str) -> Any:
@@ -144,6 +151,9 @@ class AtomicList:
messages = [f"#{filter_.id} (`{filter_.content}`)" for filter_ in triggers]
return messages
+ def __hash__(self):
+ return hash(id(self))
+
T = typing.TypeVar("T", bound=Filter)
@@ -220,7 +230,7 @@ class FilterList(dict[ListType, AtomicList], typing.Generic[T], FieldRequiring):
return hash(id(self))
-@dataclass(frozen=True)
+@dataclass(frozen=True, eq=False)
class SubscribingAtomicList(AtomicList):
"""
A base class for a list of unique filters.
diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py
index 09d513373..5c3517e10 100644
--- a/bot/exts/filtering/_filters/unique/rich_embed.py
+++ b/bot/exts/filtering/_filters/unique/rich_embed.py
@@ -20,6 +20,12 @@ class RichEmbedFilter(UniqueFilter):
async def triggered_on(self, ctx: FilterContext) -> bool:
"""Determine if `msg` contains any rich embeds not auto-generated from a URL."""
if ctx.embeds:
+ if ctx.event == Event.MESSAGE_EDIT:
+ # If the edit delta is less than 100 microseconds, it's probably a double filter trigger.
+ delta = ctx.message.edited_at - (ctx.before_message.edited_at or ctx.before_message.created_at)
+ if delta.total_seconds() < 0.0001:
+ return False
+
for embed in ctx.embeds:
if embed.type == "rich":
urls = URL_RE.findall(ctx.content)
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
index c4c118b6f..05b2339b9 100644
--- a/bot/exts/filtering/filtering.py
+++ b/bot/exts/filtering/filtering.py
@@ -180,8 +180,30 @@ class Filtering(Cog):
return
self.message_cache.append(msg)
- ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds)
- result_actions, list_messages, _ = await self._resolve_action(ctx)
+ ctx = FilterContext.from_message(Event.MESSAGE, msg, None, self.message_cache)
+ result_actions, list_messages, triggers = await self._resolve_action(ctx)
+ self.message_cache.update(msg, metadata=triggers)
+ if result_actions:
+ await result_actions.action(ctx)
+ if ctx.send_alert:
+ await self._send_alert(ctx, list_messages)
+
+ @Cog.listener()
+ async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
+ """Filter the contents of an edited message. Don't reinvoke filters already invoked on the `before` version."""
+ # Only check changes to the message contents/attachments and embed additions, not pin status etc.
+ if all((
+ before.content == after.content, # content hasn't changed
+ before.attachments == after.attachments, # attachments haven't changed
+ len(before.embeds) >= len(after.embeds) # embeds haven't been added
+ )):
+ return
+
+ # Update the cache first, it might be used by the antispam filter.
+ # No need to update the triggers, they're going to be updated inside the sublists if necessary.
+ self.message_cache.update(after)
+ ctx = FilterContext.from_message(Event.MESSAGE_EDIT, after, before, self.message_cache)
+ result_actions, list_messages, triggers = await self._resolve_action(ctx)
if result_actions:
await result_actions.action(ctx)
if ctx.send_alert:
@@ -520,21 +542,17 @@ class Filtering(Cog):
raise BadArgument("Please provide input.")
if message:
user = None if no_user else message.author
- filter_ctx = FilterContext(
- Event.MESSAGE, user, message.channel, message.content, message, message.embeds
- )
+ filter_ctx = FilterContext(Event.MESSAGE, user, message.channel, message.content, message, message.embeds)
else:
- filter_ctx = FilterContext(
- Event.MESSAGE, None, ctx.guild.get_channel(Channels.python_general), string, None
- )
+ python_general = ctx.guild.get_channel(Channels.python_general)
+ filter_ctx = FilterContext(Event.MESSAGE, None, python_general, string, None)
_, _, triggers = await self._resolve_action(filter_ctx)
lines = []
- for filter_list, list_triggers in triggers.items():
- for sublist_type, sublist_triggers in list_triggers.items():
- if sublist_triggers:
- triggers_repr = map(str, sublist_triggers)
- lines.extend([f"**{filter_list[sublist_type].label.title()}s**", *triggers_repr, "\n"])
+ for sublist, sublist_triggers in triggers.items():
+ if sublist_triggers:
+ triggers_repr = map(str, sublist_triggers)
+ lines.extend([f"**{sublist.label.title()}s**", *triggers_repr, "\n"])
lines = lines[:-1] # Remove last newline.
embed = Embed(colour=Colour.blue(), title="Match results")
@@ -767,7 +785,7 @@ class Filtering(Cog):
async def _resolve_action(
self, ctx: FilterContext
- ) -> tuple[Optional[ActionSettings], dict[FilterList, list[str]], dict[FilterList, dict[ListType, list[Filter]]]]:
+ ) -> tuple[ActionSettings | None, dict[FilterList, list[str]], dict[AtomicList, list[Filter]]]:
"""
Return the actions that should be taken for all filter lists in the given context.
@@ -778,7 +796,8 @@ class Filtering(Cog):
messages = {}
triggers = {}
for filter_list in self._subscriptions[ctx.event]:
- list_actions, list_message, triggers[filter_list] = await filter_list.actions_for(ctx)
+ list_actions, list_message, list_triggers = await filter_list.actions_for(ctx)
+ triggers.update({filter_list[list_type]: filters for list_type, filters in list_triggers.items()})
if list_actions:
actions.append(list_actions)
if list_message:
@@ -945,11 +964,11 @@ class Filtering(Cog):
"""If the filter is new and applies an auto-infraction, or was edited to apply a different one, log it."""
infraction_type = filter_.overrides[0].get("infraction_type")
if not infraction_type:
- infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type")
+ infraction_type = filter_list[list_type].default("infraction_type")
if old_filter:
old_infraction_type = old_filter.overrides[0].get("infraction_type")
if not old_infraction_type:
- old_infraction_type = filter_list[list_type].defaults.actions.get_setting("infraction_type")
+ old_infraction_type = filter_list[list_type].default("infraction_type")
if infraction_type == old_infraction_type:
return
@@ -1146,7 +1165,7 @@ class Filtering(Cog):
# Extract all auto-infraction filters added in the past 7 days from each filter type
for filter_list in self.filter_lists.values():
for sublist in filter_list.values():
- default_infraction_type = sublist.defaults.actions.get_setting("infraction_type")
+ default_infraction_type = sublist.default("infraction_type")
for filter_ in sublist.filters.values():
if max(filter_.created_at, filter_.updated_at) < seven_days_ago:
continue
diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py
index f68d280c9..3e77e6a50 100644
--- a/bot/utils/message_cache.py
+++ b/bot/utils/message_cache.py
@@ -31,20 +31,23 @@ class MessageCache:
self._start = 0
self._end = 0
- self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._messages: list[Message | None] = [None] * self.maxlen
self._message_id_mapping = {}
+ self._message_metadata = {}
- def append(self, message: Message) -> None:
+ def append(self, message: Message, *, metadata: dict | None = None) -> None:
"""Add the received message to the cache, depending on the order of messages defined by `newest_first`."""
if self.newest_first:
self._appendleft(message)
else:
self._appendright(message)
+ self._message_metadata[message.id] = metadata
def _appendright(self, message: Message) -> None:
"""Add the received message to the end of the cache."""
if self._is_full():
del self._message_id_mapping[self._messages[self._start].id]
+ del self._message_metadata[self._messages[self._start].id]
self._start = (self._start + 1) % self.maxlen
self._messages[self._end] = message
@@ -56,6 +59,7 @@ class MessageCache:
if self._is_full():
self._end = (self._end - 1) % self.maxlen
del self._message_id_mapping[self._messages[self._end].id]
+ del self._message_metadata[self._messages[self._end].id]
self._start = (self._start - 1) % self.maxlen
self._messages[self._start] = message
@@ -69,6 +73,7 @@ class MessageCache:
self._end = (self._end - 1) % self.maxlen
message = self._messages[self._end]
del self._message_id_mapping[message.id]
+ del self._message_metadata[message.id]
self._messages[self._end] = None
return message
@@ -80,6 +85,7 @@ class MessageCache:
message = self._messages[self._start]
del self._message_id_mapping[message.id]
+ del self._message_metadata[message.id]
self._messages[self._start] = None
self._start = (self._start + 1) % self.maxlen
@@ -89,16 +95,21 @@ class MessageCache:
"""Remove all messages from the cache."""
self._messages = [None] * self.maxlen
self._message_id_mapping = {}
+ self._message_metadata = {}
self._start = 0
self._end = 0
- def get_message(self, message_id: int) -> t.Optional[Message]:
+ def get_message(self, message_id: int) -> Message | None:
"""Return the message that has the given message ID, if it is cached."""
index = self._message_id_mapping.get(message_id, None)
return self._messages[index] if index is not None else None
- def update(self, message: Message) -> bool:
+ def get_message_metadata(self, message_id: int) -> dict:
+ """Return the metadata of the message that has the given message ID, if it is cached."""
+ return self._message_metadata.get(message_id, None)
+
+ def update(self, message: Message, *, metadata: dict | None = None) -> bool:
"""
Update a cached message with new contents.
@@ -108,13 +119,15 @@ class MessageCache:
if index is None:
return False
self._messages[index] = message
+ if metadata is not None:
+ self._message_metadata[message.id] = metadata
return True
def __contains__(self, message_id: int) -> bool:
"""Return True if the cache contains a message with the given ID ."""
return message_id in self._message_id_mapping
- def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]:
+ def __getitem__(self, item: int | slice) -> Message | list[Message]:
"""
Return the message(s) in the index or slice provided.