aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/CODEOWNERS6
-rw-r--r--bot/bot.py26
-rw-r--r--bot/constants.py78
-rw-r--r--bot/converters.py50
-rw-r--r--bot/exts/filtering/FILTERS-DEVELOPMENT.md63
-rw-r--r--bot/exts/filtering/__init__.py (renamed from bot/exts/filters/__init__.py)0
-rw-r--r--bot/exts/filtering/_filter_context.py64
-rw-r--r--bot/exts/filtering/_filter_lists/__init__.py9
-rw-r--r--bot/exts/filtering/_filter_lists/antispam.py191
-rw-r--r--bot/exts/filtering/_filter_lists/domain.py64
-rw-r--r--bot/exts/filtering/_filter_lists/extension.py116
-rw-r--r--bot/exts/filtering/_filter_lists/filter_list.py308
-rw-r--r--bot/exts/filtering/_filter_lists/invite.py154
-rw-r--r--bot/exts/filtering/_filter_lists/token.py72
-rw-r--r--bot/exts/filtering/_filter_lists/unique.py39
-rw-r--r--bot/exts/filtering/_filters/__init__.py (renamed from tests/bot/exts/filters/__init__.py)0
-rw-r--r--bot/exts/filtering/_filters/antispam/__init__.py9
-rw-r--r--bot/exts/filtering/_filters/antispam/attachments.py43
-rw-r--r--bot/exts/filtering/_filters/antispam/burst.py41
-rw-r--r--bot/exts/filtering/_filters/antispam/chars.py43
-rw-r--r--bot/exts/filtering/_filters/antispam/duplicates.py44
-rw-r--r--bot/exts/filtering/_filters/antispam/emoji.py53
-rw-r--r--bot/exts/filtering/_filters/antispam/links.py52
-rw-r--r--bot/exts/filtering/_filters/antispam/mentions.py90
-rw-r--r--bot/exts/filtering/_filters/antispam/newlines.py61
-rw-r--r--bot/exts/filtering/_filters/antispam/role_mentions.py42
-rw-r--r--bot/exts/filtering/_filters/domain.py62
-rw-r--r--bot/exts/filtering/_filters/extension.py27
-rw-r--r--bot/exts/filtering/_filters/filter.py94
-rw-r--r--bot/exts/filtering/_filters/invite.py48
-rw-r--r--bot/exts/filtering/_filters/token.py35
-rw-r--r--bot/exts/filtering/_filters/unique/__init__.py9
-rw-r--r--bot/exts/filtering/_filters/unique/discord_token.py (renamed from bot/exts/filters/token_remover.py)182
-rw-r--r--bot/exts/filtering/_filters/unique/everyone.py28
-rw-r--r--bot/exts/filtering/_filters/unique/rich_embed.py51
-rw-r--r--bot/exts/filtering/_filters/unique/webhook.py63
-rw-r--r--bot/exts/filtering/_settings.py233
-rw-r--r--bot/exts/filtering/_settings_types/__init__.py9
-rw-r--r--bot/exts/filtering/_settings_types/actions/__init__.py8
-rw-r--r--bot/exts/filtering/_settings_types/actions/infraction_and_notification.py204
-rw-r--r--bot/exts/filtering/_settings_types/actions/ping.py45
-rw-r--r--bot/exts/filtering/_settings_types/actions/remove_context.py113
-rw-r--r--bot/exts/filtering/_settings_types/actions/send_alert.py23
-rw-r--r--bot/exts/filtering/_settings_types/settings_entry.py90
-rw-r--r--bot/exts/filtering/_settings_types/validations/__init__.py8
-rw-r--r--bot/exts/filtering/_settings_types/validations/bypass_roles.py24
-rw-r--r--bot/exts/filtering/_settings_types/validations/channel_scope.py70
-rw-r--r--bot/exts/filtering/_settings_types/validations/enabled.py19
-rw-r--r--bot/exts/filtering/_settings_types/validations/filter_dm.py20
-rw-r--r--bot/exts/filtering/_ui/__init__.py0
-rw-r--r--bot/exts/filtering/_ui/filter.py464
-rw-r--r--bot/exts/filtering/_ui/filter_list.py271
-rw-r--r--bot/exts/filtering/_ui/search.py365
-rw-r--r--bot/exts/filtering/_ui/ui.py565
-rw-r--r--bot/exts/filtering/_utils.py224
-rw-r--r--bot/exts/filtering/filtering.py1424
-rw-r--r--bot/exts/filters/antimalware.py106
-rw-r--r--bot/exts/filters/antispam.py326
-rw-r--r--bot/exts/filters/filter_lists.py359
-rw-r--r--bot/exts/filters/filtering.py743
-rw-r--r--bot/exts/filters/webhook_remover.py94
-rw-r--r--bot/exts/info/codeblock/_cog.py6
-rw-r--r--bot/exts/moderation/clean.py3
-rw-r--r--bot/exts/moderation/infraction/infractions.py22
-rw-r--r--bot/exts/moderation/modlog.py51
-rw-r--r--bot/exts/moderation/watchchannels/_watchchannel.py6
-rw-r--r--bot/exts/utils/snekbox/_cog.py8
-rw-r--r--bot/pagination.py9
-rw-r--r--bot/rules/__init__.py12
-rw-r--r--bot/rules/attachments.py26
-rw-r--r--bot/rules/burst.py23
-rw-r--r--bot/rules/burst_shared.py18
-rw-r--r--bot/rules/chars.py24
-rw-r--r--bot/rules/discord_emojis.py34
-rw-r--r--bot/rules/duplicates.py28
-rw-r--r--bot/rules/links.py36
-rw-r--r--bot/rules/newlines.py45
-rw-r--r--bot/rules/role_mentions.py24
-rw-r--r--bot/utils/message_cache.py23
-rw-r--r--bot/utils/messages.py58
-rw-r--r--tests/bot/exts/filtering/__init__.py0
-rw-r--r--tests/bot/exts/filtering/test_discord_token_filter.py276
-rw-r--r--tests/bot/exts/filtering/test_extension_filter.py139
-rw-r--r--tests/bot/exts/filtering/test_settings.py20
-rw-r--r--tests/bot/exts/filtering/test_settings_entries.py216
-rw-r--r--tests/bot/exts/filtering/test_token_filter.py49
-rw-r--r--tests/bot/exts/filters/test_antimalware.py202
-rw-r--r--tests/bot/exts/filters/test_antispam.py35
-rw-r--r--tests/bot/exts/filters/test_filtering.py40
-rw-r--r--tests/bot/exts/filters/test_token_remover.py409
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py69
-rw-r--r--tests/bot/rules/test_burst.py54
-rw-r--r--tests/bot/rules/test_burst_shared.py57
-rw-r--r--tests/bot/rules/test_chars.py64
-rw-r--r--tests/bot/rules/test_discord_emojis.py73
-rw-r--r--tests/bot/rules/test_duplicates.py64
-rw-r--r--tests/bot/rules/test_links.py67
-rw-r--r--tests/bot/rules/test_mentions.py131
-rw-r--r--tests/bot/rules/test_newlines.py102
-rw-r--r--tests/bot/rules/test_role_mentions.py55
-rw-r--r--tests/helpers.py8
102 files changed, 7059 insertions, 3727 deletions
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 7cd00a0d6..816bdf290 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1,6 +1,5 @@
# Extensions
**/bot/exts/backend/sync/** @MarkKoz
-**/bot/exts/filters/*token_remover.py @MarkKoz
**/bot/exts/moderation/*silence.py @MarkKoz
bot/exts/info/codeblock/** @MarkKoz
bot/exts/utils/extensions.py @MarkKoz
@@ -8,14 +7,11 @@ bot/exts/utils/snekbox.py @MarkKoz @jb3
bot/exts/moderation/** @mbaruh @Den4200 @ks129 @jb3
bot/exts/info/** @Den4200 @jb3
bot/exts/info/information.py @mbaruh @jb3
-bot/exts/filters/** @mbaruh @jb3
+bot/exts/filtering/** @mbaruh
bot/exts/fun/** @ks129
bot/exts/utils/** @ks129 @jb3
bot/exts/recruitment/** @wookie184
-# Rules
-bot/rules/** @mbaruh
-
# Utils
bot/utils/function.py @MarkKoz
bot/utils/lock.py @MarkKoz
diff --git a/bot/bot.py b/bot/bot.py
index 6164ba9fd..f56aec38e 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,5 +1,4 @@
import asyncio
-from collections import defaultdict
import aiohttp
from pydis_core import BotBase
@@ -27,8 +26,6 @@ class Bot(BotBase):
super().__init__(*args, **kwargs)
- self.filter_list_cache = defaultdict(dict)
-
async def ping_services(self) -> None:
"""A helper to make sure all the services the bot relies on are available on startup."""
# Connect Site/API
@@ -45,33 +42,10 @@ class Bot(BotBase):
raise
await asyncio.sleep(constants.URLs.connect_cooldown)
- def insert_item_into_filter_list_cache(self, item: dict[str, str]) -> None:
- """Add an item to the bots filter_list_cache."""
- type_ = item["type"]
- allowed = item["allowed"]
- content = item["content"]
-
- self.filter_list_cache[f"{type_}.{allowed}"][content] = {
- "id": item["id"],
- "comment": item["comment"],
- "created_at": item["created_at"],
- "updated_at": item["updated_at"],
- }
-
- async def cache_filter_list_data(self) -> None:
- """Cache all the data in the FilterList on the site."""
- full_cache = await self.api_client.get('bot/filter-lists')
-
- for item in full_cache:
- self.insert_item_into_filter_list_cache(item)
-
async def setup_hook(self) -> None:
"""Default async initialisation method for discord.py."""
await super().setup_hook()
- # Build the FilterList cache
- await self.cache_filter_list_data()
-
# This is not awaited to avoid a deadlock with any cogs that have
# wait_until_guild_available in their cog_load method.
scheduling.create_task(self.load_extensions(exts))
diff --git a/bot/constants.py b/bot/constants.py
index 4186472b1..1e6227a94 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -333,43 +333,6 @@ class _Free(EnvConfig):
Free = _Free()
-class Rule(BaseModel):
- interval: int
- max: int
-
-
-# Some help in choosing an appropriate name for this is appreciated
-class ExtendedRule(Rule):
- max_consecutive: int
-
-
-class Rules(BaseModel):
- attachments: Rule = Rule(interval=10, max=6)
- burst: Rule = Rule(interval=10, max=7)
- chars: Rule = Rule(interval=5, max=4_200)
- discord_emojis: Rule = Rule(interval=10, max=20)
- duplicates: Rule = Rule(interval=10, max=3)
- links: Rule = Rule(interval=10, max=10)
- mentions: Rule = Rule(interval=10, max=5)
- newlines: ExtendedRule = ExtendedRule(interval=10, max=100, max_consecutive=10)
- role_mentions: Rule = Rule(interval=10, max=3)
-
-
-class _AntiSpam(EnvConfig):
- EnvConfig.Config.env_prefix = 'anti_spam_'
-
- cache_size = 100
-
- clean_offending = True
- ping_everyone = True
-
- remove_timeout_after = 600
- rules = Rules()
-
-
-AntiSpam = _AntiSpam()
-
-
class _HelpChannels(EnvConfig):
EnvConfig.Config.env_prefix = "help_channels_"
@@ -662,47 +625,6 @@ class _Icons(EnvConfig):
Icons = _Icons()
-class _Filter(EnvConfig):
- EnvConfig.Config.env_prefix = "filters_"
-
- filter_domains = True
- filter_everyone_ping = True
- filter_invites = True
- filter_zalgo = False
- watch_regex = True
- watch_rich_embeds = True
-
- # Notifications are not expected for "watchlist" type filters
-
- notify_user_domains = False
- notify_user_everyone_ping = True
- notify_user_invites = True
- notify_user_zalgo = False
-
- offensive_msg_delete_days = 7
- ping_everyone = True
-
- channel_whitelist = [
- Channels.admins,
- Channels.big_brother,
- Channels.dev_log,
- Channels.message_log,
- Channels.mod_log,
- Channels.staff_lounge
- ]
- role_whitelist = [
- Roles.admins,
- Roles.helpers,
- Roles.moderators,
- Roles.owners,
- Roles.python_community,
- Roles.partners
- ]
-
-
-Filter = _Filter()
-
-
class _Keys(EnvConfig):
EnvConfig.Config.env_prefix = "api_keys_"
diff --git a/bot/converters.py b/bot/converters.py
index 544513c90..21623b597 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -9,7 +9,7 @@ import dateutil.parser
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter
+from discord.ext.commands import BadArgument, Context, Converter, IDConverter, MemberConverter, UserConverter
from discord.utils import escape_markdown, snowflake_time
from pydis_core.site_api import ResponseCodeError
from pydis_core.utils import unqualify
@@ -68,54 +68,6 @@ class ValidDiscordServerInvite(Converter):
raise BadArgument("This does not appear to be a valid Discord server invite.")
-class ValidFilterListType(Converter):
- """
- A converter that checks whether the given string is a valid FilterList type.
-
- Raises `BadArgument` if the argument is not a valid FilterList type, and simply
- passes through the given argument otherwise.
- """
-
- @staticmethod
- async def get_valid_types(bot: Bot) -> list:
- """
- Try to get a list of valid filter list types.
-
- Raise a BadArgument if the API can't respond.
- """
- try:
- valid_types = await bot.api_client.get('bot/filter-lists/get-types')
- except ResponseCodeError:
- raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.")
-
- return [enum for enum, classname in valid_types]
-
- async def convert(self, ctx: Context, list_type: str) -> str:
- """Checks whether the given string is a valid FilterList type."""
- valid_types = await self.get_valid_types(ctx.bot)
- list_type = list_type.upper()
-
- if list_type not in valid_types:
-
- # Maybe the user is using the plural form of this type,
- # e.g. "guild_invites" instead of "guild_invite".
- #
- # This code will support the simple plural form (a single 's' at the end),
- # which works for all current list types, but if a list type is added in the future
- # which has an irregular plural form (like 'ies'), this code will need to be
- # refactored to support this.
- if list_type.endswith("S") and list_type[:-1] in valid_types:
- list_type = list_type[:-1]
-
- else:
- valid_types_list = '\n'.join([f"• {type_.lower()}" for type_ in valid_types])
- raise BadArgument(
- f"You have provided an invalid list type!\n\n"
- f"Please provide one of the following: \n{valid_types_list}"
- )
- return list_type
-
-
class Extension(Converter):
"""
Fully qualify the name of an extension and ensure it exists.
diff --git a/bot/exts/filtering/FILTERS-DEVELOPMENT.md b/bot/exts/filtering/FILTERS-DEVELOPMENT.md
new file mode 100644
index 000000000..d5896d556
--- /dev/null
+++ b/bot/exts/filtering/FILTERS-DEVELOPMENT.md
@@ -0,0 +1,63 @@
+# Filters Development
+This file gives a short overview of the extension, and shows how to perform some basic changes/additions to it.
+
+## Overview
+The main idea is that there is a list of filters each deciding whether they apply to the given content.
+For example, there can be a filter that decides it will trigger when the content contains the string "lemon".
+
+There are several types of filters, and two filters of the same type differ by their content.
+For example, filters of type "token" search for a specific token inside the provided string.
+One token filter might look for the string "lemon", while another will look for the string "joe".
+
+Each filter has a set of settings that decide when it triggers (e.g. in which channels), and what happens if it does (e.g. delete the message).
+Filters of a specific type can have additional settings that are special to them.
+
+A list of filters is contained within a filter list.
+The filter list gets content to filter, and dispatches it to each of its filters.
+It takes the answers from its filters and returns a unified response (e.g. if at least one of the filters says it should be deleted, then the filter list response will include it).
+
+A filter list has the same set of possible settings, which act as defaults.
+If a filter in the list doesn't define a value for a setting (meaning it has a value of None), it will use the value of the containing filter list.
+
+The cog receives "filtering events". For example, a new message is sent.
+It creates a "filtering context" with everything a filtering list needs to know to provide an answer for what should be done.
+For example, if the event is a new message, then the content to filter is the content of the message, embeds if any exist, etc.
+
+The cog dispatches the event to each filter list, gets the result from each, compiles them, and takes any action dictated by them.
+For example, if any of the filter lists want the message to be deleted, then the cog will delete it.
+
+## Example Changes
+### Creating a new type of filter list
+1. Head over to `bot.exts.filtering._filter_lists` and create a new Python file.
+2. Subclass the FilterList class in `bot.exts.filtering._filter_lists.filter_list` and implement its abstract methods. Make sure to set the `name` class attribute.
+
+You can now add filter lists to the database with the same name defined in the new FilterList subclass.
+
+### Creating a new type of filter
+1. Head over to `bot.exts.filtering._filters` and create a new Python file.
+2. Subclass the Filter class in `bot.exts.filtering._filters.filter` and implement its abstract methods.
+3. Make sure to set the `name` class attribute, and have one of the FilterList subclasses return this new Filter subclass in `get_filter_type`.
+
+### Creating a new type of setting
+1. Head over to `bot.exts.filtering._settings_types`, and open a new Python file in either `actions` or `validations`, depending on whether you want to subclass `ActionEntry` or `ValidationEntry`.
+2. Subclass one of the aforementioned classes, and implement its abstract methods. Make sure to set the `name` and `description` class attributes.
+
+You can now make the appropriate changes to the site repo:
+1. Add a new field in the `Filter` and `FilterList` models. Make sure that on `Filter` it's nullable, and on `FilterList` it isn't.
+2. In `serializers.py`, add the new field to `SETTINGS_FIELDS`, and to `ALLOW_BLANK_SETTINGS` or `ALLOW_EMPTY_SETTINGS` if appropriate. If it's not a part of any group of settings, add it `BASE_SETTINGS_FIELDS`, otherwise add it to the appropriate group or create a new one.
+3. If you created a new group, make sure it's used in `to_representation`.
+4. Update the docs in the filter viewsets.
+
+You can merge the changes to the bot first - if no such field is loaded from the database it'll just be ignored.
+
+You can define entries that are a group of fields in the database.
+In that case the created subclass should have fields whose names are the names of the fields in the database.
+Then, the description will be a dictionary, whose keys are the names of the fields, and values are the descriptions for each field.
+
+### Creating a new type of filtering event
+1. Head over to `bot.exts.filtering._filter_context` and add a new value to the `Event` enum.
+2. Implement the dispatching and actioning of the new event in the cog, by either adding it to an existing even listener, or creating a new one.
+3. Have the appropriate filter lists subscribe to the event, so they receive it.
+4. Have the appropriate unique filters (currently under `unique` and `antispam` in `bot.exts.filtering._filters`) subscribe to the event, so they receive it.
+
+It should be noted that the filtering events don't need to correspond to Discord events. For example, `nickname` isn't a Discord event and is dispatched when a message is sent.
diff --git a/bot/exts/filters/__init__.py b/bot/exts/filtering/__init__.py
index e69de29bb..e69de29bb 100644
--- a/bot/exts/filters/__init__.py
+++ b/bot/exts/filtering/__init__.py
diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py
new file mode 100644
index 000000000..8e1ed5788
--- /dev/null
+++ b/bot/exts/filtering/_filter_context.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+import typing
+from collections.abc import Callable, Coroutine, Iterable
+from dataclasses import dataclass, field, replace
+from enum import Enum, auto
+
+from discord import DMChannel, Embed, Member, Message, TextChannel, Thread, User
+
+from bot.utils.message_cache import MessageCache
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering._filters.filter import Filter
+
+
+class Event(Enum):
+ """Types of events that can trigger filtering. Note this does not have to align with gateway event types."""
+
+ MESSAGE = auto()
+ MESSAGE_EDIT = auto()
+ NICKNAME = auto()
+
+
+@dataclass
+class FilterContext:
+ """A dataclass containing the information that should be filtered, and output information of the filtering."""
+
+ # Input context
+ event: Event # The type of event
+ author: User | Member | None # Who triggered the event
+ channel: TextChannel | Thread | DMChannel | None # The channel involved
+ content: str | Iterable # What actually needs filtering. The Iterable type depends on the filter list.
+ message: Message | None # The message involved
+ embeds: list[Embed] = field(default_factory=list) # Any embeds involved
+ before_message: Message | None = None
+ message_cache: MessageCache | None = None
+ # Output context
+ dm_content: str = "" # The content to DM the invoker
+ dm_embed: str = "" # The embed description to DM the invoker
+ send_alert: bool = False # Whether to send an alert for the moderators
+ alert_content: str = "" # The content of the alert
+ alert_embeds: list[Embed] = field(default_factory=list) # Any embeds to add to the alert
+ action_descriptions: list[str] = field(default_factory=list) # What actions were taken
+ matches: list[str] = field(default_factory=list) # What exactly was found
+ notification_domain: str = "" # A domain to send the user for context
+ filter_info: dict['Filter', str] = field(default_factory=dict) # Additional info from a filter.
+ messages_deletion: bool = False # Whether the messages were deleted. Can't upload deletion log otherwise.
+ # Additional actions to perform
+ additional_actions: list[Callable[[FilterContext], Coroutine]] = field(default_factory=list)
+ related_messages: set[Message] = field(default_factory=set) # Deletion will include these.
+ related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set)
+ attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs.
+ upload_deletion_logs: bool = True # Whether it's allowed to upload deletion logs.
+
+ @classmethod
+ def from_message(
+ cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None
+ ) -> FilterContext:
+ """Create a filtering context from the attributes of a message."""
+ return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache)
+
+ def replace(self, **changes) -> FilterContext:
+ """Return a new context object assigning new values to the specified fields."""
+ return replace(self, **changes)
diff --git a/bot/exts/filtering/_filter_lists/__init__.py b/bot/exts/filtering/_filter_lists/__init__.py
new file mode 100644
index 000000000..82e0452f9
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/__init__.py
@@ -0,0 +1,9 @@
+from os.path import dirname
+
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType, list_type_converter
+from bot.exts.filtering._utils import subclasses_in_package
+
+filter_list_types = subclasses_in_package(dirname(__file__), f"{__name__}.", FilterList)
+filter_list_types = {filter_list.name: filter_list for filter_list in filter_list_types}
+
+__all__ = [filter_list_types, FilterList, ListType, list_type_converter]
diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py
new file mode 100644
index 000000000..0e7ab2bdc
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/antispam.py
@@ -0,0 +1,191 @@
+import asyncio
+import typing
+from collections.abc import Callable, Coroutine
+from dataclasses import dataclass, field
+from datetime import timedelta
+from functools import reduce
+from itertools import takewhile
+from operator import add, or_
+
+import arrow
+from discord import Member
+from pydis_core.utils import scheduling
+from pydis_core.utils.logging import get_logger
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filter_lists.filter_list import ListType, SubscribingAtomicList, UniquesListBase
+from bot.exts.filtering._filters.antispam import antispam_filter_types
+from bot.exts.filtering._filters.filter import Filter, UniqueFilter
+from bot.exts.filtering._settings import ActionSettings
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction
+from bot.exts.filtering._ui.ui import AlertView, build_mod_alert
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+log = get_logger(__name__)
+
+ALERT_DELAY = 6
+
+
+class AntispamList(UniquesListBase):
+ """
+ A list of anti-spam rules.
+
+ Messages from the last X seconds are passed to each rule, which decides whether it triggers across those messages.
+
+ The infraction reason is set dynamically.
+ """
+
+ name = "antispam"
+
+ def __init__(self, filtering_cog: 'Filtering'):
+ super().__init__(filtering_cog)
+ self.message_deletion_queue: dict[Member, DeletionContext] = dict()
+
+ def get_filter_type(self, content: str) -> type[UniqueFilter] | None:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ try:
+ return antispam_filter_types[content]
+ except KeyError:
+ if content not in self._already_warned:
+ log.warning(f"An antispam filter named {content} was supplied, but no matching implementation found.")
+ self._already_warned.add(content)
+ return None
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ if not ctx.message or not ctx.message_cache:
+ return None, [], {}
+
+ sublist: SubscribingAtomicList = self[ListType.DENY]
+ potential_filters = [sublist.filters[id_] for id_ in sublist.subscriptions[ctx.event]]
+ max_interval = max(filter_.extra_fields.interval for filter_ in potential_filters)
+
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=max_interval)
+ relevant_messages = list(
+ takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.message_cache)
+ )
+ new_ctx = ctx.replace(content=relevant_messages)
+ triggers = await sublist.filter_list_result(new_ctx)
+ if not triggers:
+ return None, [], {}
+
+ if ctx.author not in self.message_deletion_queue:
+ self.message_deletion_queue[ctx.author] = DeletionContext()
+ ctx.additional_actions.append(self._create_deletion_context_handler(ctx.author))
+ ctx.related_channels |= {msg.channel for msg in ctx.related_messages}
+ else: # The additional messages found are already part of a deletion context
+ ctx.related_messages = set()
+ current_infraction = self.message_deletion_queue[ctx.author].current_infraction
+ # In case another filter wants an alert, prevent deleted messages from being uploaded now and also for
+ # the spam alert (upload happens during alerting).
+ # Deleted messages API doesn't accept duplicates and will error.
+ # Additional messages are necessarily part of the deletion.
+ ctx.upload_deletion_logs = False
+ self.message_deletion_queue[ctx.author].add(ctx, triggers)
+
+ current_actions = sublist.merge_actions(triggers)
+ # Don't alert yet.
+ current_actions.pop("ping", None)
+ current_actions.pop("send_alert", None)
+
+ new_infraction = current_actions["infraction_and_notification"].copy()
+ # Smaller infraction value => higher in hierarchy.
+ if not current_infraction or new_infraction.infraction_type.value < current_infraction.value:
+ # Pick the first triggered filter for the reason, there's no good way to decide between them.
+ new_infraction.infraction_reason = (
+ f"{triggers[0].name.replace('_', ' ')} spam – {ctx.filter_info[triggers[0]]}"
+ )
+ current_actions["infraction_and_notification"] = new_infraction
+ self.message_deletion_queue[ctx.author].current_infraction = new_infraction.infraction_type
+ else:
+ current_actions.pop("infraction_and_notification", None)
+
+ # Provide some message in case another filter list wants there to be an alert.
+ return current_actions, ["Handling spam event..."], {ListType.DENY: triggers}
+
+ def _create_deletion_context_handler(self, context_id: Member) -> Callable[[FilterContext], Coroutine]:
+ async def schedule_processing(ctx: FilterContext) -> None:
+ """
+ Schedule a coroutine to process the deletion context.
+
+ It cannot be awaited directly, as it waits ALERT_DELAY seconds, and actioning a filtering context depends on
+ all actions finishing.
+
+ This is async and takes a context to adhere to the type of ctx.additional_actions.
+ """
+ async def process_deletion_context() -> None:
+ """Processes the Deletion Context queue."""
+ log.trace("Sleeping before processing message deletion queue.")
+ await asyncio.sleep(ALERT_DELAY)
+
+ if context_id not in self.message_deletion_queue:
+ log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!")
+ return
+
+ deletion_context = self.message_deletion_queue.pop(context_id)
+ await deletion_context.send_alert(self)
+
+ scheduling.create_task(process_deletion_context())
+
+ return schedule_processing
+
+
+@dataclass
+class DeletionContext:
+ """Represents a Deletion Context for a single spam event."""
+
+ contexts: list[FilterContext] = field(default_factory=list)
+ rules: set[UniqueFilter] = field(default_factory=set)
+ current_infraction: Infraction | None = None
+
+ def add(self, ctx: FilterContext, rules: list[UniqueFilter]) -> None:
+ """Adds new rule violation events to the deletion context."""
+ self.contexts.append(ctx)
+ self.rules.update(rules)
+
+ async def send_alert(self, antispam_list: AntispamList) -> None:
+ """Post the mod alert."""
+ if not self.contexts or not self.rules:
+ return
+
+ webhook = antispam_list.filtering_cog.webhook
+ if not webhook:
+ return
+
+ ctx, *other_contexts = self.contexts
+ new_ctx = FilterContext(ctx.event, ctx.author, ctx.channel, ctx.content, ctx.message)
+ new_ctx.action_descriptions = reduce(
+ add, (other_ctx.action_descriptions for other_ctx in other_contexts), ctx.action_descriptions
+ )
+ # It shouldn't ever come to this, but just in case.
+ if descriptions_num := len(new_ctx.action_descriptions) > 20:
+ new_ctx.action_descriptions = new_ctx.action_descriptions[:20]
+ new_ctx.action_descriptions[-1] += f" (+{descriptions_num - 20} other actions)"
+ new_ctx.related_messages = reduce(
+ or_, (other_ctx.related_messages for other_ctx in other_contexts), ctx.related_messages
+ ) | {ctx.message for ctx in other_contexts}
+ new_ctx.related_channels = reduce(
+ or_, (other_ctx.related_channels for other_ctx in other_contexts), ctx.related_channels
+ ) | {ctx.channel for ctx in other_contexts}
+ new_ctx.attachments = reduce(or_, (other_ctx.attachments for other_ctx in other_contexts), ctx.attachments)
+ new_ctx.upload_deletion_logs = True
+ new_ctx.messages_deletion = all(ctx.messages_deletion for ctx in self.contexts)
+
+ rules = list(self.rules)
+ actions = antispam_list[ListType.DENY].merge_actions(rules)
+ for action in list(actions):
+ if action not in ("ping", "send_alert"):
+ actions.pop(action, None)
+ await actions.action(new_ctx)
+
+ messages = antispam_list[ListType.DENY].format_messages(rules)
+ embed = await build_mod_alert(new_ctx, {antispam_list: messages})
+ if other_contexts:
+ embed.set_footer(
+ text="The list of actions taken includes actions from additional contexts after deletion began."
+ )
+ await webhook.send(username="Anti-Spam", content=ctx.alert_content, embeds=[embed], view=AlertView(new_ctx))
diff --git a/bot/exts/filtering/_filter_lists/domain.py b/bot/exts/filtering/_filter_lists/domain.py
new file mode 100644
index 000000000..f4062edfe
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/domain.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+import re
+import typing
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
+from bot.exts.filtering._filters.domain import DomainFilter
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._settings import ActionSettings
+from bot.exts.filtering._utils import clean_input
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+URL_RE = re.compile(r"https?://(\S+)", flags=re.IGNORECASE)
+
+
+class DomainsList(FilterList[DomainFilter]):
+ """
+ A list of filters, each looking for a specific domain given by URL.
+
+ The blacklist defaults dictate what happens by default when a filter is matched, and can be overridden by
+ individual filters.
+
+ Domains are found by looking for a URL schema (http or https).
+ Filters will also trigger for subdomains.
+ """
+
+ name = "domain"
+
+ def __init__(self, filtering_cog: Filtering):
+ super().__init__()
+ filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT)
+
+ def get_filter_type(self, content: str) -> type[Filter]:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ return DomainFilter
+
+ @property
+ def filter_types(self) -> set[type[Filter]]:
+ """Return the types of filters used by this list."""
+ return {DomainFilter}
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ text = ctx.content
+ if not text:
+ return None, [], {}
+
+ text = clean_input(text)
+ urls = {match.group(1).lower().rstrip("/") for match in URL_RE.finditer(text)}
+ new_ctx = ctx.replace(content=urls)
+
+ triggers = await self[ListType.DENY].filter_list_result(new_ctx)
+ ctx.notification_domain = new_ctx.notification_domain
+ actions = None
+ messages = []
+ if triggers:
+ actions = self[ListType.DENY].merge_actions(triggers)
+ messages = self[ListType.DENY].format_messages(triggers)
+ return actions, messages, {ListType.DENY: triggers}
diff --git a/bot/exts/filtering/_filter_lists/extension.py b/bot/exts/filtering/_filter_lists/extension.py
new file mode 100644
index 000000000..a739d7191
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/extension.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+
+import typing
+from os.path import splitext
+
+import bot
+from bot.constants import Channels, URLs
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
+from bot.exts.filtering._filters.extension import ExtensionFilter
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._settings import ActionSettings
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+
+PY_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach a Python file - "
+ f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
+)
+
+TXT_LIKE_FILES = {".txt", ".csv", ".json"}
+TXT_EMBED_DESCRIPTION = (
+ "You either uploaded a `{blocked_extension}` file or entered a message that was too long. "
+ f"Please use our [paste bin]({URLs.site_schema}{URLs.site_paste}) instead."
+)
+
+DISALLOWED_EMBED_DESCRIPTION = (
+ "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
+ "We currently allow the following file types: **{joined_whitelist}**.\n\n"
+ "Feel free to ask in {meta_channel_mention} if you think this is a mistake."
+)
+
+
+class ExtensionsList(FilterList[ExtensionFilter]):
+ """
+ A list of filters, each looking for a file attachment with a specific extension.
+
+ If an extension is not explicitly allowed, it will be blocked.
+
+ Whitelist defaults dictate what happens when an extension is *not* explicitly allowed,
+ and whitelist filters overrides have no effect.
+
+ Items should be added as file extensions preceded by a dot.
+ """
+
+ name = "extension"
+
+ def __init__(self, filtering_cog: Filtering):
+ super().__init__()
+ filtering_cog.subscribe(self, Event.MESSAGE)
+ self._whitelisted_description = None
+
+ def get_filter_type(self, content: str) -> type[Filter]:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ return ExtensionFilter
+
+ @property
+ def filter_types(self) -> set[type[Filter]]:
+ """Return the types of filters used by this list."""
+ return {ExtensionFilter}
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ # Return early if the message doesn't have attachments.
+ if not ctx.message or not ctx.message.attachments:
+ return None, [], {}
+
+ _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx)
+ if failed: # There's no extension filtering in this context.
+ return None, [], {}
+
+ # Find all extensions in the message.
+ all_ext = {
+ (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.message.attachments
+ }
+ new_ctx = ctx.replace(content={ext for ext, _ in all_ext}) # And prepare the context for the filters to read.
+ triggered = [
+ filter_ for filter_ in self[ListType.ALLOW].filters.values() if await filter_.triggered_on(new_ctx)
+ ]
+ allowed_ext = {filter_.content for filter_ in triggered} # Get the extensions in the message that are allowed.
+
+ # See if there are any extensions left which aren't allowed.
+ not_allowed = {ext: filename for ext, filename in all_ext if ext not in allowed_ext}
+
+ if not not_allowed: # Yes, it's a double negative. Meaning all attachments are allowed :)
+ return None, [], {ListType.ALLOW: triggered}
+
+ # Something is disallowed.
+ if ".py" in not_allowed:
+ # Provide a pastebin link for .py files.
+ ctx.dm_embed = PY_EMBED_DESCRIPTION
+ elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}:
+ # Work around Discord auto-conversion of messages longer than 2000 chars to .txt
+ cmd_channel = bot.instance.get_channel(Channels.bot_commands)
+ ctx.dm_embed = TXT_EMBED_DESCRIPTION.format(
+ blocked_extension=txt_extensions.pop(),
+ cmd_channel_mention=cmd_channel.mention
+ )
+ else:
+ meta_channel = bot.instance.get_channel(Channels.meta)
+ if not self._whitelisted_description:
+ self._whitelisted_description = ', '.join(
+ filter_.content for filter_ in self[ListType.ALLOW].filters.values()
+ )
+ ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format(
+ joined_whitelist=self._whitelisted_description,
+ blocked_extensions_str=", ".join(not_allowed),
+ meta_channel_mention=meta_channel.mention,
+ )
+
+ ctx.matches += not_allowed.values()
+ return self[ListType.ALLOW].defaults.actions, [f"`{ext}`" for ext in not_allowed], {ListType.ALLOW: triggered}
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py
new file mode 100644
index 000000000..bf02071cf
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/filter_list.py
@@ -0,0 +1,308 @@
+import dataclasses
+import typing
+from abc import abstractmethod
+from collections import defaultdict
+from collections.abc import Iterable
+from dataclasses import dataclass
+from enum import Enum
+from functools import reduce
+from typing import Any
+
+import arrow
+from discord.ext.commands import BadArgument
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import Filter, UniqueFilter
+from bot.exts.filtering._settings import ActionSettings, Defaults, create_settings
+from bot.exts.filtering._utils import FieldRequiring, past_tense
+from bot.log import get_logger
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+log = get_logger(__name__)
+
+
+class ListType(Enum):
+ """An enumeration of list types."""
+
+ DENY = 0
+ ALLOW = 1
+
+
+# Alternative names with which each list type can be specified in commands.
+aliases = (
+ (ListType.DENY, {"deny", "blocklist", "blacklist", "denylist", "bl", "dl"}),
+ (ListType.ALLOW, {"allow", "allowlist", "whitelist", "al", "wl"})
+)
+
+
+def list_type_converter(argument: str) -> ListType:
+ """A converter to get the appropriate list type."""
+ argument = argument.lower()
+ for list_type, list_aliases in aliases:
+ if argument in list_aliases or argument in map(past_tense, list_aliases):
+ return list_type
+ raise BadArgument(f"No matching list type found for {argument!r}.")
+
+
+# AtomicList and its subclasses must have eq=False, otherwise the dataclass deco will replace the hash function.
+@dataclass(frozen=True, eq=False)
+class AtomicList:
+ """
+ Represents the atomic structure of a single filter list as it appears in the database.
+
+ This is as opposed to the FilterList class which is a combination of several list types.
+ """
+
+ id: int
+ created_at: arrow.Arrow
+ updated_at: arrow.Arrow
+ name: str
+ list_type: ListType
+ defaults: Defaults
+ filters: dict[int, Filter]
+
+ @property
+ def label(self) -> str:
+ """Provide a short description identifying the list with its name and type."""
+ return f"{past_tense(self.list_type.name.lower())} {self.name.lower()}"
+
+ async def filter_list_result(self, ctx: FilterContext) -> list[Filter]:
+ """
+ Sift through the list of filters, and return only the ones which apply to the given context.
+
+ The strategy is as follows:
+ 1. The default settings are evaluated on the given context. The default answer for whether the filter is
+ relevant in the given context is whether there aren't any validation settings which returned False.
+ 2. For each filter, its overrides are considered:
+ - If there are no overrides, then the filter is relevant if that is the default answer.
+ - Otherwise it is relevant if there are no failed overrides, and any failing default is overridden by a
+ successful override.
+
+ If the filter is relevant in context, see if it actually triggers.
+ """
+ return await self._create_filter_list_result(ctx, self.defaults, self.filters.values())
+
+ async def _create_filter_list_result(
+ self, ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter]
+ ) -> list[Filter]:
+ """A helper function to evaluate the result of `filter_list_result`."""
+ passed_by_default, failed_by_default = defaults.validations.evaluate(ctx)
+ default_answer = not bool(failed_by_default)
+
+ relevant_filters = []
+ for filter_ in filters:
+ if not filter_.validations:
+ if default_answer and await filter_.triggered_on(ctx):
+ relevant_filters.append(filter_)
+ else:
+ passed, failed = filter_.validations.evaluate(ctx)
+ if not failed and failed_by_default < passed:
+ if await filter_.triggered_on(ctx):
+ relevant_filters.append(filter_)
+
+ if ctx.event == Event.MESSAGE_EDIT and ctx.message and self.list_type == ListType.DENY:
+ previously_triggered = ctx.message_cache.get_message_metadata(ctx.message.id)
+ # The message might not be cached.
+ if previously_triggered:
+ ignore_filters = previously_triggered[self]
+ # This updates the cache. Some filters are ignored, but they're necessary if there's another edit.
+ previously_triggered[self] = relevant_filters
+ relevant_filters = [filter_ for filter_ in relevant_filters if filter_ not in ignore_filters]
+ return relevant_filters
+
+ def default(self, setting_name: str) -> Any:
+ """Get the default value of a specific setting."""
+ missing = object()
+ value = self.defaults.actions.get_setting(setting_name, missing)
+ if value is missing:
+ value = self.defaults.validations.get_setting(setting_name, missing)
+ if value is missing:
+ raise ValueError(f"Couldn't find a setting named {setting_name!r}.")
+ return value
+
+ def merge_actions(self, filters: list[Filter]) -> ActionSettings | None:
+ """
+ Merge the settings of the given filters, with the list's defaults as fallback.
+
+ If `merge_default` is True, include it in the merge instead of using it as a fallback.
+ """
+ if not filters: # Nothing to action.
+ return None
+ try:
+ return reduce(
+ ActionSettings.union, (filter_.actions or self.defaults.actions for filter_ in filters)
+ ).fallback_to(self.defaults.actions)
+ except TypeError:
+ # The sequence fed to reduce is empty, meaning none of the filters have actions,
+ # meaning they all use the defaults.
+ return self.defaults.actions
+
+ @staticmethod
+ def format_messages(triggers: list[Filter], *, expand_single_filter: bool = True) -> list[str]:
+ """Convert the filters into strings that can be added to the alert embed."""
+ if len(triggers) == 1 and expand_single_filter:
+ message = f"#{triggers[0].id} (`{triggers[0].content}`)"
+ if triggers[0].description:
+ message += f" - {triggers[0].description}"
+ messages = [message]
+ else:
+ messages = [f"{filter_.id} (`{filter_.content}`)" for filter_ in triggers]
+ return messages
+
+ def __hash__(self):
+ return hash(id(self))
+
+
+T = typing.TypeVar("T", bound=Filter)
+
+
+class FilterList(dict[ListType, AtomicList], typing.Generic[T], FieldRequiring):
+ """Dispatches events to lists of _filters, and aggregates the responses into a single list of actions to take."""
+
+ # Each subclass must define a name matching the filter_list name we're expecting to receive from the database.
+ # Names must be unique across all filter lists.
+ name = FieldRequiring.MUST_SET_UNIQUE
+
+ _already_warned = set()
+
+ def add_list(self, list_data: dict) -> AtomicList:
+ """Add a new type of list (such as a whitelist or a blacklist) this filter list."""
+ actions, validations = create_settings(list_data["settings"], keep_empty=True)
+ list_type = ListType(list_data["list_type"])
+ defaults = Defaults(actions, validations)
+
+ filters = {}
+ for filter_data in list_data["filters"]:
+ new_filter = self._create_filter(filter_data, defaults)
+ if new_filter:
+ filters[filter_data["id"]] = new_filter
+
+ self[list_type] = AtomicList(
+ list_data["id"],
+ arrow.get(list_data["created_at"]),
+ arrow.get(list_data["updated_at"]),
+ self.name,
+ list_type,
+ defaults,
+ filters
+ )
+ return self[list_type]
+
+ def add_filter(self, list_type: ListType, filter_data: dict) -> T | None:
+ """Add a filter to the list of the specified type."""
+ new_filter = self._create_filter(filter_data, self[list_type].defaults)
+ if new_filter:
+ self[list_type].filters[filter_data["id"]] = new_filter
+ return new_filter
+
+ @abstractmethod
+ def get_filter_type(self, content: str) -> type[T]:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+
+ @property
+ @abstractmethod
+ def filter_types(self) -> set[type[T]]:
+ """Return the types of filters used by this list."""
+
+ @abstractmethod
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+
+ def _create_filter(self, filter_data: dict, defaults: Defaults) -> T | None:
+ """Create a filter from the given data."""
+ try:
+ content = filter_data["content"]
+ filter_type = self.get_filter_type(content)
+ if filter_type:
+ return filter_type(filter_data, defaults)
+ elif content not in self._already_warned:
+ log.warning(f"A filter named {content} was supplied, but no matching implementation found.")
+ self._already_warned.add(content)
+ return None
+ except TypeError as e:
+ log.warning(e)
+
+ def __hash__(self):
+ return hash(id(self))
+
+
+@dataclass(frozen=True, eq=False)
+class SubscribingAtomicList(AtomicList):
+ """
+ A base class for a list of unique filters.
+
+ Unique filters are ones that should only be run once in a given context.
+ Each unique filter is subscribed to a subset of events to respond to.
+ """
+
+ subscriptions: defaultdict[Event, list[int]] = dataclasses.field(default_factory=lambda: defaultdict(list))
+
+ def subscribe(self, filter_: UniqueFilter, *events: Event) -> None:
+ """
+ Subscribe a unique filter to the given events.
+
+ The filter is added to a list for each event. When the event is triggered, the filter context will be
+ dispatched to the subscribed filters.
+ """
+ for event in events:
+ if filter_ not in self.subscriptions[event]:
+ self.subscriptions[event].append(filter_.id)
+
+ async def filter_list_result(self, ctx: FilterContext) -> list[Filter]:
+ """Sift through the list of filters, and return only the ones which apply to the given context."""
+ event_filters = [self.filters[id_] for id_ in self.subscriptions[ctx.event]]
+ return await self._create_filter_list_result(ctx, self.defaults, event_filters)
+
+
+class UniquesListBase(FilterList[UniqueFilter]):
+ """
+ A list of unique filters.
+
+ Unique filters are ones that should only be run once in a given context.
+ Each unique filter subscribes to a subset of events to respond to.
+ """
+
+ def __init__(self, filtering_cog: 'Filtering'):
+ super().__init__()
+ self.filtering_cog = filtering_cog
+ self.loaded_types: dict[str, type[UniqueFilter]] = {}
+
+ def add_list(self, list_data: dict) -> SubscribingAtomicList:
+ """Add a new type of list (such as a whitelist or a blacklist) this filter list."""
+ actions, validations = create_settings(list_data["settings"], keep_empty=True)
+ list_type = ListType(list_data["list_type"])
+ defaults = Defaults(actions, validations)
+ new_list = SubscribingAtomicList(
+ list_data["id"],
+ arrow.get(list_data["created_at"]),
+ arrow.get(list_data["updated_at"]),
+ self.name,
+ list_type,
+ defaults,
+ {}
+ )
+ self[list_type] = new_list
+
+ filters = {}
+ events = set()
+ for filter_data in list_data["filters"]:
+ new_filter = self._create_filter(filter_data, defaults)
+ if new_filter:
+ new_list.subscribe(new_filter, *new_filter.events)
+ filters[filter_data["id"]] = new_filter
+ self.loaded_types[new_filter.name] = type(new_filter)
+ events.update(new_filter.events)
+
+ new_list.filters.update(filters)
+ if hasattr(self.filtering_cog, "subscribe"): # Subscribe the filter list to any new events found.
+ self.filtering_cog.subscribe(self, *events)
+ return new_list
+
+ @property
+ def filter_types(self) -> set[type[UniqueFilter]]:
+ """Return the types of filters used by this list."""
+ return set(self.loaded_types.values())
diff --git a/bot/exts/filtering/_filter_lists/invite.py b/bot/exts/filtering/_filter_lists/invite.py
new file mode 100644
index 000000000..bd0eaa122
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/invite.py
@@ -0,0 +1,154 @@
+from __future__ import annotations
+
+import typing
+
+from discord import Embed, Invite
+from discord.errors import NotFound
+from pydis_core.utils.regex import DISCORD_INVITE
+
+import bot
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._filters.invite import InviteFilter
+from bot.exts.filtering._settings import ActionSettings
+from bot.exts.filtering._utils import clean_input
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+
+class InviteList(FilterList[InviteFilter]):
+ """
+ A list of filters, each looking for guild invites to a specific guild.
+
+ If the invite is not whitelisted, it will be blocked. Partnered and verified servers are allowed unless blacklisted.
+
+ Whitelist defaults dictate what happens when an invite is *not* explicitly allowed,
+ and whitelist filters overrides have no effect.
+
+ Blacklist defaults dictate what happens by default when an explicitly blocked invite is found.
+
+ Items in the list are added through invites for the purpose of fetching the guild info.
+ Items are stored as guild IDs, guild invites are *not* stored.
+ """
+
+ name = "invite"
+
+ def __init__(self, filtering_cog: Filtering):
+ super().__init__()
+ filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT)
+
+ def get_filter_type(self, content: str) -> type[Filter]:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ return InviteFilter
+
+ @property
+ def filter_types(self) -> set[type[Filter]]:
+ """Return the types of filters used by this list."""
+ return {InviteFilter}
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ text = clean_input(ctx.content)
+
+ # Avoid escape characters
+ text = text.replace("\\", "")
+
+ matches = list(DISCORD_INVITE.finditer(text))
+ invite_codes = {m.group("invite") for m in matches}
+ if not invite_codes:
+ return None, [], {}
+ all_triggers = {}
+
+ _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx)
+ # If the allowed list doesn't operate in the context, unknown invites are allowed.
+ check_if_allowed = not failed
+
+ # Sort the invites into two categories:
+ invites_for_inspection = dict() # Found guild invites requiring further inspection.
+ unknown_invites = dict() # Either don't resolve or group DMs.
+ for invite_code in invite_codes:
+ try:
+ invite = await bot.instance.fetch_invite(invite_code)
+ except NotFound:
+ if check_if_allowed:
+ unknown_invites[invite_code] = None
+ else:
+ if invite.guild:
+ invites_for_inspection[invite_code] = invite
+ elif check_if_allowed: # Group DM
+ unknown_invites[invite_code] = invite
+
+ # Find any blocked invites
+ new_ctx = ctx.replace(content={invite.guild.id for invite in invites_for_inspection.values()})
+ triggered = await self[ListType.DENY].filter_list_result(new_ctx)
+ blocked_guilds = {filter_.content for filter_ in triggered}
+ blocked_invites = {
+ code: invite for code, invite in invites_for_inspection.items() if invite.guild.id in blocked_guilds
+ }
+
+ # Remove the ones which are already confirmed as blocked, or otherwise ones which are partnered or verified.
+ invites_for_inspection = {
+ code: invite for code, invite in invites_for_inspection.items()
+ if invite.guild.id not in blocked_guilds
+ and "PARTNERED" not in invite.guild.features and "VERIFIED" not in invite.guild.features
+ }
+
+ # Remove any remaining invites which are allowed
+ guilds_for_inspection = {invite.guild.id for invite in invites_for_inspection.values()}
+
+ if check_if_allowed: # Whether unknown invites need to be checked.
+ new_ctx = ctx.replace(content=guilds_for_inspection)
+ all_triggers[ListType.ALLOW] = [
+ filter_ for filter_ in self[ListType.ALLOW].filters.values()
+ if await filter_.triggered_on(new_ctx)
+ ]
+ allowed = {filter_.content for filter_ in all_triggers[ListType.ALLOW]}
+ unknown_invites.update({
+ code: invite for code, invite in invites_for_inspection.items() if invite.guild.id not in allowed
+ })
+
+ if not triggered and not unknown_invites:
+ return None, [], all_triggers
+
+ actions = None
+ if unknown_invites: # There are invites which weren't allowed but aren't explicitly blocked.
+ actions = self[ListType.ALLOW].defaults.actions
+ # Blocked invites come second so that their actions have preference.
+ if triggered:
+ if actions:
+ actions = actions.union(self[ListType.DENY].merge_actions(triggered))
+ else:
+ actions = self[ListType.DENY].merge_actions(triggered)
+ all_triggers[ListType.DENY] = triggered
+
+ blocked_invites |= unknown_invites
+ ctx.matches += {match[0] for match in matches if match.group("invite") in blocked_invites}
+ ctx.alert_embeds += (self._guild_embed(invite) for invite in blocked_invites.values() if invite)
+ messages = self[ListType.DENY].format_messages(triggered)
+ messages += [
+ f"`{code} - {invite.guild.id}`" if invite else f"`{code}`" for code, invite in unknown_invites.items()
+ ]
+ return actions, messages, all_triggers
+
+ @staticmethod
+ def _guild_embed(invite: Invite) -> Embed:
+ """Return an embed representing the guild invites to."""
+ embed = Embed()
+ if invite.guild:
+ embed.title = invite.guild.name
+ embed.set_thumbnail(url=invite.guild.icon.url)
+ embed.set_footer(text=f"Guild ID: {invite.guild.id}")
+ else:
+ embed.title = "Group DM"
+
+ embed.description = (
+ f"**Invite Code:** {invite.code}\n"
+ f"**Members:** {invite.approximate_member_count}\n"
+ f"**Active:** {invite.approximate_presence_count}"
+ )
+
+ return embed
diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py
new file mode 100644
index 000000000..f5da28bb5
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/token.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import re
+import typing
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._filters.token import TokenFilter
+from bot.exts.filtering._settings import ActionSettings
+from bot.exts.filtering._utils import clean_input
+
+if typing.TYPE_CHECKING:
+ from bot.exts.filtering.filtering import Filtering
+
+SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL)
+
+
+class TokensList(FilterList[TokenFilter]):
+ """
+ A list of filters, each looking for a specific token in the given content given as regex.
+
+ The blacklist defaults dictate what happens by default when a filter is matched, and can be overridden by
+ individual filters.
+
+ Usually, if blocking literal strings, the literals themselves can be specified as the filter's value.
+ But since this is a list of regex patterns, be careful of the items added. For example, a dot needs to be escaped
+ to function as a literal dot.
+ """
+
+ name = "token"
+
+ def __init__(self, filtering_cog: Filtering):
+ super().__init__()
+ filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME)
+
+ def get_filter_type(self, content: str) -> type[Filter]:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ return TokenFilter
+
+ @property
+ def filter_types(self) -> set[type[Filter]]:
+ """Return the types of filters used by this list."""
+ return {TokenFilter}
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ text = ctx.content
+ if not text:
+ return None, [], {}
+ if SPOILER_RE.search(text):
+ text = self._expand_spoilers(text)
+ text = clean_input(text)
+ ctx = ctx.replace(content=text)
+
+ triggers = await self[ListType.DENY].filter_list_result(ctx)
+ actions = None
+ messages = []
+ if triggers:
+ actions = self[ListType.DENY].merge_actions(triggers)
+ messages = self[ListType.DENY].format_messages(triggers)
+ return actions, messages, {ListType.DENY: triggers}
+
+ @staticmethod
+ def _expand_spoilers(text: str) -> str:
+ """Return a string containing all interpretations of a spoilered message."""
+ split_text = SPOILER_RE.split(text)
+ return ''.join(
+ split_text[0::2] + split_text[1::2] + split_text
+ )
diff --git a/bot/exts/filtering/_filter_lists/unique.py b/bot/exts/filtering/_filter_lists/unique.py
new file mode 100644
index 000000000..a5a04d25a
--- /dev/null
+++ b/bot/exts/filtering/_filter_lists/unique.py
@@ -0,0 +1,39 @@
+from pydis_core.utils.logging import get_logger
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filter_lists.filter_list import ListType, UniquesListBase
+from bot.exts.filtering._filters.filter import Filter, UniqueFilter
+from bot.exts.filtering._filters.unique import unique_filter_types
+from bot.exts.filtering._settings import ActionSettings
+
+log = get_logger(__name__)
+
+
+class UniquesList(UniquesListBase):
+ """
+ A list of unique filters.
+
+ Unique filters are ones that should only be run once in a given context.
+ Each unique filter subscribes to a subset of events to respond to.
+ """
+
+ name = "unique"
+
+ def get_filter_type(self, content: str) -> type[UniqueFilter] | None:
+ """Get a subclass of filter matching the filter list and the filter's content."""
+ try:
+ return unique_filter_types[content]
+ except KeyError:
+ return None
+
+ async def actions_for(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
+ """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
+ triggers = await self[ListType.DENY].filter_list_result(ctx)
+ actions = None
+ messages = []
+ if triggers:
+ actions = self[ListType.DENY].merge_actions(triggers)
+ messages = self[ListType.DENY].format_messages(triggers)
+ return actions, messages, {ListType.DENY: triggers}
diff --git a/tests/bot/exts/filters/__init__.py b/bot/exts/filtering/_filters/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/bot/exts/filters/__init__.py
+++ b/bot/exts/filtering/_filters/__init__.py
diff --git a/bot/exts/filtering/_filters/antispam/__init__.py b/bot/exts/filtering/_filters/antispam/__init__.py
new file mode 100644
index 000000000..637bcd410
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/__init__.py
@@ -0,0 +1,9 @@
+from os.path import dirname
+
+from bot.exts.filtering._filters.filter import UniqueFilter
+from bot.exts.filtering._utils import subclasses_in_package
+
+antispam_filter_types = subclasses_in_package(dirname(__file__), f"{__name__}.", UniqueFilter)
+antispam_filter_types = {filter_.name: filter_ for filter_ in antispam_filter_types}
+
+__all__ = [antispam_filter_types]
diff --git a/bot/exts/filtering/_filters/antispam/attachments.py b/bot/exts/filtering/_filters/antispam/attachments.py
new file mode 100644
index 000000000..216d9b886
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/attachments.py
@@ -0,0 +1,43 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+
+class ExtraAttachmentsSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of attachments before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 6
+
+
+class AttachmentsFilter(UniqueFilter):
+ """Detects too many attachments sent by a single user."""
+
+ name = "attachments"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraAttachmentsSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author and len(msg.attachments) > 0}
+ total_recent_attachments = sum(len(msg.attachments) for msg in detected_messages)
+
+ if total_recent_attachments > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_recent_attachments} attachments"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/burst.py b/bot/exts/filtering/_filters/antispam/burst.py
new file mode 100644
index 000000000..d78107d0a
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/burst.py
@@ -0,0 +1,41 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+
+class ExtraBurstSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of messages before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 7
+
+
+class BurstFilter(UniqueFilter):
+ """Detects too many messages sent by a single user."""
+
+ name = "burst"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraBurstSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+ if len(detected_messages) > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {len(detected_messages)} messages"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/chars.py b/bot/exts/filtering/_filters/antispam/chars.py
new file mode 100644
index 000000000..5c4fa201c
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/chars.py
@@ -0,0 +1,43 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+
+class ExtraCharsSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of characters before the filter is triggered."
+
+ interval: int = 5
+ threshold: int = 4_200
+
+
+class CharsFilter(UniqueFilter):
+ """Detects too many characters sent by a single user."""
+
+ name = "chars"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraCharsSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+ total_recent_chars = sum(len(msg.content) for msg in relevant_messages)
+
+ if total_recent_chars > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_recent_chars} characters"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/duplicates.py b/bot/exts/filtering/_filters/antispam/duplicates.py
new file mode 100644
index 000000000..60d5c322c
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/duplicates.py
@@ -0,0 +1,44 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+
+class ExtraDuplicatesSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of duplicate messages before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 3
+
+
+class DuplicatesFilter(UniqueFilter):
+ """Detects duplicated messages sent by a single user."""
+
+ name = "duplicates"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraDuplicatesSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+
+ detected_messages = {
+ msg for msg in relevant_messages
+ if msg.author == ctx.author and msg.content == ctx.message.content and msg.content
+ }
+ if len(detected_messages) > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {len(detected_messages)} duplicate messages"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/emoji.py b/bot/exts/filtering/_filters/antispam/emoji.py
new file mode 100644
index 000000000..0511e4a7b
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/emoji.py
@@ -0,0 +1,53 @@
+import re
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from emoji import demojize
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:")
+CODE_BLOCK_RE = re.compile(r"```.*?```", flags=re.DOTALL)
+
+
+class ExtraEmojiSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of emojis before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 20
+
+
+class EmojiFilter(UniqueFilter):
+ """Detects too many emojis sent by a single user."""
+
+ name = "emoji"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraEmojiSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+
+ # Get rid of code blocks in the message before searching for emojis.
+ # Convert Unicode emojis to :emoji: format to get their count.
+ total_emojis = sum(
+ len(DISCORD_EMOJI_RE.findall(demojize(CODE_BLOCK_RE.sub("", msg.content))))
+ for msg in relevant_messages
+ )
+
+ if total_emojis > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_emojis} emojis"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/links.py b/bot/exts/filtering/_filters/antispam/links.py
new file mode 100644
index 000000000..76fe53e70
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/links.py
@@ -0,0 +1,52 @@
+import re
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+LINK_RE = re.compile(r"(https?://\S+)")
+
+
+class ExtraLinksSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of links before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 10
+
+
+class DuplicatesFilter(UniqueFilter):
+ """Detects too many links sent by a single user."""
+
+ name = "links"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraLinksSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+
+ total_links = 0
+ messages_with_links = 0
+ for msg in relevant_messages:
+ total_matches = len(LINK_RE.findall(msg.content))
+ if total_matches:
+ messages_with_links += 1
+ total_links += total_matches
+
+ if total_links > self.extra_fields.threshold and messages_with_links > 1:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_links} links"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/mentions.py b/bot/exts/filtering/_filters/antispam/mentions.py
new file mode 100644
index 000000000..f3c945e16
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/mentions.py
@@ -0,0 +1,90 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from discord import DeletedReferencedMessage, MessageType, NotFound
+from pydantic import BaseModel
+from pydis_core.utils.logging import get_logger
+
+import bot
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+log = get_logger(__name__)
+
+
+class ExtraMentionsSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of distinct mentions before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 5
+
+
+class DuplicatesFilter(UniqueFilter):
+ """
+ Detects total mentions exceeding the limit sent by a single user.
+
+ Excludes mentions that are bots, themselves, or replied users.
+
+ In very rare cases, may not be able to determine a
+ mention was to a reply, in which case it is not ignored.
+ """
+
+ name = "mentions"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraMentionsSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+
+ # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned.
+ # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body.
+ # In order to exclude users who are mentioned as a reply, we check if the msg has a reference
+ #
+ # While we could use regex to parse the message content, and get a list of
+ # the mentions, that solution is very prone to breaking.
+ # We would need to deal with codeblocks, escaping markdown, and any discrepancies between
+ # our implementation and discord's Markdown parser which would cause false positives or false negatives.
+ total_recent_mentions = 0
+ for msg in relevant_messages:
+ # We check if the message is a reply, and if it is try to get the author
+ # since we ignore mentions of a user that we're replying to
+ reply_author = None
+
+ if msg.type == MessageType.reply:
+ ref = msg.reference
+
+ if not (resolved := ref.resolved):
+ # It is possible, in a very unusual situation, for a message to have a reference
+ # that is both not in the cache, and deleted while running this function.
+ # In such a situation, this will throw an error which we catch.
+ try:
+ resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message(
+ resolved.message_id
+ )
+ except NotFound:
+ log.info('Could not fetch the reference message as it has been deleted.')
+
+ if resolved and not isinstance(resolved, DeletedReferencedMessage):
+ reply_author = resolved.author
+
+ for user in msg.mentions:
+ # Don't count bot or self mentions, or the user being replied to (if applicable)
+ if user.bot or user in {msg.author, reply_author}:
+ continue
+ total_recent_mentions += 1
+
+ if total_recent_mentions > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_recent_mentions} mentions"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/newlines.py b/bot/exts/filtering/_filters/antispam/newlines.py
new file mode 100644
index 000000000..b15a35219
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/newlines.py
@@ -0,0 +1,61 @@
+import re
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+NEWLINES = re.compile(r"(\n+)")
+
+
+class ExtraNewlinesSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of newlines before the filter is triggered."
+ consecutive_threshold_description: ClassVar[str] = (
+ "Maximum number of consecutive newlines before the filter is triggered."
+ )
+
+ interval: int = 10
+ threshold: int = 100
+ consecutive_threshold: int = 10
+
+
+class NewlinesFilter(UniqueFilter):
+ """Detects too many newlines sent by a single user."""
+
+ name = "newlines"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraNewlinesSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+
+ # Identify groups of newline characters and get group & total counts
+ newline_counts = []
+ for msg in relevant_messages:
+ newline_counts += [len(group) for group in NEWLINES.findall(msg.content)]
+ total_recent_newlines = sum(newline_counts)
+ # Get maximum newline group size
+ max_newline_group = max(newline_counts, default=0)
+
+ # Check first for total newlines, if this passes then check for large groupings
+ if total_recent_newlines > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_recent_newlines} newlines"
+ return True
+ if max_newline_group > self.extra_fields.consecutive_threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {max_newline_group} consecutive newlines"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/antispam/role_mentions.py b/bot/exts/filtering/_filters/antispam/role_mentions.py
new file mode 100644
index 000000000..49de642fa
--- /dev/null
+++ b/bot/exts/filtering/_filters/antispam/role_mentions.py
@@ -0,0 +1,42 @@
+from datetime import timedelta
+from itertools import takewhile
+from typing import ClassVar
+
+import arrow
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+
+class ExtraRoleMentionsSettings(BaseModel):
+ """Extra settings for when to trigger the antispam rule."""
+
+ interval_description: ClassVar[str] = (
+ "Look for rule violations in messages from the last `interval` number of seconds."
+ )
+ threshold_description: ClassVar[str] = "Maximum number of role mentions before the filter is triggered."
+
+ interval: int = 10
+ threshold: int = 3
+
+
+class DuplicatesFilter(UniqueFilter):
+ """Detects too many role mentions sent by a single user."""
+
+ name = "role_mentions"
+ events = (Event.MESSAGE,)
+ extra_fields_type = ExtraRoleMentionsSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval)
+ relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content))
+ detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author}
+ total_recent_mentions = sum(len(msg.role_mentions) for msg in relevant_messages)
+
+ if total_recent_mentions > self.extra_fields.threshold:
+ ctx.related_messages |= detected_messages
+ ctx.filter_info[self] = f"sent {total_recent_mentions} role mentions"
+ return True
+ return False
diff --git a/bot/exts/filtering/_filters/domain.py b/bot/exts/filtering/_filters/domain.py
new file mode 100644
index 000000000..7c229fdcb
--- /dev/null
+++ b/bot/exts/filtering/_filters/domain.py
@@ -0,0 +1,62 @@
+import re
+from typing import ClassVar, Optional
+from urllib.parse import urlparse
+
+import tldextract
+from discord.ext.commands import BadArgument
+from pydantic import BaseModel
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filters.filter import Filter
+
+URL_RE = re.compile(r"(?:https?://)?(\S+?)[\\/]*", flags=re.IGNORECASE)
+
+
+class ExtraDomainSettings(BaseModel):
+ """Extra settings for how domains should be matched in a message."""
+
+ subdomains_description: ClassVar[str] = (
+ "A boolean. If True, will will only trigger for subdomains and subpaths, and not for the domain itself."
+ )
+
+ # Whether to trigger only for subdomains and subpaths, and not the specified domain itself.
+ subdomains: Optional[bool] = False
+
+
+class DomainFilter(Filter):
+ """
+ A filter which looks for a specific domain given by URL.
+
+ The schema (http, https) does not need to be included in the filter.
+ Will also match subdomains.
+ """
+
+ name = "domain"
+ extra_fields_type = ExtraDomainSettings
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Searches for a domain within a given context."""
+ domain = tldextract.extract(self.content).registered_domain
+
+ for found_url in ctx.content:
+ extract = tldextract.extract(found_url)
+ if self.content in found_url and extract.registered_domain == domain:
+ if self.extra_fields.subdomains:
+ if not extract.subdomain and not urlparse(f"https://{found_url}").path:
+ return False
+ ctx.matches.append(self.content)
+ ctx.notification_domain = self.content
+ return True
+ return False
+
+ @classmethod
+ async def process_input(cls, content: str, description: str) -> tuple[str, str]:
+ """
+ Process the content and description into a form which will work with the filtering.
+
+ A BadArgument should be raised if the content can't be used.
+ """
+ match = URL_RE.fullmatch(content)
+ if not match or not match.group(1):
+ raise BadArgument(f"`{content}` is not a URL.")
+ return match.group(1), description
diff --git a/bot/exts/filtering/_filters/extension.py b/bot/exts/filtering/_filters/extension.py
new file mode 100644
index 000000000..97eddc406
--- /dev/null
+++ b/bot/exts/filtering/_filters/extension.py
@@ -0,0 +1,27 @@
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filters.filter import Filter
+
+
+class ExtensionFilter(Filter):
+ """
+ A filter which looks for a specific attachment extension in messages.
+
+ The filter stores the extension preceded by a dot.
+ """
+
+ name = "extension"
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Searches for an attachment extension in the context content, given as a set of extensions."""
+ return self.content in ctx.content
+
+ @classmethod
+ async def process_input(cls, content: str, description: str) -> tuple[str, str]:
+ """
+ Process the content and description into a form which will work with the filtering.
+
+ A BadArgument should be raised if the content can't be used.
+ """
+ if not content.startswith("."):
+ content = f".{content}"
+ return content, description
diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py
new file mode 100644
index 000000000..b5f4c127a
--- /dev/null
+++ b/bot/exts/filtering/_filters/filter.py
@@ -0,0 +1,94 @@
+from abc import abstractmethod
+from typing import Any
+
+import arrow
+from pydantic import ValidationError
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._settings import Defaults, create_settings
+from bot.exts.filtering._utils import FieldRequiring
+
+
+class Filter(FieldRequiring):
+ """
+ A class representing a filter.
+
+ Each filter looks for a specific attribute within an event (such as message sent),
+ and defines what action should be performed if it is triggered.
+ """
+
+ # Each subclass must define a name which will be used to fetch its description.
+ # Names must be unique across all types of filters.
+ name = FieldRequiring.MUST_SET_UNIQUE
+ # If a subclass uses extra fields, it should assign the pydantic model type to this variable.
+ extra_fields_type = None
+
+ def __init__(self, filter_data: dict, defaults: Defaults | None = None):
+ self.id = filter_data["id"]
+ self.content = filter_data["content"]
+ self.description = filter_data["description"]
+ self.created_at = arrow.get(filter_data["created_at"])
+ self.updated_at = arrow.get(filter_data["updated_at"])
+ self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults)
+ if self.extra_fields_type:
+ self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"] or "{}") # noqa: P103
+ else:
+ self.extra_fields = None
+
+ @property
+ def overrides(self) -> tuple[dict[str, Any], dict[str, Any]]:
+ """Return a tuple of setting overrides and filter setting overrides."""
+ settings = {}
+ if self.actions:
+ settings = self.actions.overrides
+ if self.validations:
+ settings |= self.validations.overrides
+
+ filter_settings = {}
+ if self.extra_fields:
+ filter_settings = self.extra_fields.dict(exclude_unset=True)
+
+ return settings, filter_settings
+
+ @abstractmethod
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+
+ @classmethod
+ def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, str | None]:
+ """Validate whether the supplied fields are valid for the filter, and provide the error message if not."""
+ if cls.extra_fields_type is None:
+ return True, None
+
+ try:
+ cls.extra_fields_type(**extra_fields)
+ except ValidationError as e:
+ return False, repr(e)
+ else:
+ return True, None
+
+ @classmethod
+ async def process_input(cls, content: str, description: str) -> tuple[str, str]:
+ """
+ Process the content and description into a form which will work with the filtering.
+
+ A BadArgument should be raised if the content can't be used.
+ """
+ return content, description
+
+ def __str__(self) -> str:
+ """A string representation of the filter."""
+ string = f"{self.id}. `{self.content}`"
+ if self.description:
+ string += f" - {self.description}"
+ return string
+
+
+class UniqueFilter(Filter):
+ """
+ Unique filters are ones that should only be run once in a given context.
+
+ This is as opposed to say running many domain filters on the same message.
+ """
+
+ events: tuple[Event, ...] = FieldRequiring.MUST_SET
diff --git a/bot/exts/filtering/_filters/invite.py b/bot/exts/filtering/_filters/invite.py
new file mode 100644
index 000000000..799a302b9
--- /dev/null
+++ b/bot/exts/filtering/_filters/invite.py
@@ -0,0 +1,48 @@
+from discord import NotFound
+from discord.ext.commands import BadArgument
+from pydis_core.utils.regex import DISCORD_INVITE
+
+import bot
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filters.filter import Filter
+
+
+class InviteFilter(Filter):
+ """
+ A filter which looks for invites to a specific guild in messages.
+
+ The filter stores the guild ID which is allowed or denied.
+ """
+
+ name = "invite"
+
+ def __init__(self, filter_data: dict, defaults_data: dict | None = None):
+ super().__init__(filter_data, defaults_data)
+ self.content = int(self.content)
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Searches for a guild ID in the context content, given as a set of IDs."""
+ return self.content in ctx.content
+
+ @classmethod
+ async def process_input(cls, content: str, description: str) -> tuple[str, str]:
+ """
+ Process the content and description into a form which will work with the filtering.
+
+ A BadArgument should be raised if the content can't be used.
+ """
+ match = DISCORD_INVITE.fullmatch(content)
+ if not match or not match.group("invite"):
+ raise BadArgument(f"`{content}` is not a valid Discord invite.")
+ invite_code = match.group("invite")
+ try:
+ invite = await bot.instance.fetch_invite(invite_code)
+ except NotFound:
+ raise BadArgument(f"`{invite_code}` is not a valid Discord invite code.")
+ if not invite.guild:
+ raise BadArgument("Did you just try to add a group DM?")
+
+ guild_name = invite.guild.name if hasattr(invite.guild, "name") else ""
+ if guild_name.lower() not in description.lower():
+ description = " - ".join(part for part in (f'Guild "{guild_name}"', description) if part)
+ return str(invite.guild.id), description
diff --git a/bot/exts/filtering/_filters/token.py b/bot/exts/filtering/_filters/token.py
new file mode 100644
index 000000000..3cd9b909d
--- /dev/null
+++ b/bot/exts/filtering/_filters/token.py
@@ -0,0 +1,35 @@
+import re
+
+from discord.ext.commands import BadArgument
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filters.filter import Filter
+
+
+class TokenFilter(Filter):
+ """A filter which looks for a specific token given by regex."""
+
+ name = "token"
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Searches for a regex pattern within a given context."""
+ pattern = self.content
+
+ match = re.search(pattern, ctx.content, flags=re.IGNORECASE)
+ if match:
+ ctx.matches.append(match[0])
+ return True
+ return False
+
+ @classmethod
+ async def process_input(cls, content: str, description: str) -> tuple[str, str]:
+ """
+ Process the content and description into a form which will work with the filtering.
+
+ A BadArgument should be raised if the content can't be used.
+ """
+ try:
+ re.compile(content)
+ except re.error as e:
+ raise BadArgument(str(e))
+ return content, description
diff --git a/bot/exts/filtering/_filters/unique/__init__.py b/bot/exts/filtering/_filters/unique/__init__.py
new file mode 100644
index 000000000..ce78d6922
--- /dev/null
+++ b/bot/exts/filtering/_filters/unique/__init__.py
@@ -0,0 +1,9 @@
+from os.path import dirname
+
+from bot.exts.filtering._filters.filter import UniqueFilter
+from bot.exts.filtering._utils import subclasses_in_package
+
+unique_filter_types = subclasses_in_package(dirname(__file__), f"{__name__}.", UniqueFilter)
+unique_filter_types = {filter_.name: filter_ for filter_ in unique_filter_types}
+
+__all__ = [unique_filter_types]
diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filtering/_filters/unique/discord_token.py
index 29f80671d..6174ee30b 100644
--- a/bot/exts/filters/token_remover.py
+++ b/bot/exts/filtering/_filters/unique/discord_token.py
@@ -1,38 +1,34 @@
import base64
import re
-import typing as t
-
-from discord import Colour, Message, NotFound
-from discord.ext.commands import Cog
-
-from bot import utils
-from bot.bot import Bot
-from bot.constants import Channels, Colours, Event, Icons
+from collections.abc import Callable, Coroutine
+from typing import ClassVar, NamedTuple
+
+import discord
+from pydantic import BaseModel, Field
+from pydis_core.utils.logging import get_logger
+from pydis_core.utils.members import get_or_fetch_member
+
+import bot
+from bot import constants, utils
+from bot.constants import Guild
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+from bot.exts.filtering._utils import resolve_mention
from bot.exts.moderation.modlog import ModLog
-from bot.log import get_logger
-from bot.utils.members import get_or_fetch_member
from bot.utils.messages import format_user
log = get_logger(__name__)
+
LOG_MESSAGE = (
- "Censored a seemingly valid token sent by {author} in {channel}, "
- "token was `{user_id}.{timestamp}.{hmac}`"
+ "Censored a seemingly valid token sent by {author} in {channel}. "
+ "Token was: `{user_id}.{timestamp}.{hmac}`."
)
UNKNOWN_USER_LOG_MESSAGE = "Decoded user ID: `{user_id}` (Not present in server)."
KNOWN_USER_LOG_MESSAGE = (
"Decoded user ID: `{user_id}` **(Present in server)**.\n"
"This matches `{user_name}` and means this is likely a valid **{kind}** token."
)
-DELETION_MESSAGE_TEMPLATE = (
- "Hey {mention}! I noticed you posted a seemingly valid Discord API "
- "token in your message and have removed your message. "
- "This means that your token has been **compromised**. "
- "Please change your token **immediately** at: "
- "<https://discord.com/developers/applications>\n\n"
- "Feel free to re-post it with the token removed. "
- "If you believe this was a mistake, please let us know!"
-)
DISCORD_EPOCH = 1_420_070_400
TOKEN_EPOCH = 1_293_840_000
@@ -43,7 +39,17 @@ TOKEN_EPOCH = 1_293_840_000
TOKEN_RE = re.compile(r"([\w-]{10,})\.([\w-]{5,})\.([\w-]{10,})")
-class Token(t.NamedTuple):
+class ExtraDiscordTokenSettings(BaseModel):
+ """Extra settings for who should be pinged when a Discord token is detected."""
+
+ pings_for_bot_description: ClassVar[str] = "A sequence. Who should be pinged if the token found belongs to a bot."
+ pings_for_user_description: ClassVar[str] = "A sequence. Who should be pinged if the token found belongs to a user."
+
+ pings_for_bot: set[str] = Field(default_factory=set)
+ pings_for_user: set[str] = Field(default_factory=lambda: {"Moderators"})
+
+
+class Token(NamedTuple):
"""A Discord Bot token."""
user_id: str
@@ -51,84 +57,64 @@ class Token(t.NamedTuple):
hmac: str
-class TokenRemover(Cog):
+class DiscordTokenFilter(UniqueFilter):
"""Scans messages for potential discord client tokens and removes them."""
- def __init__(self, bot: Bot):
- self.bot = bot
+ name = "discord_token"
+ events = (Event.MESSAGE, Event.MESSAGE_EDIT)
+ extra_fields_type = ExtraDiscordTokenSettings
@property
- def mod_log(self) -> ModLog:
+ def mod_log(self) -> ModLog | None:
"""Get currently loaded ModLog cog instance."""
- return self.bot.get_cog("ModLog")
-
- @Cog.listener()
- async def on_message(self, msg: Message) -> None:
- """
- Check each message for a string that matches Discord's token pattern.
+ return bot.instance.get_cog("ModLog")
- See: https://discordapp.com/developers/docs/reference#snowflakes
- """
- # Ignore DMs; can't delete messages in there anyway.
- if not msg.guild or msg.author.bot:
- return
-
- found_token = self.find_token_in_message(msg)
- if found_token:
- await self.take_action(msg, found_token)
-
- @Cog.listener()
- async def on_message_edit(self, before: Message, after: Message) -> None:
- """
- Check each edit for a string that matches Discord's token pattern.
-
- See: https://discordapp.com/developers/docs/reference#snowflakes
- """
- await self.on_message(after)
-
- async def take_action(self, msg: Message, found_token: Token) -> None:
- """Remove the `msg` containing the `found_token` and send a mod log message."""
- self.mod_log.ignore(Event.message_delete, msg.id)
-
- try:
- await msg.delete()
- except NotFound:
- log.debug(f"Failed to remove token in message {msg.id}: message already deleted.")
- return
-
- await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
-
- log_message = self.format_log_message(msg, found_token)
- userid_message, mention_everyone = await self.format_userid_log_message(msg, found_token)
- log.debug(log_message)
-
- # Send pretty mod log embed to mod-alerts
- await self.mod_log.send_log_message(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=log_message + "\n" + userid_message,
- thumbnail=msg.author.display_avatar.url,
- channel_id=Channels.mod_alerts,
- ping_everyone=mention_everyone,
- )
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Return whether the message contains Discord client tokens."""
+ found_token = self.find_token_in_message(ctx.content)
+ if not found_token:
+ return False
- self.bot.stats.incr("tokens.removed_tokens")
+ if ctx.message and (mod_log := self.mod_log):
+ mod_log.ignore(constants.Event.message_delete, ctx.message.id)
+ ctx.content = ctx.content.replace(found_token.hmac, self.censor_hmac(found_token.hmac))
+ ctx.additional_actions.append(self._create_token_alert_embed_wrapper(found_token))
+ return True
+
+ def _create_token_alert_embed_wrapper(self, found_token: Token) -> Callable[[FilterContext], Coroutine]:
+ """Create the action to perform when an alert should be sent for a message containing a Discord token."""
+ async def _create_token_alert_embed(ctx: FilterContext) -> None:
+ """Add an alert embed to the context with info about the token sent."""
+ userid_message, is_user = await self.format_userid_log_message(found_token)
+ log_message = self.format_log_message(ctx.author, ctx.channel, found_token)
+ log.debug(log_message)
+
+ if is_user:
+ mentions = map(resolve_mention, self.extra_fields.pings_for_user)
+ color = discord.Colour.red()
+ else:
+ mentions = map(resolve_mention, self.extra_fields.pings_for_bot)
+ color = discord.Colour.blue()
+ unmentioned = [mention for mention in mentions if mention not in ctx.alert_content]
+ if unmentioned:
+ ctx.alert_content = f"{' '.join(unmentioned)} {ctx.alert_content}"
+ ctx.alert_embeds.append(discord.Embed(colour=color, description=userid_message))
+
+ return _create_token_alert_embed
@classmethod
- async def format_userid_log_message(cls, msg: Message, token: Token) -> t.Tuple[str, bool]:
+ async def format_userid_log_message(cls, token: Token) -> tuple[str, bool]:
"""
Format the portion of the log message that includes details about the detected user ID.
If the user is resolved to a member, the format includes the user ID, name, and the
kind of user detected.
-
- If we resolve to a member and it is not a bot, we also return True to ping everyone.
-
- Returns a tuple of (log_message, mention_everyone)
+ If it is resolved to a user or a member, and it is not a bot, also return True.
+ Returns a tuple of (log_message, is_user)
"""
user_id = cls.extract_user_id(token.user_id)
- user = await get_or_fetch_member(msg.guild, user_id)
+ guild = bot.instance.get_guild(Guild.id)
+ user = await get_or_fetch_member(guild, user_id)
if user:
return KNOWN_USER_LOG_MESSAGE.format(
@@ -140,22 +126,27 @@ class TokenRemover(Cog):
return UNKNOWN_USER_LOG_MESSAGE.format(user_id=user_id), False
@staticmethod
- def format_log_message(msg: Message, token: Token) -> str:
+ def censor_hmac(hmac: str) -> str:
+ """Return a censored version of the hmac."""
+ return 'x' * (len(hmac) - 3) + hmac[-3:]
+
+ @classmethod
+ def format_log_message(cls, author: discord.User, channel: discord.abc.GuildChannel, token: Token) -> str:
"""Return the generic portion of the log message to send for `token` being censored in `msg`."""
return LOG_MESSAGE.format(
- author=format_user(msg.author),
- channel=msg.channel.mention,
+ author=format_user(author),
+ channel=channel.mention,
user_id=token.user_id,
timestamp=token.timestamp,
- hmac='x' * (len(token.hmac) - 3) + token.hmac[-3:],
+ hmac=cls.censor_hmac(token.hmac),
)
@classmethod
- def find_token_in_message(cls, msg: Message) -> t.Optional[Token]:
- """Return a seemingly valid token found in `msg` or `None` if no token is found."""
+ def find_token_in_message(cls, content: str) -> Token | None:
+ """Return a seemingly valid token found in `content` or `None` if no token is found."""
# Use finditer rather than search to guard against method calls prematurely returning the
# token check (e.g. `message.channel.send` also matches our token pattern)
- for match in TOKEN_RE.finditer(msg.content):
+ for match in TOKEN_RE.finditer(content):
token = Token(*match.groups())
if (
(cls.extract_user_id(token.user_id) is not None)
@@ -169,7 +160,7 @@ class TokenRemover(Cog):
return None
@staticmethod
- def extract_user_id(b64_content: str) -> t.Optional[int]:
+ def extract_user_id(b64_content: str) -> int | None:
"""Return a user ID integer from part of a potential token, or None if it couldn't be decoded."""
b64_content = utils.pad_base64(b64_content)
@@ -214,7 +205,7 @@ class TokenRemover(Cog):
"""
Determine if a given HMAC portion of a token is potentially valid.
- If the HMAC has 3 or less characters, it's probably a dummy value like "xxxxxxxxxx",
+ If the HMAC has 3 or fewer characters, it's probably a dummy value like "xxxxxxxxxx",
and thus the token can probably be skipped.
"""
unique = len(set(b64_content.lower()))
@@ -226,8 +217,3 @@ class TokenRemover(Cog):
return False
else:
return True
-
-
-async def setup(bot: Bot) -> None:
- """Load the TokenRemover cog."""
- await bot.add_cog(TokenRemover(bot))
diff --git a/bot/exts/filtering/_filters/unique/everyone.py b/bot/exts/filtering/_filters/unique/everyone.py
new file mode 100644
index 000000000..a32e67cc5
--- /dev/null
+++ b/bot/exts/filtering/_filters/unique/everyone.py
@@ -0,0 +1,28 @@
+import re
+
+from bot.constants import Guild
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+
+EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{Guild.id}>|@here")
+CODE_BLOCK_RE = re.compile(
+ r"(?P<delim>``?)[^`]+?(?P=delim)(?!`+)" # Inline codeblock
+ r"|```(.+?)```", # Multiline codeblock
+ re.DOTALL | re.MULTILINE
+)
+
+
+class EveryoneFilter(UniqueFilter):
+ """Filter messages which contain `@everyone` and `@here` tags outside a codeblock."""
+
+ name = "everyone"
+ events = (Event.MESSAGE, Event.MESSAGE_EDIT)
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for the filter's content within a given context."""
+ # First pass to avoid running re.sub on every message
+ if not EVERYONE_PING_RE.search(ctx.content):
+ return False
+
+ content_without_codeblocks = CODE_BLOCK_RE.sub("", ctx.content)
+ return bool(EVERYONE_PING_RE.search(content_without_codeblocks))
diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py
new file mode 100644
index 000000000..2ee469f51
--- /dev/null
+++ b/bot/exts/filtering/_filters/unique/rich_embed.py
@@ -0,0 +1,51 @@
+import re
+
+from pydis_core.utils.logging import get_logger
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+from bot.utils.helpers import remove_subdomain_from_url
+
+log = get_logger(__name__)
+
+URL_RE = re.compile(r"(https?://\S+)", flags=re.IGNORECASE)
+
+
+class RichEmbedFilter(UniqueFilter):
+ """Filter messages which contain rich embeds not auto-generated from a URL."""
+
+ name = "rich_embed"
+ events = (Event.MESSAGE, Event.MESSAGE_EDIT)
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Determine if `msg` contains any rich embeds not auto-generated from a URL."""
+ if ctx.embeds:
+ if ctx.event == Event.MESSAGE_EDIT:
+ if not ctx.message.edited_at: # This might happen, apparently.
+ return False
+ # If the edit delta is less than 100 microseconds, it's probably a double filter trigger.
+ delta = ctx.message.edited_at - (ctx.before_message.edited_at or ctx.before_message.created_at)
+ if delta.total_seconds() < 0.0001:
+ return False
+
+ for embed in ctx.embeds:
+ if embed.type == "rich":
+ urls = URL_RE.findall(ctx.content)
+ final_urls = set(urls)
+ # This is due to the way discord renders relative urls in Embeds
+ # if the following url is sent: https://mobile.twitter.com/something
+ # Discord renders it as https://twitter.com/something
+ for url in urls:
+ final_urls.add(remove_subdomain_from_url(url))
+ if not embed.url or embed.url not in final_urls:
+ # If `embed.url` does not exist or if `embed.url` is not part of the content
+ # of the message, it's unlikely to be an auto-generated embed by Discord.
+ ctx.alert_embeds.extend(ctx.embeds)
+ return True
+ else:
+ log.trace(
+ "Found a rich embed sent by a regular user account, "
+ "but it was likely just an automatic URL embed."
+ )
+
+ return False
diff --git a/bot/exts/filtering/_filters/unique/webhook.py b/bot/exts/filtering/_filters/unique/webhook.py
new file mode 100644
index 000000000..965ef42eb
--- /dev/null
+++ b/bot/exts/filtering/_filters/unique/webhook.py
@@ -0,0 +1,63 @@
+import re
+from collections.abc import Callable, Coroutine
+
+from pydis_core.utils.logging import get_logger
+
+import bot
+from bot import constants
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.filter import UniqueFilter
+from bot.exts.moderation.modlog import ModLog
+
+log = get_logger(__name__)
+
+
+WEBHOOK_URL_RE = re.compile(
+ r"((?:https?://)?(?:ptb\.|canary\.)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?",
+ re.IGNORECASE
+)
+
+
+class WebhookFilter(UniqueFilter):
+ """Scan messages to detect Discord webhooks links."""
+
+ name = "webhook"
+ events = (Event.MESSAGE, Event.MESSAGE_EDIT)
+
+ @property
+ def mod_log(self) -> ModLog | None:
+ """Get current instance of `ModLog`."""
+ return bot.instance.get_cog("ModLog")
+
+ async def triggered_on(self, ctx: FilterContext) -> bool:
+ """Search for a webhook in the given content. If found, attempt to delete it."""
+ matches = set(WEBHOOK_URL_RE.finditer(ctx.content))
+ if not matches:
+ return False
+
+ # Don't log this.
+ if ctx.message and (mod_log := self.mod_log):
+ mod_log.ignore(constants.Event.message_delete, ctx.message.id)
+
+ for i, match in enumerate(matches, start=1):
+ extra = "" if len(matches) == 1 else f" ({i})"
+ # Queue the webhook for deletion.
+ ctx.additional_actions.append(self._delete_webhook_wrapper(match[0], extra))
+ # Don't show the full webhook in places such as the mod alert.
+ ctx.content = ctx.content.replace(match[0], match[1] + "xxx")
+
+ return True
+
+ @staticmethod
+ def _delete_webhook_wrapper(webhook_url: str, extra_message: str) -> Callable[[FilterContext], Coroutine]:
+ """Create the action to perform when a webhook should be deleted."""
+ async def _delete_webhook(ctx: FilterContext) -> None:
+ """Delete the given webhook and update the filter context."""
+ async with bot.instance.http_session.delete(webhook_url) as resp:
+ # The Discord API Returns a 204 NO CONTENT response on success.
+ if resp.status == 204:
+ ctx.action_descriptions.append("webhook deleted" + extra_message)
+ else:
+ ctx.action_descriptions.append("failed to delete webhook" + extra_message)
+
+ return _delete_webhook
diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py
new file mode 100644
index 000000000..75e810df5
--- /dev/null
+++ b/bot/exts/filtering/_settings.py
@@ -0,0 +1,233 @@
+from __future__ import annotations
+
+import operator
+import traceback
+from abc import abstractmethod
+from copy import copy
+from functools import reduce
+from typing import Any, NamedTuple, Optional, TypeVar
+
+from typing_extensions import Self
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types import settings_types
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry, SettingsEntry, ValidationEntry
+from bot.exts.filtering._utils import FieldRequiring
+from bot.log import get_logger
+
+TSettings = TypeVar("TSettings", bound="Settings")
+
+log = get_logger(__name__)
+
+_already_warned: set[str] = set()
+
+T = TypeVar("T", bound=SettingsEntry)
+
+
+def create_settings(
+ settings_data: dict, *, defaults: Defaults | None = None, keep_empty: bool = False
+) -> tuple[Optional[ActionSettings], Optional[ValidationSettings]]:
+ """
+ Create and return instances of the Settings subclasses from the given data.
+
+ Additionally, warn for data entries with no matching class.
+
+ In case these are setting overrides, the defaults can be provided to keep track of the correct values.
+ """
+ action_data = {}
+ validation_data = {}
+ for entry_name, entry_data in settings_data.items():
+ if entry_name in settings_types["ActionEntry"]:
+ action_data[entry_name] = entry_data
+ elif entry_name in settings_types["ValidationEntry"]:
+ validation_data[entry_name] = entry_data
+ elif entry_name not in _already_warned:
+ log.warning(
+ f"A setting named {entry_name} was loaded from the database, but no matching class."
+ )
+ _already_warned.add(entry_name)
+ if defaults is None:
+ default_actions = None
+ default_validations = None
+ else:
+ default_actions, default_validations = defaults
+ return (
+ ActionSettings.create(action_data, defaults=default_actions, keep_empty=keep_empty),
+ ValidationSettings.create(validation_data, defaults=default_validations, keep_empty=keep_empty)
+ )
+
+
+class Settings(FieldRequiring, dict[str, T]):
+ """
+ A collection of settings.
+
+ For processing the settings parts in the database and evaluating them on given contexts.
+
+ Each filter list and filter has its own settings.
+
+ A filter doesn't have to have its own settings. For every undefined setting, it falls back to the value defined in
+ the filter list which contains the filter.
+ """
+
+ entry_type = T
+
+ _already_warned: set[str] = set()
+
+ @abstractmethod
+ def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False):
+ super().__init__()
+
+ entry_classes = settings_types.get(self.entry_type.__name__)
+ for entry_name, entry_data in settings_data.items():
+ try:
+ entry_cls = entry_classes[entry_name]
+ except KeyError:
+ if entry_name not in self._already_warned:
+ log.warning(
+ f"A setting named {entry_name} was loaded from the database, "
+ f"but no matching {self.entry_type.__name__} class."
+ )
+ self._already_warned.add(entry_name)
+ else:
+ try:
+ entry_defaults = None if defaults is None else defaults[entry_name]
+ new_entry = entry_cls.create(
+ entry_data, defaults=entry_defaults, keep_empty=keep_empty
+ )
+ if new_entry:
+ self[entry_name] = new_entry
+ except TypeError as e:
+ raise TypeError(
+ f"Attempted to load a {entry_name} setting, but the response is malformed: {entry_data}"
+ ) from e
+
+ @property
+ def overrides(self) -> dict[str, Any]:
+ """Return a dictionary of overrides across all entries."""
+ return reduce(operator.or_, (entry.overrides for entry in self.values() if entry), {})
+
+ def copy(self: TSettings) -> TSettings:
+ """Create a shallow copy of the object."""
+ return copy(self)
+
+ def get_setting(self, key: str, default: Optional[Any] = None) -> Any:
+ """Get the setting matching the key, or fall back to the default value if the key is missing."""
+ for entry in self.values():
+ if hasattr(entry, key):
+ return getattr(entry, key)
+ return default
+
+ @classmethod
+ def create(
+ cls, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False
+ ) -> Optional[Settings]:
+ """
+ Returns a Settings object from `settings_data` if it holds any value, None otherwise.
+
+ Use this method to create Settings objects instead of the init.
+ The None value is significant for how a filter list iterates over its filters.
+ """
+ settings = cls(settings_data, defaults=defaults, keep_empty=keep_empty)
+ # If an entry doesn't hold any values, its `create` method will return None.
+ # If all entries are None, then the settings object holds no values.
+ if not keep_empty and not any(settings.values()):
+ return None
+
+ return settings
+
+
+class ValidationSettings(Settings[ValidationEntry]):
+ """
+ A collection of validation settings.
+
+ A filter is triggered only if all of its validation settings (e.g whether to invoke in DM) approve
+ (the check returns True).
+ """
+
+ entry_type = ValidationEntry
+
+ def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False):
+ super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty)
+
+ def evaluate(self, ctx: FilterContext) -> tuple[set[str], set[str]]:
+ """Evaluates for each setting whether the context is relevant to the filter."""
+ passed = set()
+ failed = set()
+
+ for name, validation in self.items():
+ if validation:
+ if validation.triggers_on(ctx):
+ passed.add(name)
+ else:
+ failed.add(name)
+
+ return passed, failed
+
+
+class ActionSettings(Settings[ActionEntry]):
+ """
+ A collection of action settings.
+
+ If a filter is triggered, its action settings (e.g how to infract the user) are combined with the action settings of
+ other triggered filters in the same event, and action is taken according to the combined action settings.
+ """
+
+ entry_type = ActionEntry
+
+ def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False):
+ super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty)
+
+ def union(self, other: Self) -> Self:
+ """Combine the entries of two collections of settings into a new ActionsSettings."""
+ actions = {}
+ # A settings object doesn't necessarily have all types of entries (e.g in the case of filter overrides).
+ for entry in self:
+ if entry in other:
+ actions[entry] = self[entry].union(other[entry])
+ else:
+ actions[entry] = self[entry]
+ for entry in other:
+ if entry not in actions:
+ actions[entry] = other[entry]
+
+ result = ActionSettings({})
+ result.update(actions)
+ return result
+
+ async def action(self, ctx: FilterContext) -> None:
+ """Execute the action of every action entry stored, as well as any additional actions in the context."""
+ for entry in self.values():
+ try:
+ await entry.action(ctx)
+ # Filtering should not stop even if one type of action raised an exception.
+ # For example, if deleting the message raised somehow, it should still try to infract the user.
+ except Exception:
+ log.exception(traceback.format_exc())
+
+ for action in ctx.additional_actions:
+ try:
+ await action(ctx)
+ except Exception:
+ log.exception(traceback.format_exc())
+
+ def fallback_to(self, fallback: ActionSettings) -> ActionSettings:
+ """Fill in missing entries from `fallback`."""
+ new_actions = self.copy()
+ for entry_name, entry_value in fallback.items():
+ if entry_name not in self:
+ new_actions[entry_name] = entry_value
+ return new_actions
+
+
+class Defaults(NamedTuple):
+ """Represents an atomic list's default settings."""
+
+ actions: ActionSettings
+ validations: ValidationSettings
+
+ def dict(self) -> dict[str, Any]:
+ """Return a dict representation of the stored fields across all entries."""
+ dict_ = {}
+ for settings in self:
+ dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_)
+ return dict_
diff --git a/bot/exts/filtering/_settings_types/__init__.py b/bot/exts/filtering/_settings_types/__init__.py
new file mode 100644
index 000000000..61b5737d4
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/__init__.py
@@ -0,0 +1,9 @@
+from bot.exts.filtering._settings_types.actions import action_types
+from bot.exts.filtering._settings_types.validations import validation_types
+
+settings_types = {
+ "ActionEntry": {settings_type.name: settings_type for settings_type in action_types},
+ "ValidationEntry": {settings_type.name: settings_type for settings_type in validation_types}
+}
+
+__all__ = [settings_types]
diff --git a/bot/exts/filtering/_settings_types/actions/__init__.py b/bot/exts/filtering/_settings_types/actions/__init__.py
new file mode 100644
index 000000000..a8175b976
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/actions/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname
+
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry
+from bot.exts.filtering._utils import subclasses_in_package
+
+action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry)
+
+__all__ = [action_types]
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
new file mode 100644
index 000000000..5ae4901b6
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
@@ -0,0 +1,204 @@
+from datetime import timedelta
+from enum import Enum, auto
+from typing import ClassVar
+
+import arrow
+import discord.abc
+from discord import Colour, Embed, Member, User
+from discord.errors import Forbidden
+from pydantic import validator
+from pydis_core.utils.logging import get_logger
+from pydis_core.utils.members import get_or_fetch_member
+from typing_extensions import Self
+
+import bot as bot_module
+from bot.constants import Channels
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry
+from bot.exts.filtering._utils import FakeContext
+
+log = get_logger(__name__)
+
+passive_form = {
+ "BAN": "banned",
+ "KICK": "kicked",
+ "TIMEOUT": "timed out",
+ "VOICE_MUTE": "voice muted",
+ "SUPERSTAR": "superstarred",
+ "WARNING": "warned",
+ "WATCH": "watch",
+ "NOTE": "noted",
+}
+
+
+class Infraction(Enum):
+ """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy."""
+
+ BAN = auto()
+ KICK = auto()
+ TIMEOUT = auto()
+ VOICE_MUTE = auto()
+ SUPERSTAR = auto()
+ WARNING = auto()
+ WATCH = auto()
+ NOTE = auto()
+ NONE = auto()
+
+ def __str__(self) -> str:
+ return self.name
+
+ async def invoke(
+ self,
+ user: Member | User,
+ message: discord.Message,
+ channel: discord.abc.GuildChannel | discord.DMChannel,
+ alerts_channel: discord.TextChannel,
+ duration: float,
+ reason: str
+ ) -> None:
+ """Invokes the command matching the infraction name."""
+ command_name = self.name.lower()
+ command = bot_module.instance.get_command(command_name)
+ if not command:
+ await alerts_channel.send(f":warning: Could not apply {command_name} to {user.mention}: command not found.")
+ log.warning(f":warning: Could not apply {command_name} to {user.mention}: command not found.")
+ return
+
+ if isinstance(user, discord.User): # For example because a message was sent in a DM.
+ member = await get_or_fetch_member(channel.guild, user.id)
+ if member:
+ user = member
+ ctx = FakeContext(message, channel, command)
+ if self.name in ("KICK", "WARNING", "WATCH", "NOTE"):
+ await command(ctx, user, reason=reason or None)
+ else:
+ duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None
+ await command(ctx, user, duration, reason=reason or None)
+
+
+class InfractionAndNotification(ActionEntry):
+ """
+ A setting entry which specifies what infraction to issue and the notification to DM the user.
+
+ Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together.
+ """
+
+ name: ClassVar[str] = "infraction_and_notification"
+ description: ClassVar[dict[str, str]] = {
+ "infraction_type": (
+ "The type of infraction to issue when the filter triggers, or 'NONE'. "
+ "If two infractions are triggered for the same message, "
+ "the harsher one will be applied (by type or duration).\n\n"
+ "Valid infraction types in order of harshness: "
+ ) + ", ".join(infraction.name for infraction in Infraction),
+ "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.",
+ "infraction_reason": "The reason delivered with the infraction.",
+ "infraction_channel": (
+ "The channel ID in which to invoke the infraction (and send the confirmation message). "
+ "If 0, the infraction will be sent in the context channel. If the ID otherwise fails to resolve, "
+ "it will default to the mod-alerts channel."
+ ),
+ "dm_content": "The contents of a message to be DMed to the offending user.",
+ "dm_embed": "The contents of the embed to be DMed to the offending user."
+ }
+
+ dm_content: str
+ dm_embed: str
+ infraction_type: Infraction
+ infraction_reason: str
+ infraction_duration: float
+ infraction_channel: int
+
+ @validator("infraction_type", pre=True)
+ @classmethod
+ def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction:
+ """Convert the string to an Infraction by name."""
+ if isinstance(infr_type, Infraction):
+ return infr_type
+ return Infraction[infr_type.replace(" ", "_").upper()]
+
+ async def action(self, ctx: FilterContext) -> None:
+ """Send the notification to the user, and apply any specified infractions."""
+ # If there is no infraction to apply, any DM contents already provided in the context take precedence.
+ if self.infraction_type == Infraction.NONE and (ctx.dm_content or ctx.dm_embed):
+ dm_content = ctx.dm_content
+ dm_embed = ctx.dm_embed
+ else:
+ dm_content = self.dm_content
+ dm_embed = self.dm_embed
+
+ if dm_content or dm_embed:
+ formatting = {"domain": ctx.notification_domain}
+ dm_content = f"Hey {ctx.author.mention}!\n{dm_content.format(**formatting)}"
+ if dm_embed:
+ dm_embed = Embed(description=dm_embed.format(**formatting), colour=Colour.og_blurple())
+ else:
+ dm_embed = None
+
+ try:
+ await ctx.author.send(dm_content, embed=dm_embed)
+ ctx.action_descriptions.append("notified")
+ except Forbidden:
+ ctx.action_descriptions.append("failed to notify")
+
+ if self.infraction_type != Infraction.NONE:
+ alerts_channel = bot_module.instance.get_channel(Channels.mod_alerts)
+ if self.infraction_channel:
+ channel = bot_module.instance.get_channel(self.infraction_channel)
+ if not channel:
+ log.info(f"Could not find a channel with ID {self.infraction_channel}, infracting in mod-alerts.")
+ channel = alerts_channel
+ elif not ctx.channel:
+ channel = alerts_channel
+ else:
+ channel = ctx.channel
+ if not channel: # If somehow it's set to `alerts_channel` and it can't be found.
+ log.error(f"Unable to apply infraction as the context channel {channel} can't be found.")
+ return
+
+ await self.infraction_type.invoke(
+ ctx.author, ctx.message, channel, alerts_channel, self.infraction_duration, self.infraction_reason
+ )
+ ctx.action_descriptions.append(passive_form[self.infraction_type.name])
+
+ def union(self, other: Self) -> Self:
+ """
+ Combines two actions of the same type. Each type of action is executed once per filter.
+
+ If the infractions are different, take the data of the one higher up the hierarchy.
+
+ There is no clear way to properly combine several notification messages, especially when it's in two parts.
+ To avoid bombarding the user with several notifications, the message with the more significant infraction
+ is used. If the more significant infraction has no accompanying message, use the one from the other infraction,
+ if it exists.
+ """
+ # Lower number -> higher in the hierarchy
+ if self.infraction_type is None:
+ return other.copy()
+ elif other.infraction_type is None:
+ return self.copy()
+
+ if self.infraction_type.value < other.infraction_type.value:
+ result = self.copy()
+ elif self.infraction_type.value > other.infraction_type.value:
+ result = other.copy()
+ other = self
+ else:
+ if self.infraction_duration is None or (
+ other.infraction_duration is not None and self.infraction_duration > other.infraction_duration
+ ):
+ result = self.copy()
+ else:
+ result = other.copy()
+ other = self
+
+ # If the winner has no message but the loser does, copy the message to the winner.
+ result_overrides = result.overrides
+ if "dm_content" not in result_overrides and "dm_embed" not in result_overrides:
+ other_overrides = other.overrides
+ if "dm_content" in other_overrides:
+ result.dm_content = other_overrides["dm_content"]
+ if "dm_embed" in other_overrides:
+ result.dm_content = other_overrides["dm_embed"]
+
+ return result
diff --git a/bot/exts/filtering/_settings_types/actions/ping.py b/bot/exts/filtering/_settings_types/actions/ping.py
new file mode 100644
index 000000000..ee40c54fe
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/actions/ping.py
@@ -0,0 +1,45 @@
+from typing import ClassVar
+
+from pydantic import validator
+from typing_extensions import Self
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry
+from bot.exts.filtering._utils import resolve_mention
+
+
+class Ping(ActionEntry):
+ """A setting entry which adds the appropriate pings to the alert."""
+
+ name: ClassVar[str] = "mentions"
+ description: ClassVar[dict[str, str]] = {
+ "guild_pings": (
+ "A list of role IDs/role names/user IDs/user names/here/everyone. "
+ "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged."
+ ),
+ "dm_pings": (
+ "A list of role IDs/role names/user IDs/user names/here/everyone. "
+ "If a mod-alert is generated for a filter triggered in DMs, these will be pinged."
+ )
+ }
+
+ guild_pings: set[str]
+ dm_pings: set[str]
+
+ @validator("*", pre=True)
+ @classmethod
+ def init_sequence_if_none(cls, pings: list[str] | None) -> list[str]:
+ """Initialize an empty sequence if the value is None."""
+ if pings is None:
+ return []
+ return pings
+
+ async def action(self, ctx: FilterContext) -> None:
+ """Add the stored pings to the alert message content."""
+ mentions = self.guild_pings if not ctx.channel or ctx.channel.guild else self.dm_pings
+ new_content = " ".join([resolve_mention(mention) for mention in mentions])
+ ctx.alert_content = f"{new_content} {ctx.alert_content}"
+
+ def union(self, other: Self) -> Self:
+ """Combines two actions of the same type. Each type of action is executed once per filter."""
+ return Ping(guild_pings=self.guild_pings | other.guild_pings, dm_pings=self.dm_pings | other.dm_pings)
diff --git a/bot/exts/filtering/_settings_types/actions/remove_context.py b/bot/exts/filtering/_settings_types/actions/remove_context.py
new file mode 100644
index 000000000..7ead88818
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/actions/remove_context.py
@@ -0,0 +1,113 @@
+from collections import defaultdict
+from typing import ClassVar
+
+from discord import Message
+from discord.errors import HTTPException
+from pydis_core.utils import scheduling
+from pydis_core.utils.logging import get_logger
+from typing_extensions import Self
+
+import bot
+from bot.constants import Channels
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry
+from bot.exts.filtering._utils import FakeContext
+from bot.utils.messages import send_attachments
+
+log = get_logger(__name__)
+
+SUPERSTAR_REASON = (
+ "Your nickname was found to be in violation of our code of conduct. "
+ "If you believe this is a mistake, please let us know."
+)
+
+
+async def upload_messages_attachments(ctx: FilterContext, messages: list[Message]) -> None:
+ """Re-upload the messages' attachments for future logging."""
+ if not messages:
+ return
+ destination = messages[0].guild.get_channel(Channels.attachment_log)
+ for message in messages:
+ if message.attachments and message.id not in ctx.attachments:
+ ctx.attachments[message.id] = await send_attachments(message, destination, link_large=False)
+
+
+class RemoveContext(ActionEntry):
+ """A setting entry which tells whether to delete the offending message(s)."""
+
+ name: ClassVar[str] = "remove_context"
+ description: ClassVar[str] = (
+ "A boolean field. If True, the filter being triggered will cause the offending context to be removed. "
+ "An offending message will be deleted, while an offending nickname will be superstarified."
+ )
+
+ remove_context: bool
+
+ async def action(self, ctx: FilterContext) -> None:
+ """Remove the offending context."""
+ if not self.remove_context:
+ return
+
+ if ctx.event in (Event.MESSAGE, Event.MESSAGE_EDIT):
+ await self._handle_messages(ctx)
+ elif ctx.event == Event.NICKNAME:
+ await self._handle_nickname(ctx)
+
+ @staticmethod
+ async def _handle_messages(ctx: FilterContext) -> None:
+ """Delete any messages involved in this context."""
+ if not ctx.message or not ctx.message.guild:
+ return
+
+ # If deletion somehow fails at least this will allow scheduling for deletion.
+ ctx.messages_deletion = True
+ channel_messages = defaultdict(set) # Duplicates will cause batch deletion to fail.
+ for message in {ctx.message} | ctx.related_messages:
+ channel_messages[message.channel].add(message)
+
+ success = fail = 0
+ deleted = list()
+ for channel, messages in channel_messages.items():
+ try:
+ await channel.delete_messages(messages)
+ except HTTPException:
+ fail += len(messages)
+ else:
+ success += len(messages)
+ deleted.extend(messages)
+ scheduling.create_task(upload_messages_attachments(ctx, deleted))
+
+ if not fail:
+ if success == 1:
+ ctx.action_descriptions.append("deleted")
+ else:
+ ctx.action_descriptions.append("deleted all")
+ elif not success:
+ if fail == 1:
+ ctx.action_descriptions.append("failed to delete")
+ else:
+ ctx.action_descriptions.append("all failed to delete")
+ else:
+ ctx.action_descriptions.append(f"{success} deleted, {fail} failed to delete")
+
+ @staticmethod
+ async def _handle_nickname(ctx: FilterContext) -> None:
+ """Apply a superstar infraction to remove the user's nickname."""
+ alerts_channel = bot.instance.get_channel(Channels.mod_alerts)
+ if not alerts_channel:
+ log.error(f"Unable to apply superstar as the context channel {alerts_channel} can't be found.")
+ return
+ command = bot.instance.get_command("superstar")
+ if not command:
+ user = ctx.author
+ await alerts_channel.send(f":warning: Could not apply superstar to {user.mention}: command not found.")
+ log.warning(f":warning: Could not apply superstar to {user.mention}: command not found.")
+ ctx.action_descriptions.append("failed to superstar")
+ return
+
+ await command(FakeContext(ctx.message, alerts_channel, command), ctx.author, None, reason=SUPERSTAR_REASON)
+ ctx.action_descriptions.append("superstar")
+
+ def union(self, other: Self) -> Self:
+ """Combines two actions of the same type. Each type of action is executed once per filter."""
+ return RemoveContext(remove_context=self.remove_context or other.remove_context)
diff --git a/bot/exts/filtering/_settings_types/actions/send_alert.py b/bot/exts/filtering/_settings_types/actions/send_alert.py
new file mode 100644
index 000000000..f554cdd4d
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/actions/send_alert.py
@@ -0,0 +1,23 @@
+from typing import ClassVar
+
+from typing_extensions import Self
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ActionEntry
+
+
+class SendAlert(ActionEntry):
+ """A setting entry which tells whether to send an alert message."""
+
+ name: ClassVar[str] = "send_alert"
+ description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created."
+
+ send_alert: bool
+
+ async def action(self, ctx: FilterContext) -> None:
+ """Add the stored pings to the alert message content."""
+ ctx.send_alert = self.send_alert
+
+ def union(self, other: Self) -> Self:
+ """Combines two actions of the same type. Each type of action is executed once per filter."""
+ return SendAlert(send_alert=self.send_alert or other.send_alert)
diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py
new file mode 100644
index 000000000..e41ef5c7a
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/settings_entry.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing import Any, ClassVar, Union
+
+from pydantic import BaseModel, PrivateAttr
+from typing_extensions import Self
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._utils import FieldRequiring
+
+
+class SettingsEntry(BaseModel, FieldRequiring):
+ """
+ A basic entry in the settings field appearing in every filter list and filter.
+
+ For a filter list, this is the default setting for it. For a filter, it's an override of the default entry.
+ """
+
+ # Each subclass must define a name matching the entry name we're expecting to receive from the database.
+ # Names must be unique across all filter lists.
+ name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE
+ # Each subclass must define a description of what it does. If the data an entry type receives comprises
+ # several DB fields, the value should a dictionary of field names and their descriptions.
+ description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET
+
+ _overrides: set[str] = PrivateAttr(default_factory=set)
+
+ def __init__(self, defaults: SettingsEntry | None = None, /, **data):
+ overrides = set()
+ if defaults:
+ defaults_dict = defaults.dict()
+ for field_name, field_value in list(data.items()):
+ if field_value is None:
+ data[field_name] = defaults_dict[field_name]
+ else:
+ overrides.add(field_name)
+ super().__init__(**data)
+ self._overrides |= overrides
+
+ @property
+ def overrides(self) -> dict[str, Any]:
+ """Return a dictionary of overrides."""
+ return {name: getattr(self, name) for name in self._overrides}
+
+ @classmethod
+ def create(
+ cls, entry_data: dict[str, Any] | None, *, defaults: SettingsEntry | None = None, keep_empty: bool = False
+ ) -> SettingsEntry | None:
+ """
+ Returns a SettingsEntry object from `entry_data` if it holds any value, None otherwise.
+
+ Use this method to create SettingsEntry objects instead of the init.
+ The None value is significant for how a filter list iterates over its filters.
+ """
+ if entry_data is None:
+ return None
+ if not keep_empty and hasattr(entry_data, "values") and all(value is None for value in entry_data.values()):
+ return None
+
+ if not isinstance(entry_data, dict):
+ entry_data = {cls.name: entry_data}
+ return cls(defaults, **entry_data)
+
+
+class ValidationEntry(SettingsEntry):
+ """A setting entry to validate whether the filter should be triggered in the given context."""
+
+ @abstractmethod
+ def triggers_on(self, ctx: FilterContext) -> bool:
+ """Return whether the filter should be triggered with this setting in the given context."""
+ ...
+
+
+class ActionEntry(SettingsEntry):
+ """A setting entry defining what the bot should do if the filter it belongs to is triggered."""
+
+ @abstractmethod
+ async def action(self, ctx: FilterContext) -> None:
+ """Execute an action that should be taken when the filter this setting belongs to is triggered."""
+ ...
+
+ @abstractmethod
+ def union(self, other: Self) -> Self:
+ """
+ Combine two actions of the same type. Each type of action is executed once per filter.
+
+ The following condition must hold: if self == other, then self | other == self.
+ """
+ ...
diff --git a/bot/exts/filtering/_settings_types/validations/__init__.py b/bot/exts/filtering/_settings_types/validations/__init__.py
new file mode 100644
index 000000000..5c44e8b27
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/validations/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname
+
+from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
+from bot.exts.filtering._utils import subclasses_in_package
+
+validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry)
+
+__all__ = [validation_types]
diff --git a/bot/exts/filtering/_settings_types/validations/bypass_roles.py b/bot/exts/filtering/_settings_types/validations/bypass_roles.py
new file mode 100644
index 000000000..d42e6407c
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/validations/bypass_roles.py
@@ -0,0 +1,24 @@
+from typing import ClassVar, Union
+
+from discord import Member
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
+
+
+class RoleBypass(ValidationEntry):
+ """A setting entry which tells whether the roles the member has allow them to bypass the filter."""
+
+ name: ClassVar[str] = "bypass_roles"
+ description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter."
+
+ bypass_roles: set[Union[int, str]]
+
+ def triggers_on(self, ctx: FilterContext) -> bool:
+ """Return whether the filter should be triggered on this user given their roles."""
+ if not isinstance(ctx.author, Member):
+ return True
+ return all(
+ member_role.id not in self.bypass_roles and member_role.name not in self.bypass_roles
+ for member_role in ctx.author.roles
+ )
diff --git a/bot/exts/filtering/_settings_types/validations/channel_scope.py b/bot/exts/filtering/_settings_types/validations/channel_scope.py
new file mode 100644
index 000000000..d37efaa09
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/validations/channel_scope.py
@@ -0,0 +1,70 @@
+from typing import ClassVar, Union
+
+from pydantic import validator
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
+
+
+class ChannelScope(ValidationEntry):
+ """A setting entry which tells whether the filter was invoked in a whitelisted channel or category."""
+
+ name: ClassVar[str] = "channel_scope"
+ description: ClassVar[str] = {
+ "disabled_channels": (
+ "A list of channel IDs or channel names. "
+ "The filter will not trigger in these channels even if the category is expressly enabled."
+ ),
+ "disabled_categories": (
+ "A list of category IDs or category names. The filter will not trigger in these categories."
+ ),
+ "enabled_channels": (
+ "A list of channel IDs or channel names. "
+ "The filter can trigger in these channels even if the category is disabled or not expressly enabled."
+ ),
+ "enabled_categories": (
+ "A list of category IDs or category names. "
+ "If the list is not empty, filters will trigger only in channels of these categories, "
+ "unless the channel is expressly disabled."
+ )
+ }
+
+ disabled_channels: set[Union[int, str]]
+ disabled_categories: set[Union[int, str]]
+ enabled_channels: set[Union[int, str]]
+ enabled_categories: set[Union[int, str]]
+
+ @validator("*", pre=True)
+ @classmethod
+ def init_if_sequence_none(cls, sequence: list[str] | None) -> list[str]:
+ """Initialize an empty sequence if the value is None."""
+ if sequence is None:
+ return []
+ return sequence
+
+ def triggers_on(self, ctx: FilterContext) -> bool:
+ """
+ Return whether the filter should be triggered in the given channel.
+
+ The filter is invoked by default.
+ If the channel is explicitly enabled, it bypasses the set disabled channels and categories.
+ """
+ channel = ctx.channel
+
+ if not channel:
+ return True
+ if not hasattr(channel, "category"): # This is not a guild channel, outside the scope of this setting.
+ return True
+ if hasattr(channel, "parent"):
+ channel = channel.parent
+
+ enabled_channel = channel.id in self.enabled_channels or channel.name in self.enabled_channels
+ disabled_channel = channel.id in self.disabled_channels or channel.name in self.disabled_channels
+ enabled_category = channel.category and (not self.enabled_categories or (
+ channel.category.id in self.enabled_categories or channel.category.name in self.enabled_categories
+ ))
+ disabled_category = channel.category and (
+ channel.category.id in self.disabled_categories or channel.category.name in self.disabled_categories
+ )
+
+ return enabled_channel or (enabled_category and not disabled_channel and not disabled_category)
diff --git a/bot/exts/filtering/_settings_types/validations/enabled.py b/bot/exts/filtering/_settings_types/validations/enabled.py
new file mode 100644
index 000000000..3b5e3e446
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/validations/enabled.py
@@ -0,0 +1,19 @@
+from typing import ClassVar
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
+
+
+class Enabled(ValidationEntry):
+ """A setting entry which tells whether the filter is enabled."""
+
+ name: ClassVar[str] = "enabled"
+ description: ClassVar[str] = (
+ "A boolean field. Setting it to False allows disabling the filter without deleting it entirely."
+ )
+
+ enabled: bool
+
+ def triggers_on(self, ctx: FilterContext) -> bool:
+ """Return whether the filter is enabled."""
+ return self.enabled
diff --git a/bot/exts/filtering/_settings_types/validations/filter_dm.py b/bot/exts/filtering/_settings_types/validations/filter_dm.py
new file mode 100644
index 000000000..9961984d6
--- /dev/null
+++ b/bot/exts/filtering/_settings_types/validations/filter_dm.py
@@ -0,0 +1,20 @@
+from typing import ClassVar
+
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
+
+
+class FilterDM(ValidationEntry):
+ """A setting entry which tells whether to apply the filter to DMs."""
+
+ name: ClassVar[str] = "filter_dm"
+ description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs."
+
+ filter_dm: bool
+
+ def triggers_on(self, ctx: FilterContext) -> bool:
+ """Return whether the filter should be triggered even if it was triggered in DMs."""
+ if not ctx.channel: # No channel - out of scope for this setting.
+ return True
+
+ return ctx.channel.guild is not None or self.filter_dm
diff --git a/bot/exts/filtering/_ui/__init__.py b/bot/exts/filtering/_ui/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/bot/exts/filtering/_ui/__init__.py
diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py
new file mode 100644
index 000000000..1ef25f17a
--- /dev/null
+++ b/bot/exts/filtering/_ui/filter.py
@@ -0,0 +1,464 @@
+from __future__ import annotations
+
+from typing import Any, Callable
+
+import discord
+import discord.ui
+from discord import Embed, Interaction, User
+from discord.ext.commands import BadArgument
+from discord.ui.select import SelectOption
+from pydis_core.site_api import ResponseCodeError
+
+from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._ui.ui import (
+ COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MAX_EMBED_DESCRIPTION, MISSING, SETTINGS_DELIMITER,
+ SINGLE_SETTING_PATTERN, format_response_error, parse_value, populate_embed_from_dict
+)
+from bot.exts.filtering._utils import repr_equals, to_serializable
+from bot.log import get_logger
+
+log = get_logger(__name__)
+
+
+def build_filter_repr_dict(
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ settings_overrides: dict,
+ extra_fields_overrides: dict
+) -> dict:
+ """Build a dictionary of field names and values to pass to `populate_embed_from_dict`."""
+ # Get filter list settings
+ default_setting_values = {}
+ for settings_group in filter_list[list_type].defaults:
+ for _, setting in settings_group.items():
+ default_setting_values.update(to_serializable(setting.dict()))
+
+ # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.
+ total_values = {}
+ for name, value in default_setting_values.items():
+ if name not in settings_overrides or repr_equals(settings_overrides[name], value):
+ total_values[name] = value
+ else:
+ total_values[f"{name}*"] = settings_overrides[name]
+
+ # Add the filter-specific settings.
+ if filter_type.extra_fields_type:
+ # This iterates over the default values of the extra fields model.
+ for name, value in filter_type.extra_fields_type().dict().items():
+ if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value):
+ total_values[f"{filter_type.name}/{name}"] = value
+ else:
+ total_values[f"{filter_type.name}/{name}*"] = extra_fields_overrides[name]
+
+ return total_values
+
+
+class EditContentModal(discord.ui.Modal, title="Edit Content"):
+ """A modal to input a filter's content."""
+
+ content = discord.ui.TextInput(label="Content")
+
+ def __init__(self, embed_view: FilterEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new content."""
+ await interaction.response.defer()
+ await self.embed_view.update_embed(self.message, content=self.content.value)
+
+
+class EditDescriptionModal(discord.ui.Modal, title="Edit Description"):
+ """A modal to input a filter's description."""
+
+ description = discord.ui.TextInput(label="Description")
+
+ def __init__(self, embed_view: FilterEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await interaction.response.defer()
+ await self.embed_view.update_embed(self.message, description=self.description.value)
+
+
+class TemplateModal(discord.ui.Modal, title="Template"):
+ """A modal to enter a filter ID to copy its overrides over."""
+
+ template = discord.ui.TextInput(label="Template Filter ID")
+
+ def __init__(self, embed_view: FilterEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await self.embed_view.apply_template(self.template.value, self.message, interaction)
+
+
+class FilterEditView(EditBaseView):
+ """A view used to edit a filter's settings before updating the database."""
+
+ class _REMOVE:
+ """Sentinel value for when an override should be removed."""
+
+ def __init__(
+ self,
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ content: str | None,
+ description: str | None,
+ settings_overrides: dict,
+ filter_settings_overrides: dict,
+ loaded_settings: dict,
+ loaded_filter_settings: dict,
+ author: User,
+ embed: Embed,
+ confirm_callback: Callable
+ ):
+ super().__init__(author)
+ self.filter_list = filter_list
+ self.list_type = list_type
+ self.filter_type = filter_type
+ self.content = content
+ self.description = description
+ self.settings_overrides = settings_overrides
+ self.filter_settings_overrides = filter_settings_overrides
+ self.loaded_settings = loaded_settings
+ self.loaded_filter_settings = loaded_filter_settings
+ self.embed = embed
+ self.confirm_callback = confirm_callback
+
+ all_settings_repr_dict = build_filter_repr_dict(
+ filter_list, list_type, filter_type, settings_overrides, filter_settings_overrides
+ )
+ populate_embed_from_dict(embed, all_settings_repr_dict)
+
+ self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()}
+ self.type_per_setting_name.update({
+ f"{filter_type.name}/{name}": type_
+ for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items()
+ })
+
+ add_select = CustomCallbackSelect(
+ self._prompt_new_value,
+ placeholder="Select a setting to edit",
+ options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)],
+ row=1
+ )
+ self.add_item(add_select)
+
+ if settings_overrides or filter_settings_overrides:
+ override_names = (
+ list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides]
+ )
+ remove_select = CustomCallbackSelect(
+ self._remove_override,
+ placeholder="Select an override to remove",
+ options=[SelectOption(label=name) for name in sorted(override_names)],
+ row=2
+ )
+ self.add_item(remove_select)
+
+ @discord.ui.button(label="Edit Content", row=3)
+ async def edit_content(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to edit the filter's content. Pressing the button invokes a modal."""
+ modal = EditContentModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="Edit Description", row=3)
+ async def edit_description(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to edit the filter's description. Pressing the button invokes a modal."""
+ modal = EditDescriptionModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="Empty Description", row=3)
+ async def empty_description(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to empty the filter's description."""
+ await self.update_embed(interaction, description=self._REMOVE)
+
+ @discord.ui.button(label="Template", row=3)
+ async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to enter a filter template ID and copy its overrides over."""
+ modal = TemplateModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=4)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Confirm the content, description, and settings, and update the filters database."""
+ if self.content is None:
+ await interaction.response.send_message(
+ ":x: Cannot add a filter with no content.", ephemeral=True, reference=interaction.message
+ )
+ if self.description is None:
+ self.description = ""
+ await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first.
+ try:
+ await self.confirm_callback(
+ interaction.message,
+ self.filter_list,
+ self.list_type,
+ self.filter_type,
+ self.content,
+ self.description,
+ self.settings_overrides,
+ self.filter_settings_overrides
+ )
+ except ResponseCodeError as e:
+ await interaction.message.reply(embed=format_response_error(e))
+ await interaction.message.edit(view=self)
+ except BadArgument as e:
+ await interaction.message.reply(
+ embed=Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e))
+ )
+ await interaction.message.edit(view=self)
+ else:
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=4)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the operation."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None)
+ self.stop()
+
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+ if setting_name in self.settings_overrides:
+ return self.settings_overrides[setting_name]
+ if "/" in setting_name:
+ _, setting_name = setting_name.split("/", maxsplit=1)
+ if setting_name in self.filter_settings_overrides:
+ return self.filter_settings_overrides[setting_name]
+ return MISSING
+
+ async def update_embed(
+ self,
+ interaction_or_msg: discord.Interaction | discord.Message,
+ *,
+ content: str | None = None,
+ description: str | type[FilterEditView._REMOVE] | None = None,
+ setting_name: str | None = None,
+ setting_value: str | type[FilterEditView._REMOVE] | None = None,
+ ) -> None:
+ """
+ Update the embed with the new information.
+
+ If a setting name is provided with a _REMOVE value, remove the override.
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+ if content is not None or description is not None:
+ if content is not None:
+ filter_type = self.filter_list.get_filter_type(content)
+ if not filter_type:
+ if isinstance(interaction_or_msg, discord.Message):
+ send_method = interaction_or_msg.channel.send
+ else:
+ send_method = interaction_or_msg.response.send_message
+ await send_method(f":x: Could not find a filter type appropriate for `{content}`.")
+ return
+ self.content = content
+ self.filter_type = filter_type
+ else:
+ content = self.content # If there's no content or description, use the existing values.
+ if description is self._REMOVE:
+ self.description = None
+ elif description is not None:
+ self.description = description
+ else:
+ description = self.description
+
+ # Update the embed with the new content and/or description.
+ self.embed.description = f"`{content}`" if content else "*No content*"
+ if description and description is not self._REMOVE:
+ self.embed.description += f" - {description}"
+ if len(self.embed.description) > MAX_EMBED_DESCRIPTION:
+ self.embed.description = self.embed.description[:MAX_EMBED_DESCRIPTION - 5] + "[...]"
+
+ if setting_name:
+ # Find the right dictionary to update.
+ if "/" in setting_name:
+ filter_name, setting_name = setting_name.split("/", maxsplit=1)
+ dict_to_edit = self.filter_settings_overrides
+ default_value = self.filter_type.extra_fields_type().dict()[setting_name]
+ else:
+ dict_to_edit = self.settings_overrides
+ default_value = self.filter_list[self.list_type].default(setting_name)
+ # Update the setting override value or remove it
+ if setting_value is not self._REMOVE:
+ if not repr_equals(setting_value, default_value):
+ dict_to_edit[setting_name] = setting_value
+ # If there's already an override, remove it, since the new value is the same as the default.
+ elif setting_name in dict_to_edit:
+ dict_to_edit.pop(setting_name)
+ elif setting_name in dict_to_edit:
+ dict_to_edit.pop(setting_name)
+
+ # This is inefficient, but otherwise the selects go insane if the user attempts to edit the same setting
+ # multiple times, even when replacing the select with a new one.
+ self.embed.clear_fields()
+ new_view = self.copy()
+
+ try:
+ if isinstance(interaction_or_msg, discord.Interaction):
+ await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view)
+ else:
+ await interaction_or_msg.edit(embed=self.embed, view=new_view)
+ except discord.errors.HTTPException: # Various unexpected errors.
+ pass
+ else:
+ self.stop()
+
+ async def edit_setting_override(self, interaction: Interaction, setting_name: str, override_value: Any) -> None:
+ """
+ Update the overrides with the new value and edit the embed.
+
+ The interaction needs to be the selection of the setting attached to the embed.
+ """
+ await self.update_embed(interaction, setting_name=setting_name, setting_value=override_value)
+
+ async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None:
+ """Replace any non-overridden settings with overrides from the given filter."""
+ try:
+ settings, filter_settings = template_settings(
+ template_id, self.filter_list, self.list_type, self.filter_type
+ )
+ except BadArgument as e: # The interaction object is necessary to send an ephemeral message.
+ await interaction.response.send_message(f":x: {e}", ephemeral=True)
+ return
+ else:
+ await interaction.response.defer()
+
+ self.settings_overrides = settings | self.settings_overrides
+ self.filter_settings_overrides = filter_settings | self.filter_settings_overrides
+ self.embed.clear_fields()
+ await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
+
+ async def _remove_override(self, interaction: Interaction, select: discord.ui.Select) -> None:
+ """
+ Remove the override for the setting the user selected, and edit the embed.
+
+ The interaction needs to be the selection of the setting attached to the embed.
+ """
+ await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE)
+
+ def copy(self) -> FilterEditView:
+ """Create a copy of this view."""
+ return FilterEditView(
+ self.filter_list,
+ self.list_type,
+ self.filter_type,
+ self.content,
+ self.description,
+ self.settings_overrides,
+ self.filter_settings_overrides,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ self.author,
+ self.embed,
+ self.confirm_callback
+ )
+
+
+def description_and_settings_converter(
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ loaded_settings: dict,
+ loaded_filter_settings: dict,
+ input_data: str
+) -> tuple[str, dict[str, Any], dict[str, Any]]:
+ """Parse a string representing a possible description and setting overrides, and validate the setting names."""
+ if not input_data:
+ return "", {}, {}
+
+ parsed = SETTINGS_DELIMITER.split(input_data)
+ if not parsed:
+ return "", {}, {}
+
+ description = ""
+ if not SINGLE_SETTING_PATTERN.match(parsed[0]):
+ description, *parsed = parsed
+
+ settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ template = None
+ if "--template" in settings:
+ template = settings.pop("--template")
+
+ filter_settings = {}
+ for setting, _ in list(settings.items()):
+ if setting in loaded_settings: # It's a filter list setting
+ type_ = loaded_settings[setting][2]
+ try:
+ parsed_value = parse_value(settings.pop(setting), type_)
+ if not repr_equals(parsed_value, filter_list[list_type].default(setting)):
+ settings[setting] = parsed_value
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+ elif "/" not in setting:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ else: # It's a filter setting
+ filter_name, filter_setting_name = setting.split("/", maxsplit=1)
+ if filter_name.lower() != filter_type.name.lower():
+ raise BadArgument(
+ f"A setting for a {filter_name!r} filter was provided, but the filter name is {filter_type.name!r}"
+ )
+ if filter_setting_name not in loaded_filter_settings[filter_type.name]:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2]
+ try:
+ parsed_value = parse_value(settings.pop(setting), type_)
+ if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)):
+ filter_settings[filter_setting_name] = parsed_value
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+
+ # Pull templates settings and apply them.
+ if template is not None:
+ try:
+ t_settings, t_filter_settings = template_settings(template, filter_list, list_type, filter_type)
+ except ValueError as e:
+ raise BadArgument(str(e))
+ else:
+ # The specified settings go on top of the template
+ settings = t_settings | settings
+ filter_settings = t_filter_settings | filter_settings
+
+ return description, settings, filter_settings
+
+
+def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]:
+ """Get a serializable version of the filter's overrides."""
+ overrides_values, extra_fields_overrides = filter_.overrides
+ return to_serializable(overrides_values), to_serializable(extra_fields_overrides)
+
+
+def template_settings(
+ filter_id: str, filter_list: FilterList, list_type: ListType, filter_type: type[Filter]
+) -> tuple[dict, dict]:
+ """Find the filter with specified ID, and return its settings."""
+ try:
+ filter_id = int(filter_id)
+ if filter_id < 0:
+ raise ValueError()
+ except ValueError:
+ raise BadArgument("Template value must be a non-negative integer.")
+
+ if filter_id not in filter_list[list_type].filters:
+ raise BadArgument(
+ f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list."
+ )
+ filter_ = filter_list[list_type].filters[filter_id]
+
+ if not isinstance(filter_, filter_type):
+ raise BadArgument(
+ f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}"
+ )
+ return filter_serializable_overrides(filter_)
diff --git a/bot/exts/filtering/_ui/filter_list.py b/bot/exts/filtering/_ui/filter_list.py
new file mode 100644
index 000000000..a4526f090
--- /dev/null
+++ b/bot/exts/filtering/_ui/filter_list.py
@@ -0,0 +1,271 @@
+from __future__ import annotations
+
+from typing import Any, Callable
+
+import discord
+from discord import Embed, Interaction, SelectOption, User
+from discord.ext.commands import BadArgument
+from pydis_core.site_api import ResponseCodeError
+
+from bot.exts.filtering._filter_lists import FilterList, ListType
+from bot.exts.filtering._ui.ui import (
+ CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, format_response_error, parse_value,
+ populate_embed_from_dict
+)
+from bot.exts.filtering._utils import repr_equals, to_serializable
+
+
+def settings_converter(loaded_settings: dict, input_data: str) -> dict[str, Any]:
+ """Parse a string representing settings, and validate the setting names."""
+ if not input_data:
+ return {}
+
+ parsed = SETTINGS_DELIMITER.split(input_data)
+ if not parsed:
+ return {}
+
+ try:
+ settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ except ValueError:
+ raise BadArgument("The settings provided are not in the correct format.")
+
+ for setting in settings:
+ if setting not in loaded_settings:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ else:
+ type_ = loaded_settings[setting][2]
+ try:
+ parsed_value = parse_value(settings.pop(setting), type_)
+ settings[setting] = parsed_value
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+
+ return settings
+
+
+def build_filterlist_repr_dict(filter_list: FilterList, list_type: ListType, new_settings: dict) -> dict:
+ """Build a dictionary of field names and values to pass to `_build_embed_from_dict`."""
+ # Get filter list settings
+ default_setting_values = {}
+ for settings_group in filter_list[list_type].defaults:
+ for _, setting in settings_group.items():
+ default_setting_values.update(to_serializable(setting.dict()))
+
+ # Add new values. It's done in this way to preserve field order, since the new_values won't have all settings.
+ total_values = {}
+ for name, value in default_setting_values.items():
+ if name not in new_settings or repr_equals(new_settings[name], value):
+ total_values[name] = value
+ else:
+ total_values[f"{name}~"] = new_settings[name]
+
+ return total_values
+
+
+class FilterListAddView(EditBaseView):
+ """A view used to add a new filter list."""
+
+ def __init__(
+ self,
+ list_name: str,
+ list_type: ListType,
+ settings: dict,
+ loaded_settings: dict,
+ author: User,
+ embed: Embed,
+ confirm_callback: Callable
+ ):
+ super().__init__(author)
+ self.list_name = list_name
+ self.list_type = list_type
+ self.settings = settings
+ self.loaded_settings = loaded_settings
+ self.embed = embed
+ self.confirm_callback = confirm_callback
+
+ self.settings_repr_dict = {name: to_serializable(value) for name, value in settings.items()}
+ populate_embed_from_dict(embed, self.settings_repr_dict)
+
+ self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()}
+
+ edit_select = CustomCallbackSelect(
+ self._prompt_new_value,
+ placeholder="Select a setting to edit",
+ options=[SelectOption(label=name) for name in sorted(settings)],
+ row=0
+ )
+ self.add_item(edit_select)
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=1)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Confirm the content, description, and settings, and update the filters database."""
+ await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first.
+ try:
+ await self.confirm_callback(interaction.message, self.list_name, self.list_type, self.settings)
+ except ResponseCodeError as e:
+ await interaction.message.reply(embed=format_response_error(e))
+ await interaction.message.edit(view=self)
+ else:
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=1)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the operation."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None)
+ self.stop()
+
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+ if setting_name in self.settings:
+ return self.settings[setting_name]
+ return MISSING
+
+ async def update_embed(
+ self,
+ interaction_or_msg: discord.Interaction | discord.Message,
+ *,
+ setting_name: str | None = None,
+ setting_value: str | None = None,
+ ) -> None:
+ """
+ Update the embed with the new information.
+
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+ if not setting_name: # Obligatory check to match the signature in the parent class.
+ return
+
+ self.settings[setting_name] = setting_value
+
+ self.embed.clear_fields()
+ new_view = self.copy()
+
+ try:
+ if isinstance(interaction_or_msg, discord.Interaction):
+ await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view)
+ else:
+ await interaction_or_msg.edit(embed=self.embed, view=new_view)
+ except discord.errors.HTTPException: # Various unexpected errors.
+ pass
+ else:
+ self.stop()
+
+ def copy(self) -> FilterListAddView:
+ """Create a copy of this view."""
+ return FilterListAddView(
+ self.list_name,
+ self.list_type,
+ self.settings,
+ self.loaded_settings,
+ self.author,
+ self.embed,
+ self.confirm_callback
+ )
+
+
+class FilterListEditView(EditBaseView):
+ """A view used to edit a filter list's settings before updating the database."""
+
+ def __init__(
+ self,
+ filter_list: FilterList,
+ list_type: ListType,
+ new_settings: dict,
+ loaded_settings: dict,
+ author: User,
+ embed: Embed,
+ confirm_callback: Callable
+ ):
+ super().__init__(author)
+ self.filter_list = filter_list
+ self.list_type = list_type
+ self.settings = new_settings
+ self.loaded_settings = loaded_settings
+ self.embed = embed
+ self.confirm_callback = confirm_callback
+
+ self.settings_repr_dict = build_filterlist_repr_dict(filter_list, list_type, new_settings)
+ populate_embed_from_dict(embed, self.settings_repr_dict)
+
+ self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()}
+
+ edit_select = CustomCallbackSelect(
+ self._prompt_new_value,
+ placeholder="Select a setting to edit",
+ options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)],
+ row=0
+ )
+ self.add_item(edit_select)
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=1)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Confirm the content, description, and settings, and update the filters database."""
+ await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first.
+ try:
+ await self.confirm_callback(interaction.message, self.filter_list, self.list_type, self.settings)
+ except ResponseCodeError as e:
+ await interaction.message.reply(embed=format_response_error(e))
+ await interaction.message.edit(view=self)
+ else:
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=1)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the operation."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None)
+ self.stop()
+
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+ if setting_name in self.settings:
+ return self.settings[setting_name]
+ if setting_name in self.settings_repr_dict:
+ return self.settings_repr_dict[setting_name]
+ return MISSING
+
+ async def update_embed(
+ self,
+ interaction_or_msg: discord.Interaction | discord.Message,
+ *,
+ setting_name: str | None = None,
+ setting_value: str | None = None,
+ ) -> None:
+ """
+ Update the embed with the new information.
+
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+ if not setting_name: # Obligatory check to match the signature in the parent class.
+ return
+
+ default_value = self.filter_list[self.list_type].default(setting_name)
+ if not repr_equals(setting_value, default_value):
+ self.settings[setting_name] = setting_value
+ # If there's already a new value, remove it, since the new value is the same as the default.
+ elif setting_name in self.settings:
+ self.settings.pop(setting_name)
+
+ self.embed.clear_fields()
+ new_view = self.copy()
+
+ try:
+ if isinstance(interaction_or_msg, discord.Interaction):
+ await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view)
+ else:
+ await interaction_or_msg.edit(embed=self.embed, view=new_view)
+ except discord.errors.HTTPException: # Various errors such as embed description being too long.
+ pass
+ else:
+ self.stop()
+
+ def copy(self) -> FilterListEditView:
+ """Create a copy of this view."""
+ return FilterListEditView(
+ self.filter_list,
+ self.list_type,
+ self.settings,
+ self.loaded_settings,
+ self.author,
+ self.embed,
+ self.confirm_callback
+ )
diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py
new file mode 100644
index 000000000..d553c28ea
--- /dev/null
+++ b/bot/exts/filtering/_ui/search.py
@@ -0,0 +1,365 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
+
+import discord
+from discord import Interaction, SelectOption
+from discord.ext.commands import BadArgument
+
+from bot.exts.filtering._filter_lists import FilterList, ListType
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._settings_types.settings_entry import SettingsEntry
+from bot.exts.filtering._ui.filter import filter_serializable_overrides
+from bot.exts.filtering._ui.ui import (
+ COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value,
+ populate_embed_from_dict
+)
+
+
+def search_criteria_converter(
+ filter_lists: dict,
+ loaded_filters: dict,
+ loaded_settings: dict,
+ loaded_filter_settings: dict,
+ filter_type: type[Filter] | None,
+ input_data: str
+) -> tuple[dict[str, Any], dict[str, Any], type[Filter]]:
+ """Parse a string representing setting overrides, and validate the setting names."""
+ if not input_data:
+ return {}, {}, filter_type
+
+ parsed = SETTINGS_DELIMITER.split(input_data)
+ if not parsed:
+ return {}, {}, filter_type
+
+ try:
+ settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ except ValueError:
+ raise BadArgument("The settings provided are not in the correct format.")
+
+ template = None
+ if "--template" in settings:
+ template = settings.pop("--template")
+
+ filter_settings = {}
+ for setting, _ in list(settings.items()):
+ if setting in loaded_settings: # It's a filter list setting
+ type_ = loaded_settings[setting][2]
+ try:
+ settings[setting] = parse_value(settings[setting], type_)
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+ elif "/" not in setting:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ else: # It's a filter setting
+ filter_name, filter_setting_name = setting.split("/", maxsplit=1)
+ if not filter_type:
+ if filter_name in loaded_filters:
+ filter_type = loaded_filters[filter_name]
+ else:
+ raise BadArgument(f"There's no filter type named {filter_name!r}.")
+ if filter_name.lower() != filter_type.name.lower():
+ raise BadArgument(
+ f"A setting for a {filter_name!r} filter was provided, "
+ f"but the filter name is {filter_type.name!r}"
+ )
+ if filter_setting_name not in loaded_filter_settings[filter_type.name]:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2]
+ try:
+ filter_settings[filter_setting_name] = parse_value(settings.pop(setting), type_)
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+
+ # Pull templates settings and apply them.
+ if template is not None:
+ try:
+ t_settings, t_filter_settings, filter_type = template_settings(template, filter_lists, filter_type)
+ except ValueError as e:
+ raise BadArgument(str(e))
+ else:
+ # The specified settings go on top of the template
+ settings = t_settings | settings
+ filter_settings = t_filter_settings | filter_settings
+
+ return settings, filter_settings, filter_type
+
+
+def get_filter(filter_id: int, filter_lists: dict) -> tuple[Filter, FilterList, ListType] | None:
+ """Return a filter with the specific filter_id, if found."""
+ for filter_list in filter_lists.values():
+ for list_type, sublist in filter_list.items():
+ if filter_id in sublist.filters:
+ return sublist.filters[filter_id], filter_list, list_type
+ return None
+
+
+def template_settings(
+ filter_id: str, filter_lists: dict, filter_type: type[Filter] | None
+) -> tuple[dict, dict, type[Filter]]:
+ """Find a filter with the specified ID and filter type, and return its settings and (maybe newly found) type."""
+ try:
+ filter_id = int(filter_id)
+ if filter_id < 0:
+ raise ValueError()
+ except ValueError:
+ raise BadArgument("Template value must be a non-negative integer.")
+
+ result = get_filter(filter_id, filter_lists)
+ if not result:
+ raise BadArgument(f"Could not find a filter with ID `{filter_id}`.")
+ filter_, filter_list, list_type = result
+
+ if filter_type and not isinstance(filter_, filter_type):
+ raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.")
+
+ settings, filter_settings = filter_serializable_overrides(filter_)
+ return settings, filter_settings, type(filter_)
+
+
+def build_search_repr_dict(
+ settings: dict[str, Any], filter_settings: dict[str, Any], filter_type: type[Filter] | None
+) -> dict:
+ """Build a dictionary of field names and values to pass to `populate_embed_from_dict`."""
+ total_values = settings.copy()
+ if filter_type:
+ for setting_name, value in filter_settings.items():
+ total_values[f"{filter_type.name}/{setting_name}"] = value
+
+ return total_values
+
+
+class SearchEditView(EditBaseView):
+ """A view used to edit the search criteria before performing the search."""
+
+ class _REMOVE:
+ """Sentinel value for when an override should be removed."""
+
+ def __init__(
+ self,
+ filter_type: type[Filter] | None,
+ settings: dict[str, Any],
+ filter_settings: dict[str, Any],
+ loaded_filter_lists: dict[str, FilterList],
+ loaded_filters: dict[str, type[Filter]],
+ loaded_settings: dict[str, tuple[str, SettingsEntry, type]],
+ loaded_filter_settings: dict[str, dict[str, tuple[str, SettingsEntry, type]]],
+ author: discord.User | discord.Member,
+ embed: discord.Embed,
+ confirm_callback: Callable
+ ):
+ super().__init__(author)
+ self.filter_type = filter_type
+ self.settings = settings
+ self.filter_settings = filter_settings
+ self.loaded_filter_lists = loaded_filter_lists
+ self.loaded_filters = loaded_filters
+ self.loaded_settings = loaded_settings
+ self.loaded_filter_settings = loaded_filter_settings
+ self.embed = embed
+ self.confirm_callback = confirm_callback
+
+ title = "Filters Search"
+ if filter_type:
+ title += f" - {filter_type.name.title()}"
+ embed.set_author(name=title)
+
+ settings_repr_dict = build_search_repr_dict(settings, filter_settings, filter_type)
+ populate_embed_from_dict(embed, settings_repr_dict)
+
+ self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()}
+ if filter_type:
+ self.type_per_setting_name.update({
+ f"{filter_type.name}/{name}": type_
+ for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items()
+ })
+
+ add_select = CustomCallbackSelect(
+ self._prompt_new_value,
+ placeholder="Add or edit criterion",
+ options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)],
+ row=0
+ )
+ self.add_item(add_select)
+
+ if settings_repr_dict:
+ remove_select = CustomCallbackSelect(
+ self._remove_criterion,
+ placeholder="Select a criterion to remove",
+ options=[SelectOption(label=name) for name in sorted(settings_repr_dict)],
+ row=1
+ )
+ self.add_item(remove_select)
+
+ @discord.ui.button(label="Template", row=2)
+ async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to enter a filter template ID and copy its overrides over."""
+ modal = TemplateModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="Filter Type", row=2)
+ async def enter_filter_type(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to enter a filter type."""
+ modal = FilterTypeModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=3)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Confirm the search criteria and perform the search."""
+ await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first.
+ try:
+ await self.confirm_callback(interaction.message, self.filter_type, self.settings, self.filter_settings)
+ except BadArgument as e:
+ await interaction.message.reply(
+ embed=discord.Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e))
+ )
+ await interaction.message.edit(view=self)
+ else:
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=3)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the operation."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None)
+ self.stop()
+
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+ if setting_name in self.settings:
+ return self.settings[setting_name]
+ if "/" in setting_name:
+ _, setting_name = setting_name.split("/", maxsplit=1)
+ if setting_name in self.filter_settings:
+ return self.filter_settings[setting_name]
+ return MISSING
+
+ async def update_embed(
+ self,
+ interaction_or_msg: discord.Interaction | discord.Message,
+ *,
+ setting_name: str | None = None,
+ setting_value: str | type[SearchEditView._REMOVE] | None = None,
+ ) -> None:
+ """
+ Update the embed with the new information.
+
+ If a setting name is provided with a _REMOVE value, remove the override.
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+ if not setting_name: # Can be None just to make the function signature compatible with the parent class.
+ return
+
+ if "/" in setting_name:
+ filter_name, setting_name = setting_name.split("/", maxsplit=1)
+ dict_to_edit = self.filter_settings
+ else:
+ dict_to_edit = self.settings
+
+ # Update the criterion value or remove it
+ if setting_value is not self._REMOVE:
+ dict_to_edit[setting_name] = setting_value
+ elif setting_name in dict_to_edit:
+ dict_to_edit.pop(setting_name)
+
+ self.embed.clear_fields()
+ new_view = self.copy()
+
+ try:
+ if isinstance(interaction_or_msg, discord.Interaction):
+ await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view)
+ else:
+ await interaction_or_msg.edit(embed=self.embed, view=new_view)
+ except discord.errors.HTTPException: # Just in case of faulty input.
+ pass
+ else:
+ self.stop()
+
+ async def _remove_criterion(self, interaction: Interaction, select: discord.ui.Select) -> None:
+ """
+ Remove the criterion the user selected, and edit the embed.
+
+ The interaction needs to be the selection of the setting attached to the embed.
+ """
+ await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE)
+
+ async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None:
+ """Set any unset criteria with settings values from the given filter."""
+ try:
+ settings, filter_settings, self.filter_type = template_settings(
+ template_id, self.loaded_filter_lists, self.filter_type
+ )
+ except BadArgument as e: # The interaction object is necessary to send an ephemeral message.
+ await interaction.response.send_message(f":x: {e}", ephemeral=True)
+ return
+ else:
+ await interaction.response.defer()
+
+ self.settings = settings | self.settings
+ self.filter_settings = filter_settings | self.filter_settings
+ self.embed.clear_fields()
+ await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
+
+ async def apply_filter_type(self, type_name: str, embed_message: discord.Message, interaction: Interaction) -> None:
+ """Set a new filter type and reset any criteria for settings of the old filter type."""
+ if type_name.lower() not in self.loaded_filters:
+ if type_name.lower()[:-1] not in self.loaded_filters: # In case the user entered the plural form.
+ await interaction.response.send_message(f":x: No such filter type {type_name!r}.", ephemeral=True)
+ return
+ type_name = type_name[:-1]
+ type_name = type_name.lower()
+ await interaction.response.defer()
+
+ if self.filter_type and type_name == self.filter_type.name:
+ return
+ self.filter_type = self.loaded_filters[type_name]
+ self.filter_settings = {}
+ self.embed.clear_fields()
+ await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
+
+ def copy(self) -> SearchEditView:
+ """Create a copy of this view."""
+ return SearchEditView(
+ self.filter_type,
+ self.settings,
+ self.filter_settings,
+ self.loaded_filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ self.author,
+ self.embed,
+ self.confirm_callback
+ )
+
+
+class TemplateModal(discord.ui.Modal, title="Template"):
+ """A modal to enter a filter ID to copy its overrides over."""
+
+ template = discord.ui.TextInput(label="Template Filter ID", required=False)
+
+ def __init__(self, embed_view: SearchEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await self.embed_view.apply_template(self.template.value, self.message, interaction)
+
+
+class FilterTypeModal(discord.ui.Modal, title="Template"):
+ """A modal to enter a filter ID to copy its overrides over."""
+
+ filter_type = discord.ui.TextInput(label="Filter Type")
+
+ def __init__(self, embed_view: SearchEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await self.embed_view.apply_filter_type(self.filter_type.value, self.message, interaction)
diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py
new file mode 100644
index 000000000..157906d6b
--- /dev/null
+++ b/bot/exts/filtering/_ui/ui.py
@@ -0,0 +1,565 @@
+from __future__ import annotations
+
+import re
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from enum import EnumMeta
+from functools import partial
+from typing import Any, Callable, Coroutine, Optional, TypeVar
+
+import discord
+from discord import Embed, Interaction
+from discord.ext.commands import Context
+from discord.ui.select import MISSING as SELECT_MISSING, SelectOption
+from discord.utils import escape_markdown
+from pydis_core.site_api import ResponseCodeError
+from pydis_core.utils import scheduling
+from pydis_core.utils.logging import get_logger
+from pydis_core.utils.members import get_or_fetch_member
+
+import bot
+from bot.constants import Colours
+from bot.exts.filtering._filter_context import FilterContext
+from bot.exts.filtering._filter_lists import FilterList
+from bot.exts.filtering._utils import FakeContext
+from bot.utils.messages import format_channel, format_user, upload_log
+
+log = get_logger(__name__)
+
+
+# Max number of characters in a Discord embed field value, minus 6 characters for a placeholder.
+MAX_FIELD_SIZE = 1018
+# Max number of characters for an embed field's value before it should take its own line.
+MAX_INLINE_SIZE = 50
+# Number of seconds before a settings editing view timeout.
+EDIT_TIMEOUT = 600
+# Number of seconds before timeout of an editing component.
+COMPONENT_TIMEOUT = 180
+# Amount of seconds to confirm the operation.
+DELETION_TIMEOUT = 60
+# Max length of modal title
+MAX_MODAL_TITLE_LENGTH = 45
+# Max number of items in a select
+MAX_SELECT_ITEMS = 25
+MAX_EMBED_DESCRIPTION = 4080
+# Number of seconds before timeout of the alert view
+ALERT_VIEW_TIMEOUT = 3600
+
+SETTINGS_DELIMITER = re.compile(r"\s+(?=\S+=\S+)")
+SINGLE_SETTING_PATTERN = re.compile(r"[\w/]+=.+")
+
+# Sentinel value to denote that a value is missing
+MISSING = object()
+
+T = TypeVar('T')
+
+
+async def _build_alert_message_content(ctx: FilterContext, current_message_length: int) -> str:
+ """Build the content section of the alert."""
+ # For multiple messages and those with attachments or excessive newlines, use the logs API
+ if ctx.messages_deletion and ctx.upload_deletion_logs and any((
+ ctx.related_messages,
+ len(ctx.attachments) > 0,
+ ctx.content.count('\n') > 15
+ )):
+ url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.attachments)
+ return f"A complete log of the offending messages can be found [here]({url})"
+
+ alert_content = escape_markdown(ctx.content)
+ remaining_chars = MAX_EMBED_DESCRIPTION - current_message_length
+
+ if len(alert_content) > remaining_chars:
+ if ctx.messages_deletion and ctx.upload_deletion_logs:
+ url = await upload_log([ctx.message], bot.instance.user.id, ctx.attachments)
+ log_site_msg = f"The full message can be found [here]({url})"
+ # 7 because that's the length of "[...]\n\n"
+ return alert_content[:remaining_chars - (7 + len(log_site_msg))] + "[...]\n\n" + log_site_msg
+ else:
+ return alert_content[:remaining_chars - 5] + "[...]"
+
+ return alert_content
+
+
+async def build_mod_alert(ctx: FilterContext, triggered_filters: dict[FilterList, Iterable[str]]) -> Embed:
+ """Build an alert message from the filter context."""
+ embed = Embed(color=Colours.soft_orange)
+ embed.set_thumbnail(url=ctx.author.display_avatar.url)
+ triggered_by = f"**Triggered by:** {format_user(ctx.author)}"
+ if ctx.channel:
+ if ctx.channel.guild:
+ triggered_in = f"**Triggered in:** {format_channel(ctx.channel)}\n"
+ else:
+ triggered_in = "**Triggered in:** :warning:**DM**:warning:\n"
+ if len(ctx.related_channels) > 1:
+ triggered_in += f"**Channels:** {', '.join(channel.mention for channel in ctx.related_channels)}\n"
+ else:
+ triggered_by += "\n"
+ triggered_in = ""
+
+ filters = []
+ for filter_list, list_message in triggered_filters.items():
+ if list_message:
+ filters.append(f"**{filter_list.name.title()} Filters:** {', '.join(list_message)}")
+ filters = "\n".join(filters)
+
+ matches = "**Matches:** " + ", ".join(repr(match) for match in ctx.matches) if ctx.matches else ""
+ actions = "\n**Actions Taken:** " + (", ".join(ctx.action_descriptions) if ctx.action_descriptions else "-")
+
+ mod_alert_message = "\n".join(part for part in (triggered_by, triggered_in, filters, matches, actions) if part)
+ log.debug(f"{ctx.event.name} Filter:\n{mod_alert_message}")
+
+ if ctx.message:
+ mod_alert_message += f"\n**[Original Content]({ctx.message.jump_url})**:\n"
+ else:
+ mod_alert_message += "\n**Original Content**:\n"
+ mod_alert_message += await _build_alert_message_content(ctx, len(mod_alert_message))
+
+ embed.description = mod_alert_message
+ return embed
+
+
+def populate_embed_from_dict(embed: Embed, data: dict) -> None:
+ """Populate a Discord embed by populating fields from the given dict."""
+ for setting, value in data.items():
+ if setting.startswith("_"):
+ continue
+ if isinstance(value, (list, set, tuple)):
+ value = f"[{', '.join(map(str, value))}]"
+ else:
+ value = str(value) if value not in ("", None) else "-"
+ if len(value) > MAX_FIELD_SIZE:
+ value = value[:MAX_FIELD_SIZE] + " [...]"
+ embed.add_field(name=setting, value=value, inline=len(value) < MAX_INLINE_SIZE)
+
+
+def parse_value(value: str, type_: type[T]) -> T:
+ """Parse the value and attempt to convert it to the provided type."""
+ if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
+ type_ = type_.__origin__
+ if value == '""':
+ return type_()
+ if type_ in (tuple, list, set):
+ return list(value.split(","))
+ if type_ is bool:
+ return value.lower() == "true" or value == "1"
+ if isinstance(type_, EnumMeta):
+ return type_[value.upper()]
+
+ return type_(value)
+
+
+def format_response_error(e: ResponseCodeError) -> Embed:
+ """Format the response error into an embed."""
+ description = ""
+ if isinstance(e.response_json, list):
+ description = "\n".join(f"• {error}" for error in e.response_json)
+ elif isinstance(e.response_json, dict):
+ if "non_field_errors" in e.response_json:
+ non_field_errors = e.response_json.pop("non_field_errors")
+ description += "\n".join(f"• {error}" for error in non_field_errors) + "\n"
+ for field, errors in e.response_json.items():
+ description += "\n".join(f"• {field} - {error}" for error in errors) + "\n"
+
+ description = description.strip()
+ if len(description) > MAX_EMBED_DESCRIPTION:
+ description = description[:MAX_EMBED_DESCRIPTION] + "[...]"
+ if not description:
+ description = "Something unexpected happened, check the logs."
+
+ embed = Embed(colour=discord.Colour.red(), title="Oops...", description=description)
+ return embed
+
+
+class ArgumentCompletionSelect(discord.ui.Select):
+ """A select detailing the options that can be picked to assign to a missing argument."""
+
+ def __init__(
+ self,
+ ctx: Context,
+ args: list,
+ arg_name: str,
+ options: list[str],
+ position: int,
+ converter: Optional[Callable] = None
+ ):
+ super().__init__(
+ placeholder=f"Select a value for {arg_name!r}",
+ options=[discord.SelectOption(label=option) for option in options]
+ )
+ self.ctx = ctx
+ self.args = args
+ self.position = position
+ self.converter = converter
+
+ async def callback(self, interaction: discord.Interaction) -> None:
+ """re-invoke the context command with the completed argument value."""
+ await interaction.response.defer()
+ value = interaction.data["values"][0]
+ if self.converter:
+ value = self.converter(value)
+ args = self.args.copy() # This makes the view reusable.
+ args.insert(self.position, value)
+ log.trace(f"Argument filled with the value {value}. Re-invoking command")
+ await self.ctx.invoke(self.ctx.command, *args)
+
+
+class ArgumentCompletionView(discord.ui.View):
+ """A view used to complete a missing argument in an in invoked command."""
+
+ def __init__(
+ self,
+ ctx: Context,
+ args: list,
+ arg_name: str,
+ options: list[str],
+ position: int,
+ converter: Optional[Callable] = None
+ ):
+ super().__init__()
+ log.trace(f"The {arg_name} argument was designated missing in the invocation {ctx.view.buffer!r}")
+ self.add_item(ArgumentCompletionSelect(ctx, args, arg_name, options, position, converter))
+ self.ctx = ctx
+
+ async def interaction_check(self, interaction: discord.Interaction) -> bool:
+ """Check to ensure that the interacting user is the user who invoked the command."""
+ if interaction.user != self.ctx.author:
+ embed = discord.Embed(description="Sorry, but this dropdown menu can only be used by the original author.")
+ await interaction.response.send_message(embed=embed, ephemeral=True)
+ return False
+ return True
+
+
+class CustomCallbackSelect(discord.ui.Select):
+ """A selection which calls the provided callback on interaction."""
+
+ def __init__(
+ self,
+ callback: Callable[[Interaction, discord.ui.Select], Coroutine[None]],
+ *,
+ custom_id: str = SELECT_MISSING,
+ placeholder: str | None = None,
+ min_values: int = 1,
+ max_values: int = 1,
+ options: list[SelectOption] = SELECT_MISSING,
+ disabled: bool = False,
+ row: int | None = None,
+ ):
+ super().__init__(
+ custom_id=custom_id,
+ placeholder=placeholder,
+ min_values=min_values,
+ max_values=max_values,
+ options=options,
+ disabled=disabled,
+ row=row
+ )
+ self.custom_callback = callback
+
+ async def callback(self, interaction: Interaction) -> Any:
+ """Invoke the provided callback."""
+ await self.custom_callback(interaction, self)
+
+
+class BooleanSelectView(discord.ui.View):
+ """A view containing an instance of BooleanSelect."""
+
+ class BooleanSelect(discord.ui.Select):
+ """Select a true or false value and send it to the supplied callback."""
+
+ def __init__(self, setting_name: str, update_callback: Callable):
+ super().__init__(options=[SelectOption(label="True"), SelectOption(label="False")])
+ self.setting_name = setting_name
+ self.update_callback = update_callback
+
+ async def callback(self, interaction: Interaction) -> Any:
+ """Respond to the interaction by sending the boolean value to the update callback."""
+ await interaction.response.edit_message(content="✅ Edit confirmed", view=None)
+ value = self.values[0] == "True"
+ await self.update_callback(setting_name=self.setting_name, setting_value=value)
+
+ def __init__(self, setting_name: str, update_callback: Callable):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.add_item(self.BooleanSelect(setting_name, update_callback))
+
+
+class FreeInputModal(discord.ui.Modal):
+ """A modal to freely enter a value for a setting."""
+
+ def __init__(self, setting_name: str, type_: type, update_callback: Callable):
+ title = f"{setting_name} Input" if len(setting_name) < MAX_MODAL_TITLE_LENGTH - 6 else "Setting Input"
+ super().__init__(timeout=COMPONENT_TIMEOUT, title=title)
+
+ self.setting_name = setting_name
+ self.type_ = type_
+ self.update_callback = update_callback
+
+ label = setting_name if len(setting_name) < MAX_MODAL_TITLE_LENGTH else "Value"
+ self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=False)
+ self.add_item(self.setting_input)
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the setting with the new value in the embed."""
+ try:
+ if not self.setting_input.value:
+ value = self.type_()
+ else:
+ value = self.type_(self.setting_input.value)
+ except (ValueError, TypeError):
+ await interaction.response.send_message(
+ f"Could not process the input value for `{self.setting_name}`.", ephemeral=True
+ )
+ else:
+ await interaction.response.defer()
+ await self.update_callback(setting_name=self.setting_name, setting_value=value)
+
+
+class SequenceEditView(discord.ui.View):
+ """A view to modify the contents of a sequence of values."""
+
+ class SingleItemModal(discord.ui.Modal):
+ """A modal to enter a single list item."""
+
+ new_item = discord.ui.TextInput(label="New Item")
+
+ def __init__(self, view: SequenceEditView):
+ super().__init__(title="Item Addition", timeout=COMPONENT_TIMEOUT)
+ self.view = view
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Send the submitted value to be added to the list."""
+ await self.view.apply_addition(interaction, self.new_item.value)
+
+ class NewListModal(discord.ui.Modal):
+ """A modal to enter new contents for the list."""
+
+ new_value = discord.ui.TextInput(label="Enter comma separated values", style=discord.TextStyle.paragraph)
+
+ def __init__(self, view: SequenceEditView):
+ super().__init__(title="New List", timeout=COMPONENT_TIMEOUT)
+ self.view = view
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Send the submitted value to be added to the list."""
+ await self.view.apply_edit(interaction, self.new_value.value)
+
+ def __init__(self, setting_name: str, starting_value: list, update_callback: Callable):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.setting_name = setting_name
+ self.stored_value = starting_value
+ self.update_callback = update_callback
+
+ options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]]
+ self.removal_select = CustomCallbackSelect(
+ self.apply_removal, placeholder="Enter an item to remove", options=options, row=1
+ )
+ if self.stored_value:
+ self.add_item(self.removal_select)
+
+ async def apply_removal(self, interaction: Interaction, select: discord.ui.Select) -> None:
+ """Remove an item from the list."""
+ # The value might not be stored as a string.
+ _i = len(self.stored_value)
+ for _i, element in enumerate(self.stored_value):
+ if str(element) == select.values[0]:
+ break
+ if _i != len(self.stored_value):
+ self.stored_value.pop(_i)
+
+ await interaction.response.edit_message(
+ content=f"Current list: [{', '.join(self.stored_value)}]", view=self.copy()
+ )
+ self.stop()
+
+ async def apply_addition(self, interaction: Interaction, item: str) -> None:
+ """Add an item to the list."""
+ if item in self.stored_value: # Ignore duplicates
+ await interaction.response.defer()
+ return
+
+ self.stored_value.append(item)
+ await interaction.response.edit_message(
+ content=f"Current list: [{', '.join(self.stored_value)}]", view=self.copy()
+ )
+ self.stop()
+
+ async def apply_edit(self, interaction: Interaction, new_list: str) -> None:
+ """Change the contents of the list."""
+ self.stored_value = list(set(part.strip() for part in new_list.split(",") if part.strip()))
+ await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self.copy())
+ self.stop()
+
+ @discord.ui.button(label="Add Value")
+ async def add_value(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to add an item to the list."""
+ await interaction.response.send_modal(self.SingleItemModal(self))
+
+ @discord.ui.button(label="Free Input")
+ async def free_input(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to change the entire list."""
+ await interaction.response.send_modal(self.NewListModal(self))
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Send the final value to the embed editor."""
+ # Edit first, it might time out otherwise.
+ await interaction.response.edit_message(content="✅ Edit confirmed", view=None)
+ await self.update_callback(setting_name=self.setting_name, setting_value=self.stored_value)
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the list editing."""
+ await interaction.response.edit_message(content="🚫 Canceled", view=None)
+ self.stop()
+
+ def copy(self) -> SequenceEditView:
+ """Return a copy of this view."""
+ return SequenceEditView(self.setting_name, self.stored_value, self.update_callback)
+
+
+class EnumSelectView(discord.ui.View):
+ """A view containing an instance of EnumSelect."""
+
+ class EnumSelect(discord.ui.Select):
+ """Select an enum value and send it to the supplied callback."""
+
+ def __init__(self, setting_name: str, enum_cls: EnumMeta, update_callback: Callable):
+ super().__init__(options=[SelectOption(label=elem.name) for elem in enum_cls])
+ self.setting_name = setting_name
+ self.enum_cls = enum_cls
+ self.update_callback = update_callback
+
+ async def callback(self, interaction: Interaction) -> Any:
+ """Respond to the interaction by sending the enum value to the update callback."""
+ await interaction.response.edit_message(content="✅ Edit confirmed", view=None)
+ await self.update_callback(setting_name=self.setting_name, setting_value=self.values[0])
+
+ def __init__(self, setting_name: str, enum_cls: EnumMeta, update_callback: Callable):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.add_item(self.EnumSelect(setting_name, enum_cls, update_callback))
+
+
+class EditBaseView(ABC, discord.ui.View):
+ """A view used to edit embed fields based on a provided type."""
+
+ def __init__(self, author: discord.User):
+ super().__init__(timeout=EDIT_TIMEOUT)
+ self.author = author
+ self.type_per_setting_name = {}
+
+ async def interaction_check(self, interaction: Interaction) -> bool:
+ """Only allow interactions from the command invoker."""
+ return interaction.user.id == self.author.id
+
+ async def _prompt_new_value(self, interaction: Interaction, select: discord.ui.Select) -> None:
+ """Prompt the user to give an override value for the setting they selected, and respond to the interaction."""
+ setting_name = select.values[0]
+ type_ = self.type_per_setting_name[setting_name]
+ if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
+ type_ = type_.__origin__
+ new_view = self.copy()
+ # This is in order to not block the interaction response. There's a potential race condition here, since
+ # a view's method is used without guaranteeing the task completed, but since it depends on user input
+ # realistically it shouldn't happen.
+ scheduling.create_task(interaction.message.edit(view=new_view))
+ update_callback = partial(new_view.update_embed, interaction_or_msg=interaction.message)
+ if type_ is bool:
+ view = BooleanSelectView(setting_name, update_callback)
+ await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True)
+ elif type_ in (set, list, tuple):
+ if (current_value := self.current_value(setting_name)) is not MISSING:
+ current_list = list(current_value)
+ else:
+ current_list = []
+ await interaction.response.send_message(
+ f"Current list: [{', '.join(current_list)}]",
+ view=SequenceEditView(setting_name, current_list, update_callback),
+ ephemeral=True
+ )
+ elif isinstance(type_, EnumMeta):
+ view = EnumSelectView(setting_name, type_, update_callback)
+ await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True)
+ else:
+ await interaction.response.send_modal(FreeInputModal(setting_name, type_, update_callback))
+ self.stop()
+
+ @abstractmethod
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+
+ @abstractmethod
+ async def update_embed(self, interaction_or_msg: Interaction | discord.Message) -> None:
+ """
+ Update the embed with the new information.
+
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+
+ @abstractmethod
+ def copy(self) -> EditBaseView:
+ """Create a copy of this view."""
+
+
+class DeleteConfirmationView(discord.ui.View):
+ """A view to confirm a deletion."""
+
+ def __init__(self, author: discord.Member | discord.User, callback: Callable):
+ super().__init__(timeout=DELETION_TIMEOUT)
+ self.author = author
+ self.callback = callback
+
+ async def interaction_check(self, interaction: Interaction) -> bool:
+ """Only allow interactions from the command invoker."""
+ return interaction.user.id == self.author.id
+
+ @discord.ui.button(label="Delete", style=discord.ButtonStyle.red, row=0)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Invoke the filter list deletion."""
+ await interaction.response.edit_message(view=None)
+ await self.callback()
+
+ @discord.ui.button(label="Cancel", row=0)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the filter list deletion."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", view=None)
+
+
+class AlertView(discord.ui.View):
+ """A view providing info about the offending user."""
+
+ def __init__(self, ctx: FilterContext):
+ super().__init__(timeout=ALERT_VIEW_TIMEOUT)
+ self.ctx = ctx
+
+ @discord.ui.button(label="ID")
+ async def user_id(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Reply with the ID of the offending user."""
+ await interaction.response.send_message(self.ctx.author.id, ephemeral=True)
+
+ @discord.ui.button(emoji="👤")
+ async def user_info(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Send the info embed of the offending user."""
+ command = bot.instance.get_command("user")
+ if not command:
+ await interaction.response.send_message("The command `user` is not loaded.", ephemeral=True)
+ return
+
+ await interaction.response.defer()
+ fake_ctx = FakeContext(interaction.message, interaction.channel, command, author=interaction.user)
+ # Get the most updated user/member object every time the button is pressed.
+ author = await get_or_fetch_member(interaction.guild, self.ctx.author.id)
+ if author is None:
+ author = await bot.instance.fetch_user(self.ctx.author.id)
+ await command(fake_ctx, author)
+
+ @discord.ui.button(emoji="⚠")
+ async def user_infractions(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Send the infractions embed of the offending user."""
+ command = bot.instance.get_command("infraction search")
+ if not command:
+ await interaction.response.send_message("The command `infraction search` is not loaded.", ephemeral=True)
+ return
+
+ await interaction.response.defer()
+ fake_ctx = FakeContext(interaction.message, interaction.channel, command, author=interaction.user)
+ await command(fake_ctx, self.ctx.author)
diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py
new file mode 100644
index 000000000..da433330f
--- /dev/null
+++ b/bot/exts/filtering/_utils.py
@@ -0,0 +1,224 @@
+import importlib
+import importlib.util
+import inspect
+import pkgutil
+import types
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from functools import cache
+from typing import Any, Iterable, TypeVar, Union, get_args, get_origin
+
+import discord
+import regex
+from discord.ext.commands import Command
+
+import bot
+from bot.bot import Bot
+from bot.constants import Guild
+
+VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF"
+INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1)
+ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1)
+
+
+T = TypeVar('T')
+
+
+def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]:
+ """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path."""
+ subclasses = set()
+
+ # Find all modules in the package.
+ for module_info in pkgutil.iter_modules([package], prefix):
+ if not module_info.ispkg:
+ module = importlib.import_module(module_info.name)
+ # Find all classes in each module...
+ for _, class_ in inspect.getmembers(module, inspect.isclass):
+ # That are a subclass of the given class.
+ if parent in class_.__mro__:
+ subclasses.add(class_)
+
+ return subclasses
+
+
+def clean_input(string: str) -> str:
+ """Remove zalgo and invisible characters from `string`."""
+ # For future consideration: remove characters in the Mc, Sk, and Lm categories too.
+ # Can be normalised with form C to merge char + combining char into a single char to avoid
+ # removing legit diacritics, but this would open up a way to bypass _filters.
+ no_zalgo = ZALGO_RE.sub("", string)
+ return INVISIBLE_RE.sub("", no_zalgo)
+
+
+def past_tense(word: str) -> str:
+ """Return the past tense form of the input word."""
+ if not word:
+ return word
+ if word.endswith("e"):
+ return word + "d"
+ if word.endswith("y") and len(word) > 1 and word[-2] not in "aeiou":
+ return word[:-1] + "ied"
+ return word + "ed"
+
+
+def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]:
+ """Convert the item into an object that can be converted to JSON."""
+ if isinstance(item, (bool, int, float, str, type(None))):
+ return item
+ if isinstance(item, dict):
+ result = {}
+ for key, value in item.items():
+ if not isinstance(key, (bool, int, float, str, type(None))):
+ key = str(key)
+ result[key] = to_serializable(value)
+ return result
+ if isinstance(item, Iterable):
+ return [to_serializable(subitem) for subitem in item]
+ return str(item)
+
+
+@cache
+def resolve_mention(mention: str) -> str:
+ """Return the appropriate formatting for the mention, be it a literal, a user ID, or a role ID."""
+ guild = bot.instance.get_guild(Guild.id)
+ if mention in ("here", "everyone"):
+ return f"@{mention}"
+ try:
+ mention = int(mention) # It's an ID.
+ except ValueError:
+ pass
+ else:
+ if any(mention == role.id for role in guild.roles):
+ return f"<@&{mention}>"
+ else:
+ return f"<@{mention}>"
+
+ # It's a name
+ for role in guild.roles:
+ if role.name == mention:
+ return role.mention
+ for member in guild.members:
+ if str(member) == mention:
+ return member.mention
+ return mention
+
+
+def repr_equals(override: Any, default: Any) -> bool:
+ """Return whether the override and the default have the same representation."""
+ if override is None: # It's not an override
+ return True
+
+ override_is_sequence = isinstance(override, (tuple, list, set))
+ default_is_sequence = isinstance(default, (tuple, list, set))
+ if override_is_sequence != default_is_sequence: # One is a sequence and the other isn't.
+ return False
+ if override_is_sequence:
+ if len(override) != len(default):
+ return False
+ return all(str(item1) == str(item2) for item1, item2 in zip(set(override), set(default)))
+ return str(override) == str(default)
+
+
+def starting_value(type_: type[T]) -> T:
+ """Return a value of the given type."""
+ if get_origin(type_) in (Union, types.UnionType): # In case of a Union
+ args = get_args(type_)
+ if type(None) in args:
+ return None
+ type_ = args[0] # Pick one, doesn't matter
+ if origin := get_origin(type_): # In case of a parameterized List, Set, Dict etc.
+ type_ = origin
+
+ try:
+ return type_()
+ except TypeError: # In case it all fails, return a string and let the user handle it.
+ return ""
+
+
+class FieldRequiring(ABC):
+ """A mixin class that can force its concrete subclasses to set a value for specific class attributes."""
+
+ # Sentinel value that mustn't remain in a concrete subclass.
+ MUST_SET = object()
+
+ # Sentinel value that mustn't remain in a concrete subclass.
+ # Overriding value must be unique in the subclasses of the abstract class in which the attribute was set.
+ MUST_SET_UNIQUE = object()
+
+ # A mapping of the attributes which must be unique, and their unique values, per FieldRequiring subclass.
+ __unique_attributes: defaultdict[type, dict[str, set]] = defaultdict(dict)
+
+ @abstractmethod
+ def __init__(self):
+ ...
+
+ def __init_subclass__(cls, **kwargs):
+ def inherited(attr: str) -> bool:
+ """True if `attr` was inherited from a parent class."""
+ for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object.
+ if hasattr(parent, attr): # The attribute was inherited.
+ return True
+ return False
+
+ # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it.
+ if inspect.isabstract(cls):
+ for attribute in dir(cls):
+ if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE:
+ if not inherited(attribute):
+ # A new attribute with the value MUST_SET_UNIQUE.
+ FieldRequiring.__unique_attributes[cls][attribute] = set()
+ return
+
+ for attribute in dir(cls):
+ if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"):
+ continue
+ value = getattr(cls, attribute)
+ if value is FieldRequiring.MUST_SET and inherited(attribute):
+ raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}")
+ elif value is FieldRequiring.MUST_SET_UNIQUE and inherited(attribute):
+ raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}")
+ else:
+ # Check if the value needs to be unique.
+ for parent in cls.__mro__[1:-1]:
+ # Find the parent class the attribute was first defined in.
+ if attribute in FieldRequiring.__unique_attributes[parent]:
+ if value in FieldRequiring.__unique_attributes[parent][attribute]:
+ raise ValueError(f"Value of {attribute!r} in {cls!r} is not unique for parent {parent!r}.")
+ else:
+ # Add to the set of unique values for that field.
+ FieldRequiring.__unique_attributes[parent][attribute].add(value)
+
+
+@dataclass
+class FakeContext:
+ """
+ A class representing a context-like object that can be sent to infraction commands.
+
+ The goal is to be able to apply infractions without depending on the existence of a message or an interaction
+ (which are the two ways to create a Context), e.g. in API events which aren't message-driven, or in custom filtering
+ events.
+ """
+
+ message: discord.Message
+ channel: discord.abc.Messageable
+ command: Command | None
+ bot: Bot | None = None
+ guild: discord.Guild | None = None
+ author: discord.Member | discord.User | None = None
+ me: discord.Member | None = None
+
+ def __post_init__(self):
+ """Initialize the missing information."""
+ if not self.bot:
+ self.bot = bot.instance
+ if not self.guild:
+ self.guild = self.bot.get_guild(Guild.id)
+ if not self.me:
+ self.me = self.guild.me
+ if not self.author:
+ self.author = self.me
+
+ async def send(self, *args, **kwargs) -> discord.Message:
+ """A wrapper for channel.send."""
+ return await self.channel.send(*args, **kwargs)
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
new file mode 100644
index 000000000..c4417e5e0
--- /dev/null
+++ b/bot/exts/filtering/filtering.py
@@ -0,0 +1,1424 @@
+import datetime
+import json
+import re
+import unicodedata
+from collections import defaultdict
+from collections.abc import Iterable, Mapping
+from functools import partial, reduce
+from io import BytesIO
+from operator import attrgetter
+from typing import Literal, Optional, get_type_hints
+
+import arrow
+import discord
+from async_rediscache import RedisCache
+from discord import Colour, Embed, HTTPException, Message, MessageType
+from discord.ext import commands, tasks
+from discord.ext.commands import BadArgument, Cog, Context, command, has_any_role
+from pydis_core.site_api import ResponseCodeError
+from pydis_core.utils import scheduling
+
+import bot
+import bot.exts.filtering._ui.filter as filters_ui
+from bot import constants
+from bot.bot import Bot
+from bot.constants import Channels, Guild, MODERATION_ROLES, Roles
+from bot.exts.backend.branding._repository import HEADERS, PARAMS
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists import FilterList, ListType, filter_list_types, list_type_converter
+from bot.exts.filtering._filter_lists.filter_list import AtomicList
+from bot.exts.filtering._filters.filter import Filter, UniqueFilter
+from bot.exts.filtering._settings import ActionSettings
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction
+from bot.exts.filtering._ui.filter import (
+ build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict
+)
+from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter
+from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter
+from bot.exts.filtering._ui.ui import (
+ AlertView, ArgumentCompletionView, DeleteConfirmationView, build_mod_alert, format_response_error
+)
+from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable
+from bot.exts.moderation.infraction.infractions import COMP_BAN_DURATION, COMP_BAN_REASON
+from bot.log import get_logger
+from bot.pagination import LinePaginator
+from bot.utils.channel import is_mod_channel
+from bot.utils.lock import lock_arg
+from bot.utils.message_cache import MessageCache
+
+log = get_logger(__name__)
+
+WEBHOOK_ICON_URL = r"https://github.com/python-discord/branding/raw/main/icons/filter/filter_pfp.png"
+WEBHOOK_NAME = "Filtering System"
+CACHE_SIZE = 1000
+HOURS_BETWEEN_NICKNAME_ALERTS = 1
+OFFENSIVE_MSG_DELETE_TIME = datetime.timedelta(days=7)
+WEEKLY_REPORT_ISO_DAY = 3 # 1=Monday, 7=Sunday
+
+
+class Filtering(Cog):
+ """Filtering and alerting for content posted on the server."""
+
+ # A set of filter list names with missing implementations that already caused a warning.
+ already_warned = set()
+
+ # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent.
+ name_alerts = RedisCache()
+
+ # region: init
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.filter_lists: dict[str, FilterList] = {}
+ self._subscriptions: defaultdict[Event, list[FilterList]] = defaultdict(list)
+ self.delete_scheduler = scheduling.Scheduler(self.__class__.__name__)
+ self.webhook: discord.Webhook | None = None
+
+ self.loaded_settings = {}
+ self.loaded_filters = {}
+ self.loaded_filter_settings = {}
+
+ self.message_cache = MessageCache(CACHE_SIZE, newest_first=True)
+
+ async def cog_load(self) -> None:
+ """
+ Fetch the filter data from the API, parse it, and load it to the appropriate data structures.
+
+ Additionally, fetch the alerting webhook.
+ """
+ await self.bot.wait_until_guild_available()
+
+ log.trace("Loading filtering information from the database.")
+ raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists")
+ example_list = None
+ for raw_filter_list in raw_filter_lists:
+ loaded_list = self._load_raw_filter_list(raw_filter_list)
+ if not example_list and loaded_list:
+ example_list = loaded_list
+
+ # The webhook must be generated by the bot to send messages with components through it.
+ self.webhook = await self._fetch_or_generate_filtering_webhook()
+
+ self.collect_loaded_types(example_list)
+ await self.schedule_offending_messages_deletion()
+ self.weekly_auto_infraction_report_task.start()
+
+ def subscribe(self, filter_list: FilterList, *events: Event) -> None:
+ """
+ Subscribe a filter list to the given events.
+
+ The filter list is added to a list for each event. When the event is triggered, the filter context will be
+ dispatched to the subscribed filter lists.
+
+ While it's possible to just make each filter list check the context's event, these are only the events a filter
+ list expects to receive from the filtering cog, there isn't an actual limitation on the kinds of events a filter
+ list can handle as long as the filter context is built properly. If for whatever reason we want to invoke a
+ filter list outside of the usual procedure with the filtering cog, it will be more problematic if the events are
+ hard-coded into each filter list.
+ """
+ for event in events:
+ if filter_list not in self._subscriptions[event]:
+ self._subscriptions[event].append(filter_list)
+
+ def unsubscribe(self, filter_list: FilterList, *events: Event) -> None:
+ """Unsubscribe a filter list from the given events. If no events given, unsubscribe from every event."""
+ if not events:
+ events = list(self._subscriptions)
+
+ for event in events:
+ if filter_list in self._subscriptions.get(event, []):
+ self._subscriptions[event].remove(filter_list)
+
+ def collect_loaded_types(self, example_list: AtomicList) -> None:
+ """
+ Go over the classes used in initialization and collect them to dictionaries.
+
+ The information that is collected is about the types actually used to load the API response, not all types
+ available in the filtering extension.
+
+ Any filter list has the fields for all settings in the DB schema, so picking any one of them is enough.
+ """
+ # Get the filter types used by each filter list.
+ for filter_list in self.filter_lists.values():
+ self.loaded_filters.update({filter_type.name: filter_type for filter_type in filter_list.filter_types})
+
+ # Get the setting types used by each filter list.
+ if self.filter_lists:
+ settings_entries = set()
+ # The settings are split between actions and validations.
+ for settings_group in example_list.defaults:
+ settings_entries.update(type(setting) for _, setting in settings_group.items())
+
+ for setting_entry in settings_entries:
+ type_hints = get_type_hints(setting_entry)
+ # The description should be either a string or a dictionary.
+ if isinstance(setting_entry.description, str):
+ # If it's a string, then the settings entry matches a single field in the DB,
+ # and its name is the setting type's name attribute.
+ self.loaded_settings[setting_entry.name] = (
+ setting_entry.description, setting_entry, type_hints[setting_entry.name]
+ )
+ else:
+ # Otherwise, the setting entry works with compound settings.
+ self.loaded_settings.update({
+ subsetting: (description, setting_entry, type_hints[subsetting])
+ for subsetting, description in setting_entry.description.items()
+ })
+
+ # Get the settings per filter as well.
+ for filter_name, filter_type in self.loaded_filters.items():
+ extra_fields_type = filter_type.extra_fields_type
+ if not extra_fields_type:
+ continue
+ type_hints = get_type_hints(extra_fields_type)
+ # A class var with a `_description` suffix is expected per field name.
+ self.loaded_filter_settings[filter_name] = {
+ field_name: (
+ getattr(extra_fields_type, f"{field_name}_description", ""),
+ extra_fields_type,
+ type_hints[field_name]
+ )
+ for field_name in extra_fields_type.__fields__
+ }
+
+ async def schedule_offending_messages_deletion(self) -> None:
+ """Load the messages that need to be scheduled for deletion from the database."""
+ response = await self.bot.api_client.get('bot/offensive-messages')
+
+ now = arrow.utcnow()
+ for msg in response:
+ delete_at = arrow.get(msg['delete_date'])
+ if delete_at < now:
+ await self._delete_offensive_msg(msg)
+ else:
+ self._schedule_msg_delete(msg)
+
+ async def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators to invoke the commands in this cog."""
+ return await has_any_role(*MODERATION_ROLES).predicate(ctx)
+
+ # endregion
+ # region: listeners and event handlers
+
+ @Cog.listener()
+ async def on_message(self, msg: Message) -> None:
+ """Filter the contents of a sent message."""
+ if msg.author.bot or msg.webhook_id or msg.type == MessageType.auto_moderation_action:
+ return
+ self.message_cache.append(msg)
+
+ ctx = FilterContext.from_message(Event.MESSAGE, msg, None, self.message_cache)
+ result_actions, list_messages, triggers = await self._resolve_action(ctx)
+ self.message_cache.update(msg, metadata=triggers)
+ if result_actions:
+ await result_actions.action(ctx)
+ if ctx.send_alert:
+ await self._send_alert(ctx, list_messages)
+
+ nick_ctx = FilterContext.from_message(Event.NICKNAME, msg)
+ nick_ctx.content = msg.author.display_name
+ await self._check_bad_name(nick_ctx)
+
+ await self._maybe_schedule_msg_delete(ctx, result_actions)
+ self._increment_stats(triggers)
+
+ @Cog.listener()
+ async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
+ """Filter the contents of an edited message. Don't reinvoke filters already invoked on the `before` version."""
+ # Only check changes to the message contents/attachments and embed additions, not pin status etc.
+ if all((
+ before.content == after.content, # content hasn't changed
+ before.attachments == after.attachments, # attachments haven't changed
+ len(before.embeds) >= len(after.embeds) # embeds haven't been added
+ )):
+ return
+
+ # Update the cache first, it might be used by the antispam filter.
+ # No need to update the triggers, they're going to be updated inside the sublists if necessary.
+ self.message_cache.update(after)
+ ctx = FilterContext.from_message(Event.MESSAGE_EDIT, after, before, self.message_cache)
+ result_actions, list_messages, triggers = await self._resolve_action(ctx)
+ if result_actions:
+ await result_actions.action(ctx)
+ if ctx.send_alert:
+ await self._send_alert(ctx, list_messages)
+ await self._maybe_schedule_msg_delete(ctx, result_actions)
+ self._increment_stats(triggers)
+
+ @Cog.listener()
+ async def on_voice_state_update(self, member: discord.Member, *_) -> None:
+ """Checks for bad words in usernames when users join, switch or leave a voice channel."""
+ ctx = FilterContext(Event.NICKNAME, member, None, member.display_name, None)
+ await self._check_bad_name(ctx)
+
+ async def filter_snekbox_output(self, snekbox_result: str, msg: Message) -> bool:
+ """
+ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly.
+
+ Also requires the original message, to check whether to filter and for alerting.
+ Any action (deletion, infraction) will be applied in the context of the original message.
+
+ Returns whether a filter was triggered or not.
+ """
+ ctx = FilterContext.from_message(Event.MESSAGE, msg).replace(content=snekbox_result)
+ result_actions, list_messages, triggers = await self._resolve_action(ctx)
+ if result_actions:
+ await result_actions.action(ctx)
+ if ctx.send_alert:
+ await self._send_alert(ctx, list_messages)
+ self._increment_stats(triggers)
+
+ return result_actions is not None
+
+ # endregion
+ # region: blacklist commands
+
+ @commands.group(aliases=("bl", "blacklist", "denylist", "dl"))
+ async def blocklist(self, ctx: Context) -> None:
+ """Group for managing blacklisted items."""
+ if not ctx.invoked_subcommand:
+ await ctx.send_help(ctx.command)
+
+ @blocklist.command(name="list", aliases=("get",))
+ async def bl_list(self, ctx: Context, list_name: Optional[str] = None) -> None:
+ """List the contents of a specified blacklist."""
+ result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type")
+ if not result:
+ return
+ list_type, filter_list = result
+ await self._send_list(ctx, filter_list, list_type)
+
+ @blocklist.command(name="add", aliases=("a",))
+ async def bl_add(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ list_name: Optional[str],
+ content: str,
+ *,
+ description_and_settings: Optional[str] = None
+ ) -> None:
+ """
+ Add a blocked filter to the specified filter list.
+
+ Unless `noui` is specified, a UI will be provided to edit the content, description, and settings
+ before confirmation.
+
+ The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the
+ equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces.
+ """
+ result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type")
+ if result is None:
+ return
+ list_type, filter_list = result
+ await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings)
+
+ # endregion
+ # region: whitelist commands
+
+ @commands.group(aliases=("wl", "whitelist", "al"))
+ async def allowlist(self, ctx: Context) -> None:
+ """Group for managing blacklisted items."""
+ if not ctx.invoked_subcommand:
+ await ctx.send_help(ctx.command)
+
+ @allowlist.command(name="list", aliases=("get",))
+ async def al_list(self, ctx: Context, list_name: Optional[str] = None) -> None:
+ """List the contents of a specified whitelist."""
+ result = await self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name, exclude="list_type")
+ if not result:
+ return
+ list_type, filter_list = result
+ await self._send_list(ctx, filter_list, list_type)
+
+ @allowlist.command(name="add", aliases=("a",))
+ async def al_add(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ list_name: Optional[str],
+ content: str,
+ *,
+ description_and_settings: Optional[str] = None
+ ) -> None:
+ """
+ Add an allowed filter to the specified filter list.
+
+ Unless `noui` is specified, a UI will be provided to edit the content, description, and settings
+ before confirmation.
+
+ The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the
+ equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces.
+ """
+ result = await self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name, exclude="list_type")
+ if result is None:
+ return
+ list_type, filter_list = result
+ await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings)
+
+ # endregion
+ # region: filter commands
+
+ @commands.group(aliases=("filters", "f"), invoke_without_command=True)
+ async def filter(self, ctx: Context, id_: Optional[int] = None) -> None:
+ """
+ Group for managing filters.
+
+ If a valid filter ID is provided, an embed describing the filter will be posted.
+ """
+ if not ctx.invoked_subcommand and not id_:
+ await ctx.send_help(ctx.command)
+ return
+
+ result = self._get_filter_by_id(id_)
+ if result is None:
+ await ctx.send(f":x: Could not find a filter with ID `{id_}`.")
+ return
+ filter_, filter_list, list_type = result
+
+ overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_)
+
+ all_settings_repr_dict = build_filter_repr_dict(
+ filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides
+ )
+ embed = Embed(colour=Colour.blue())
+ populate_embed_from_dict(embed, all_settings_repr_dict)
+ embed.description = f"`{filter_.content}`"
+ if filter_.description:
+ embed.description += f" - {filter_.description}"
+ embed.set_author(name=f"Filter {id_} - " + f"{filter_list[list_type].label}".title())
+ embed.set_footer(text=(
+ "Field names with an asterisk have values which override the defaults of the containing filter list. "
+ f"To view all defaults of the list, "
+ f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`."
+ ))
+ await ctx.send(embed=embed)
+
+ @filter.command(name="list", aliases=("get",))
+ async def f_list(
+ self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None
+ ) -> None:
+ """List the contents of a specified list of filters."""
+ result = await self._resolve_list_type_and_name(ctx, list_type, list_name)
+ if result is None:
+ return
+ list_type, filter_list = result
+
+ await self._send_list(ctx, filter_list, list_type)
+
+ @filter.command(name="describe", aliases=("explain", "manual"))
+ async def f_describe(self, ctx: Context, filter_name: Optional[str]) -> None:
+ """Show a description of the specified filter, or a list of possible values if no name is specified."""
+ if not filter_name:
+ filter_names = [f"» {f}" for f in self.loaded_filters]
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name="List of filter names")
+ await LinePaginator.paginate(filter_names, ctx, embed, max_lines=10, empty=False)
+ else:
+ filter_type = self.loaded_filters.get(filter_name)
+ if not filter_type:
+ filter_type = self.loaded_filters.get(filter_name[:-1]) # A plural form or a typo.
+ if not filter_type:
+ await ctx.send(f":x: There's no filter type named {filter_name!r}.")
+ return
+ # Use the class's docstring, and ignore single newlines.
+ embed = Embed(description=re.sub(r"(?<!\n)\n(?!\n)", " ", filter_type.__doc__), colour=Colour.blue())
+ embed.set_author(name=f"Description of the {filter_name} filter")
+ await ctx.send(embed=embed)
+
+ @filter.command(name="add", aliases=("a",))
+ async def f_add(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ list_type: Optional[list_type_converter],
+ list_name: Optional[str],
+ content: str,
+ *,
+ description_and_settings: Optional[str] = None
+ ) -> None:
+ """
+ Add a filter to the specified filter list.
+
+ Unless `noui` is specified, a UI will be provided to edit the content, description, and settings
+ before confirmation.
+
+ The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the
+ equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces.
+
+ A template filter can be specified in the settings area to copy overrides from. The setting name is "--template"
+ and the value is the filter ID. The template will be used before applying any other override.
+
+ Example: `!filter add denied token "Scaleios is great" remove_context=True send_alert=False --template=100`
+ """
+ result = await self._resolve_list_type_and_name(ctx, list_type, list_name)
+ if result is None:
+ return
+ list_type, filter_list = result
+ await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings)
+
+ @filter.command(name="edit", aliases=("e",))
+ async def f_edit(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ filter_id: int,
+ *,
+ description_and_settings: Optional[str] = None
+ ) -> None:
+ """
+ Edit a filter specified by its ID.
+
+ Unless `noui` is specified, a UI will be provided to edit the content, description, and settings
+ before confirmation.
+
+ The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the
+ equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces.
+
+ A template filter can be specified in the settings area to copy overrides from. The setting name is "--template"
+ and the value is the filter ID. The template will be used before applying any other override.
+
+ To edit the filter's content, use the UI.
+ """
+ result = self._get_filter_by_id(filter_id)
+ if result is None:
+ await ctx.send(f":x: Could not find a filter with ID `{filter_id}`.")
+ return
+ filter_, filter_list, list_type = result
+ filter_type = type(filter_)
+ settings, filter_settings = filter_serializable_overrides(filter_)
+ description, new_settings, new_filter_settings = description_and_settings_converter(
+ filter_list,
+ list_type, filter_type,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ description_and_settings
+ )
+
+ content = filter_.content
+ description = description or filter_.description
+ settings.update(new_settings)
+ filter_settings.update(new_filter_settings)
+ patch_func = partial(self._patch_filter, filter_)
+
+ if noui:
+ try:
+ await patch_func(
+ ctx.message, filter_list, list_type, filter_type, content, description, settings, filter_settings
+ )
+ except ResponseCodeError as e:
+ await ctx.reply(embed=format_response_error(e))
+ return
+
+ embed = Embed(colour=Colour.blue())
+ embed.description = f"`{filter_.content}`"
+ if description:
+ embed.description += f" - {description}"
+ embed.set_author(
+ name=f"Filter {filter_id} - {filter_list[list_type].label}".title())
+ embed.set_footer(text=(
+ "Field names with an asterisk have values which override the defaults of the containing filter list. "
+ f"To view all defaults of the list, "
+ f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`."
+ ))
+
+ view = filters_ui.FilterEditView(
+ filter_list,
+ list_type,
+ filter_type,
+ content,
+ description,
+ settings,
+ filter_settings,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ ctx.author,
+ embed,
+ patch_func
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
+ @filter.command(name="delete", aliases=("d", "remove"))
+ async def f_delete(self, ctx: Context, filter_id: int) -> None:
+ """Delete the filter specified by its ID."""
+ async def delete_list() -> None:
+ """The actual removal routine."""
+ await bot.instance.api_client.delete(f'bot/filter/filters/{filter_id}')
+ log.info(f"Successfully deleted filter with ID {filter_id}.")
+ filter_list[list_type].filters.pop(filter_id)
+ await ctx.reply(f"✅ Deleted filter: {filter_}")
+
+ result = self._get_filter_by_id(filter_id)
+ if result is None:
+ await ctx.send(f":x: Could not find a filter with ID `{filter_id}`.")
+ return
+ filter_, filter_list, list_type = result
+ await ctx.reply(
+ f"Are you sure you want to delete filter {filter_}?",
+ view=DeleteConfirmationView(ctx.author, delete_list)
+ )
+
+ @filter.command(aliases=("settings",))
+ async def setting(self, ctx: Context, setting_name: str | None) -> None:
+ """Show a description of the specified setting, or a list of possible settings if no name is specified."""
+ if not setting_name:
+ settings_list = [f"» {setting_name}" for setting_name in self.loaded_settings]
+ for filter_name, filter_settings in self.loaded_filter_settings.items():
+ settings_list.extend(f"» {filter_name}/{setting}" for setting in filter_settings)
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name="List of setting names")
+ await LinePaginator.paginate(settings_list, ctx, embed, max_lines=10, empty=False)
+
+ else:
+ # The setting is either in a SettingsEntry subclass, or a pydantic model.
+ setting_data = self.loaded_settings.get(setting_name)
+ description = None
+ if setting_data:
+ description = setting_data[0]
+ elif "/" in setting_name: # It's a filter specific setting.
+ filter_name, filter_setting_name = setting_name.split("/", maxsplit=1)
+ if filter_name in self.loaded_filter_settings:
+ if filter_setting_name in self.loaded_filter_settings[filter_name]:
+ description = self.loaded_filter_settings[filter_name][filter_setting_name][0]
+ if description is None:
+ await ctx.send(f":x: There's no setting type named {setting_name!r}.")
+ return
+ embed = Embed(colour=Colour.blue(), description=description)
+ embed.set_author(name=f"Description of the {setting_name} setting")
+ await ctx.send(embed=embed)
+
+ @filter.command(name="match")
+ async def f_match(
+ self, ctx: Context, no_user: bool | None, message: Message | None, *, string: str | None
+ ) -> None:
+ """
+ Post any responses from the filter lists for the given message or string.
+
+ If there's a `message`, the `string` will be ignored. Note that if a `message` is provided, it will go through
+ all validations appropriate to where it was sent and who sent it. To check for matches regardless of the author
+ (for example if the message was sent by another staff member or yourself) set `no_user` to '1' or 'True'.
+
+ If a `string` is provided, it will be validated in the context of a user with no roles in python-general.
+ """
+ if not message and not string:
+ raise BadArgument("Please provide input.")
+ if message:
+ user = None if no_user else message.author
+ filter_ctx = FilterContext(Event.MESSAGE, user, message.channel, message.content, message, message.embeds)
+ else:
+ python_general = ctx.guild.get_channel(Channels.python_general)
+ filter_ctx = FilterContext(Event.MESSAGE, None, python_general, string, None)
+
+ _, _, triggers = await self._resolve_action(filter_ctx)
+ lines = []
+ for sublist, sublist_triggers in triggers.items():
+ if sublist_triggers:
+ triggers_repr = map(str, sublist_triggers)
+ lines.extend([f"**{sublist.label.title()}s**", *triggers_repr, "\n"])
+ lines = lines[:-1] # Remove last newline.
+
+ embed = Embed(colour=Colour.blue(), title="Match results")
+ await LinePaginator.paginate(lines, ctx, embed, max_lines=10, empty=False)
+
+ @filter.command(name="search")
+ async def f_search(
+ self,
+ ctx: Context,
+ noui: Literal["noui"] | None,
+ filter_type_name: str | None,
+ *,
+ settings: str = ""
+ ) -> None:
+ """
+ Find filters with the provided settings. The format is identical to that of the add and edit commands.
+
+ If a list type and/or a list name are provided, the search will be limited to those parameters. A list name must
+ be provided in order to search by filter-specific settings.
+ """
+ filter_type = None
+ if filter_type_name:
+ filter_type_name = filter_type_name.lower()
+ filter_type = self.loaded_filters.get(filter_type_name)
+ if not filter_type:
+ self.loaded_filters.get(filter_type_name[:-1]) # In case the user tried to specify the plural form.
+ # If settings were provided with no filter_type, discord.py will capture the first word as the filter type.
+ if filter_type is None and filter_type_name is not None:
+ if settings:
+ settings = f"{filter_type_name} {settings}"
+ else:
+ settings = filter_type_name
+ filter_type_name = None
+
+ settings, filter_settings, filter_type = search_criteria_converter(
+ self.filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ filter_type,
+ settings
+ )
+
+ if noui:
+ await self._search_filters(ctx.message, filter_type, settings, filter_settings)
+ return
+
+ embed = Embed(colour=Colour.blue())
+ view = SearchEditView(
+ filter_type,
+ settings,
+ filter_settings,
+ self.filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ ctx.author,
+ embed,
+ self._search_filters
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
+ @filter.command(root_aliases=("compfilter", "compf"))
+ async def compadd(
+ self, ctx: Context, list_name: Optional[str], content: str, *, description: Optional[str] = "Phishing"
+ ) -> None:
+ """Add a filter to detect a compromised account. Will apply the equivalent of a compban if triggered."""
+ result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type")
+ if result is None:
+ return
+ list_type, filter_list = result
+
+ settings = (
+ "remove_context=True "
+ "dm_pings=Moderators "
+ "infraction_type=BAN "
+ "infraction_channel=1 " # Post the ban in #mod-alerts
+ f"infraction_duration={COMP_BAN_DURATION.total_seconds()} "
+ f"infraction_reason={COMP_BAN_REASON}"
+ )
+ description_and_settings = f"{description} {settings}"
+ await self._add_filter(ctx, "noui", list_type, filter_list, content, description_and_settings)
+
+ # endregion
+ # region: filterlist group
+
+ @commands.group(aliases=("fl",))
+ async def filterlist(self, ctx: Context) -> None:
+ """Group for managing filter lists."""
+ if not ctx.invoked_subcommand:
+ await ctx.send_help(ctx.command)
+
+ @filterlist.command(name="describe", aliases=("explain", "manual", "id"))
+ async def fl_describe(
+ self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None
+ ) -> None:
+ """Show a description of the specified filter list, or a list of possible values if no values are provided."""
+ if not list_type and not list_name:
+ list_names = [f"» {fl}" for fl in self.filter_lists]
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name="List of filter lists names")
+ await LinePaginator.paginate(list_names, ctx, embed, max_lines=10, empty=False)
+ return
+
+ result = await self._resolve_list_type_and_name(ctx, list_type, list_name)
+ if result is None:
+ return
+ list_type, filter_list = result
+
+ setting_values = {}
+ for settings_group in filter_list[list_type].defaults:
+ for _, setting in settings_group.items():
+ setting_values.update(to_serializable(setting.dict()))
+
+ embed = Embed(colour=Colour.blue())
+ populate_embed_from_dict(embed, setting_values)
+ # Use the class's docstring, and ignore single newlines.
+ embed.description = re.sub(r"(?<!\n)\n(?!\n)", " ", filter_list.__doc__)
+ embed.set_author(
+ name=f"Description of the {filter_list[list_type].label} filter list"
+ )
+ await ctx.send(embed=embed)
+
+ @filterlist.command(name="add", aliases=("a",))
+ @has_any_role(Roles.admins)
+ async def fl_add(self, ctx: Context, list_type: list_type_converter, list_name: str) -> None:
+ """Add a new filter list."""
+ # Check if there's an implementation.
+ if list_name.lower() not in filter_list_types:
+ if list_name.lower()[:-1] not in filter_list_types: # Maybe the name was given with uppercase or in plural?
+ await ctx.reply(f":x: Cannot add a `{list_name}` filter list, as there is no matching implementation.")
+ return
+ else:
+ list_name = list_name.lower()[:-1]
+
+ # Check it doesn't already exist.
+ list_description = f"{past_tense(list_type.name.lower())} {list_name.lower()}"
+ if list_name in self.filter_lists:
+ filter_list = self.filter_lists[list_name]
+ if list_type in filter_list:
+ await ctx.reply(f":x: The {list_description} filter list already exists.")
+ return
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name=f"New Filter List - {list_description.title()}")
+ settings = {name: starting_value(value[2]) for name, value in self.loaded_settings.items()}
+
+ view = FilterListAddView(
+ list_name,
+ list_type,
+ settings,
+ self.loaded_settings,
+ ctx.author,
+ embed,
+ self._post_filter_list
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
+ @filterlist.command(name="edit", aliases=("e",))
+ @has_any_role(Roles.admins)
+ async def fl_edit(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ list_type: Optional[list_type_converter] = None,
+ list_name: Optional[str] = None,
+ *,
+ settings: str | None
+ ) -> None:
+ """
+ Edit the filter list.
+
+ Unless `noui` is specified, a UI will be provided to edit the settings before confirmation.
+
+ The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the
+ equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces.
+ """
+ result = await self._resolve_list_type_and_name(ctx, list_type, list_name)
+ if result is None:
+ return
+ list_type, filter_list = result
+ settings = settings_converter(self.loaded_settings, settings)
+ if noui:
+ try:
+ await self._patch_filter_list(ctx.message, filter_list, list_type, settings)
+ except ResponseCodeError as e:
+ await ctx.reply(embed=format_response_error(e))
+ return
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name=f"{filter_list[list_type].label.title()} Filter List")
+ embed.set_footer(text="Field names with a ~ have values which change the existing value in the filter list.")
+
+ view = FilterListEditView(
+ filter_list,
+ list_type,
+ settings,
+ self.loaded_settings,
+ ctx.author,
+ embed,
+ self._patch_filter_list
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
+ @filterlist.command(name="delete", aliases=("remove",))
+ @has_any_role(Roles.admins)
+ async def fl_delete(
+ self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None
+ ) -> None:
+ """Remove the filter list and all of its filters from the database."""
+ async def delete_list() -> None:
+ """The actual removal routine."""
+ list_data = await bot.instance.api_client.get(f"bot/filter/filter_lists/{list_id}")
+ file = discord.File(BytesIO(json.dumps(list_data, indent=4).encode("utf-8")), f"{list_description}.json")
+ message = await ctx.send("⏳ Annihilation in progress, please hold...", file=file)
+ # Unload the filter list.
+ filter_list.pop(list_type)
+ if not filter_list: # There's nothing left, remove from the cog.
+ self.filter_lists.pop(filter_list.name)
+ self.unsubscribe(filter_list)
+
+ await bot.instance.api_client.delete(f"bot/filter/filter_lists/{list_id}")
+ log.info(f"Successfully deleted the {filter_list[list_type].label} filterlist.")
+ await message.edit(content=f"✅ The {list_description} list has been deleted.")
+
+ result = await self._resolve_list_type_and_name(ctx, list_type, list_name)
+ if result is None:
+ return
+ list_type, filter_list = result
+ list_id = filter_list[list_type].id
+ list_description = filter_list[list_type].label
+ await ctx.reply(
+ f"Are you sure you want to delete the {list_description} list?",
+ view=DeleteConfirmationView(ctx.author, delete_list)
+ )
+
+ # endregion
+ # region: utility commands
+
+ @command(name="filter_report")
+ async def force_send_weekly_report(self, ctx: Context) -> None:
+ """Respond with a list of auto-infractions added in the last 7 days."""
+ await self.send_weekly_auto_infraction_report(ctx.channel)
+
+ # endregion
+ # region: helper functions
+
+ def _load_raw_filter_list(self, list_data: dict) -> AtomicList | None:
+ """Load the raw list data to the cog."""
+ list_name = list_data["name"]
+ if list_name not in self.filter_lists:
+ if list_name not in filter_list_types:
+ if list_name not in self.already_warned:
+ log.warning(
+ f"A filter list named {list_name} was loaded from the database, but no matching class."
+ )
+ self.already_warned.add(list_name)
+ return None
+ self.filter_lists[list_name] = filter_list_types[list_name](self)
+ return self.filter_lists[list_name].add_list(list_data)
+
+ async def _fetch_or_generate_filtering_webhook(self) -> discord.Webhook | None:
+ """Generate a webhook with the filtering avatar."""
+ alerts_channel = self.bot.get_guild(Guild.id).get_channel(Channels.mod_alerts)
+ # Try to find an existing webhook.
+ for webhook in await alerts_channel.webhooks():
+ if webhook.name == WEBHOOK_NAME and webhook.user == self.bot.user and webhook.is_authenticated():
+ log.trace(f"Found existing filters webhook with ID {webhook.id}.")
+ return webhook
+
+ # Download the filtering avatar from the branding repository.
+ webhook_icon = None
+ async with self.bot.http_session.get(WEBHOOK_ICON_URL, params=PARAMS, headers=HEADERS) as response:
+ if response.status == 200:
+ log.debug("Successfully fetched filtering webhook icon, reading payload.")
+ webhook_icon = await response.read()
+ else:
+ log.warning(f"Failed to fetch filtering webhook icon due to status: {response.status}")
+
+ # Generate a new webhook.
+ try:
+ webhook = await alerts_channel.create_webhook(name=WEBHOOK_NAME, avatar=webhook_icon)
+ log.trace(f"Generated new filters webhook with ID {webhook.id},")
+ return webhook
+ except HTTPException as e:
+ log.error(f"Failed to create filters webhook: {e}")
+ return None
+
+ async def _resolve_action(
+ self, ctx: FilterContext
+ ) -> tuple[ActionSettings | None, dict[FilterList, list[str]], dict[AtomicList, list[Filter]]]:
+ """
+ Return the actions that should be taken for all filter lists in the given context.
+
+ Additionally, a message is possibly provided from each filter list describing the triggers,
+ which should be relayed to the moderators.
+ """
+ actions = []
+ messages = {}
+ triggers = {}
+ for filter_list in self._subscriptions[ctx.event]:
+ list_actions, list_message, list_triggers = await filter_list.actions_for(ctx)
+ triggers.update({filter_list[list_type]: filters for list_type, filters in list_triggers.items()})
+ if list_actions:
+ actions.append(list_actions)
+ if list_message:
+ messages[filter_list] = list_message
+
+ result_actions = None
+ if actions:
+ result_actions = reduce(ActionSettings.union, actions)
+
+ return result_actions, messages, triggers
+
+ async def _send_alert(self, ctx: FilterContext, triggered_filters: dict[FilterList, Iterable[str]]) -> None:
+ """Build an alert message from the filter context, and send it via the alert webhook."""
+ if not self.webhook:
+ return
+
+ name = f"{ctx.event.name.replace('_', ' ').title()} Filter"
+ embed = await build_mod_alert(ctx, triggered_filters)
+ # There shouldn't be more than 10, but if there are it's not very useful to send them all.
+ await self.webhook.send(
+ username=name, content=ctx.alert_content, embeds=[embed, *ctx.alert_embeds][:10], view=AlertView(ctx)
+ )
+
+ def _increment_stats(self, triggered_filters: dict[AtomicList, list[Filter]]) -> None:
+ """Increment the stats for every filter triggered."""
+ for filters in triggered_filters.values():
+ for filter_ in filters:
+ if isinstance(filter_, UniqueFilter):
+ self.bot.stats.incr(f"filters.{filter_.name}")
+
+ async def _recently_alerted_name(self, member: discord.Member) -> bool:
+ """When it hasn't been `HOURS_BETWEEN_NICKNAME_ALERTS` since last alert, return False, otherwise True."""
+ if last_alert := await self.name_alerts.get(member.id):
+ last_alert = arrow.get(last_alert)
+ if arrow.utcnow() - last_alert < datetime.timedelta(days=HOURS_BETWEEN_NICKNAME_ALERTS):
+ log.trace(f"Last alert was too recent for {member}'s nickname.")
+ return True
+
+ return False
+
+ @lock_arg("filtering.check_bad_name", "ctx", attrgetter("author.id"))
+ async def _check_bad_name(self, ctx: FilterContext) -> None:
+ """Check filter triggers in the passed context - a member's display name."""
+ if await self._recently_alerted_name(ctx.author):
+ return
+
+ name = ctx.content
+ normalised_name = unicodedata.normalize("NFKC", name)
+ cleaned_normalised_name = "".join([c for c in normalised_name if not unicodedata.combining(c)])
+
+ # Run filters against normalised, cleaned normalised and the original name,
+ # in case there are filters for one but not another.
+ names_to_check = (name, normalised_name, cleaned_normalised_name)
+
+ new_ctx = ctx.replace(content=" ".join(names_to_check))
+ result_actions, list_messages, triggers = await self._resolve_action(new_ctx)
+ if result_actions:
+ await result_actions.action(ctx)
+ if ctx.send_alert:
+ await self._send_alert(ctx, list_messages) # `ctx` has the original content.
+ # Update time when alert sent
+ await self.name_alerts.set(ctx.author.id, arrow.utcnow().timestamp())
+ self._increment_stats(triggers)
+
+ async def _resolve_list_type_and_name(
+ self, ctx: Context, list_type: ListType | None = None, list_name: str | None = None, *, exclude: str = ""
+ ) -> tuple[ListType, FilterList] | None:
+ """Prompt the user to complete the list type or list name if one of them is missing."""
+ if list_name is None:
+ args = [list_type] if exclude != "list_type" else []
+ await ctx.send(
+ "The **list_name** argument is unspecified. Please pick a value from the options below:",
+ view=ArgumentCompletionView(ctx, args, "list_name", list(self.filter_lists), 1, None)
+ )
+ return None
+
+ filter_list = self._get_list_by_name(list_name)
+ if list_type is None:
+ if len(filter_list) > 1:
+ args = [list_name] if exclude != "list_name" else []
+ await ctx.send(
+ "The **list_type** argument is unspecified. Please pick a value from the options below:",
+ view=ArgumentCompletionView(
+ ctx, args, "list_type", [option.name for option in ListType], 0, list_type_converter
+ )
+ )
+ return None
+ list_type = list(filter_list)[0]
+ return list_type, filter_list
+
+ def _get_list_by_name(self, list_name: str) -> FilterList:
+ """Get a filter list by its name, or raise an error if there's no such list."""
+ log.trace(f"Getting the filter list matching the name {list_name}")
+ filter_list = self.filter_lists.get(list_name)
+ if not filter_list:
+ if list_name.endswith("s"): # The user may have attempted to use the plural form.
+ filter_list = self.filter_lists.get(list_name[:-1])
+ if not filter_list:
+ raise BadArgument(f"There's no filter list named {list_name!r}.")
+ log.trace(f"Found list named {filter_list.name}")
+ return filter_list
+
+ @staticmethod
+ async def _send_list(ctx: Context, filter_list: FilterList, list_type: ListType) -> None:
+ """Show the list of filters identified by the list name and type."""
+ if list_type not in filter_list:
+ await ctx.send(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.")
+ return
+
+ lines = list(map(str, filter_list[list_type].filters.values()))
+ log.trace(f"Sending a list of {len(lines)} filters.")
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name=f"List of {filter_list[list_type].label}s ({len(lines)} total)")
+
+ await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True)
+
+ def _get_filter_by_id(self, id_: int) -> Optional[tuple[Filter, FilterList, ListType]]:
+ """Get the filter object corresponding to the provided ID, along with its containing list and list type."""
+ for filter_list in self.filter_lists.values():
+ for list_type, sublist in filter_list.items():
+ if id_ in sublist.filters:
+ return sublist.filters[id_], filter_list, list_type
+
+ async def _add_filter(
+ self,
+ ctx: Context,
+ noui: Optional[Literal["noui"]],
+ list_type: ListType,
+ filter_list: FilterList,
+ content: str,
+ description_and_settings: Optional[str] = None
+ ) -> None:
+ """Add a filter to the database."""
+ # Validations.
+ if list_type not in filter_list:
+ await ctx.reply(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.")
+ return
+ filter_type = filter_list.get_filter_type(content)
+ if not filter_type:
+ await ctx.reply(f":x: Could not find a filter type appropriate for `{content}`.")
+ return
+ # Parse the description and settings.
+ description, settings, filter_settings = description_and_settings_converter(
+ filter_list,
+ list_type,
+ filter_type,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ description_and_settings
+ )
+
+ if noui: # Add directly with no UI.
+ try:
+ await self._post_new_filter(
+ ctx.message, filter_list, list_type, filter_type, content, description, settings, filter_settings
+ )
+ except ResponseCodeError as e:
+ await ctx.reply(embed=format_response_error(e))
+ except ValueError as e:
+ raise BadArgument(str(e))
+ return
+ # Bring up the UI.
+ embed = Embed(colour=Colour.blue())
+ embed.description = f"`{content}`" if content else "*No content*"
+ if description:
+ embed.description += f" - {description}"
+ embed.set_author(
+ name=f"New Filter - {filter_list[list_type].label}".title())
+ embed.set_footer(text=(
+ "Field names with an asterisk have values which override the defaults of the containing filter list. "
+ f"To view all defaults of the list, "
+ f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`."
+ ))
+
+ view = filters_ui.FilterEditView(
+ filter_list,
+ list_type,
+ filter_type,
+ content,
+ description,
+ settings,
+ filter_settings,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ ctx.author,
+ embed,
+ self._post_new_filter
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
+ @staticmethod
+ def _identical_filters_message(content: str, filter_list: FilterList, list_type: ListType, filter_: Filter) -> str:
+ """Returns all the filters in the list with content identical to the content supplied."""
+ if list_type not in filter_list:
+ return ""
+ duplicates = [
+ f for f in filter_list[list_type].filters.values()
+ if f.content == content and f.id != filter_.id
+ ]
+ msg = ""
+ if duplicates:
+ msg = f"\n:warning: The filter(s) #{', #'.join(str(dup.id) for dup in duplicates)} have the same content. "
+ msg += "Please make sure this is intentional."
+
+ return msg
+
+ @staticmethod
+ async def _maybe_alert_auto_infraction(
+ filter_list: FilterList, list_type: ListType, filter_: Filter, old_filter: Filter | None = None
+ ) -> None:
+ """If the filter is new and applies an auto-infraction, or was edited to apply a different one, log it."""
+ infraction_type = filter_.overrides[0].get("infraction_type")
+ if not infraction_type:
+ infraction_type = filter_list[list_type].default("infraction_type")
+ if old_filter:
+ old_infraction_type = old_filter.overrides[0].get("infraction_type")
+ if not old_infraction_type:
+ old_infraction_type = filter_list[list_type].default("infraction_type")
+ if infraction_type == old_infraction_type:
+ return
+
+ if infraction_type != Infraction.NONE:
+ filter_log = bot.instance.get_channel(Channels.filter_log)
+ if filter_log:
+ await filter_log.send(
+ f":warning: Heads up! The new {filter_list[list_type].label} filter "
+ f"({filter_}) will automatically {infraction_type.name.lower()} users."
+ )
+
+ async def _post_new_filter(
+ self,
+ msg: Message,
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ content: str,
+ description: str | None,
+ settings: dict,
+ filter_settings: dict
+ ) -> None:
+ """POST the data of the new filter to the site API."""
+ valid, error_msg = filter_type.validate_filter_settings(filter_settings)
+ if not valid:
+ raise BadArgument(f"Error while validating filter-specific settings: {error_msg}")
+
+ content, description = await filter_type.process_input(content, description)
+
+ list_id = filter_list[list_type].id
+ description = description or None
+ payload = {
+ "filter_list": list_id, "content": content, "description": description,
+ "additional_field": json.dumps(filter_settings), **settings
+ }
+ response = await bot.instance.api_client.post('bot/filter/filters', json=to_serializable(payload))
+ new_filter = filter_list.add_filter(list_type, response)
+ log.info(f"Added new filter: {new_filter}.")
+ if new_filter:
+ await self._maybe_alert_auto_infraction(filter_list, list_type, new_filter)
+ extra_msg = Filtering._identical_filters_message(content, filter_list, list_type, new_filter)
+ await msg.reply(f"✅ Added filter: {new_filter}" + extra_msg)
+ else:
+ await msg.reply(":x: Could not create the filter. Are you sure it's implemented?")
+
+ async def _patch_filter(
+ self,
+ filter_: Filter,
+ msg: Message,
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ content: str,
+ description: str | None,
+ settings: dict,
+ filter_settings: dict
+ ) -> None:
+ """PATCH the new data of the filter to the site API."""
+ valid, error_msg = filter_type.validate_filter_settings(filter_settings)
+ if not valid:
+ raise BadArgument(f"Error while validating filter-specific settings: {error_msg}")
+
+ if content != filter_.content:
+ content, description = await filter_type.process_input(content, description)
+
+ # If the setting is not in `settings`, the override was either removed, or there wasn't one in the first place.
+ for current_settings in (filter_.actions, filter_.validations):
+ if current_settings:
+ for setting_entry in current_settings.values():
+ settings.update({setting: None for setting in setting_entry.dict() if setting not in settings})
+
+ # Even though the list ID remains unchanged, it still needs to be provided for correct serializer validation.
+ list_id = filter_list[list_type].id
+ description = description or None
+ payload = {
+ "filter_list": list_id, "content": content, "description": description,
+ "additional_field": json.dumps(filter_settings), **settings
+ }
+ response = await bot.instance.api_client.patch(
+ f'bot/filter/filters/{filter_.id}', json=to_serializable(payload)
+ )
+ # Return type can be None, but if it's being edited then it's not supposed to be.
+ edited_filter = filter_list.add_filter(list_type, response)
+ log.info(f"Successfully patched filter {edited_filter}.")
+ await self._maybe_alert_auto_infraction(filter_list, list_type, edited_filter, filter_)
+ extra_msg = Filtering._identical_filters_message(content, filter_list, list_type, edited_filter)
+ await msg.reply(f"✅ Edited filter: {edited_filter}" + extra_msg)
+
+ async def _post_filter_list(self, msg: Message, list_name: str, list_type: ListType, settings: dict) -> None:
+ """POST the new data of the filter list to the site API."""
+ payload = {"name": list_name, "list_type": list_type.value, **to_serializable(settings)}
+ filterlist_name = f"{past_tense(list_type.name.lower())} {list_name}"
+ response = await bot.instance.api_client.post('bot/filter/filter_lists', json=payload)
+ log.info(f"Successfully posted the new {filterlist_name} filterlist.")
+ self._load_raw_filter_list(response)
+ await msg.reply(f"✅ Added a new filter list: {filterlist_name}")
+
+ @staticmethod
+ async def _patch_filter_list(msg: Message, filter_list: FilterList, list_type: ListType, settings: dict) -> None:
+ """PATCH the new data of the filter list to the site API."""
+ list_id = filter_list[list_type].id
+ response = await bot.instance.api_client.patch(
+ f'bot/filter/filter_lists/{list_id}', json=to_serializable(settings)
+ )
+ log.info(f"Successfully patched the {filter_list[list_type].label} filterlist, reloading...")
+ filter_list.pop(list_type, None)
+ filter_list.add_list(response)
+ await msg.reply(f"✅ Edited filter list: {filter_list[list_type].label}")
+
+ def _filter_match_query(
+ self, filter_: Filter, settings_query: dict, filter_settings_query: dict, differ_by_default: set[str]
+ ) -> bool:
+ """Return whether the given filter matches the query."""
+ override_matches = set()
+ overrides, _ = filter_.overrides
+ for setting_name, setting_value in settings_query.items():
+ if setting_name not in overrides:
+ continue
+ if repr_equals(overrides[setting_name], setting_value):
+ override_matches.add(setting_name)
+ else: # If an override doesn't match then the filter doesn't match.
+ return False
+ if not (differ_by_default <= override_matches): # The overrides didn't cover for the default mismatches.
+ return False
+
+ filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {}
+ # If the dict changes then some fields were not the same.
+ return (filter_settings | filter_settings_query) == filter_settings
+
+ def _search_filter_list(
+ self, atomic_list: AtomicList, filter_type: type[Filter] | None, settings: dict, filter_settings: dict
+ ) -> list[Filter]:
+ """Find all filters in the filter list which match the settings."""
+ # If the default answers are known, only the overrides need to be checked for each filter.
+ all_defaults = atomic_list.defaults.dict()
+ match_by_default = set()
+ differ_by_default = set()
+ for setting_name, setting_value in settings.items():
+ if repr_equals(all_defaults[setting_name], setting_value):
+ match_by_default.add(setting_name)
+ else:
+ differ_by_default.add(setting_name)
+
+ result_filters = []
+ for filter_ in atomic_list.filters.values():
+ if filter_type and not isinstance(filter_, filter_type):
+ continue
+ if self._filter_match_query(filter_, settings, filter_settings, differ_by_default):
+ result_filters.append(filter_)
+
+ return result_filters
+
+ async def _search_filters(
+ self, message: Message, filter_type: type[Filter] | None, settings: dict, filter_settings: dict
+ ) -> None:
+ """Find all filters which match the settings and display them."""
+ lines = []
+ result_count = 0
+ for filter_list in self.filter_lists.values():
+ if filter_type and filter_type not in filter_list.filter_types:
+ continue
+ for atomic_list in filter_list.values():
+ list_results = self._search_filter_list(atomic_list, filter_type, settings, filter_settings)
+ if list_results:
+ lines.append(f"**{atomic_list.label.title()}**")
+ lines.extend(map(str, list_results))
+ lines.append("")
+ result_count += len(list_results)
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name=f"Search Results ({result_count} total)")
+ ctx = await bot.instance.get_context(message)
+ await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True)
+
+ async def _delete_offensive_msg(self, msg: Mapping[str, int]) -> None:
+ """Delete an offensive message, and then delete it from the DB."""
+ try:
+ channel = self.bot.get_channel(msg['channel_id'])
+ if channel:
+ msg_obj = await channel.fetch_message(msg['id'])
+ await msg_obj.delete()
+ except discord.NotFound:
+ log.info(
+ f"Tried to delete message {msg['id']}, but the message can't be found "
+ f"(it has been probably already deleted)."
+ )
+ except HTTPException as e:
+ log.warning(f"Failed to delete message {msg['id']}: status {e.status}")
+
+ await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}')
+ log.info(f"Deleted the offensive message with id {msg['id']}.")
+
+ def _schedule_msg_delete(self, msg: dict) -> None:
+ """Delete an offensive message once its deletion date is reached."""
+ delete_at = arrow.get(msg['delete_date']).datetime
+ self.delete_scheduler.schedule_at(delete_at, msg['id'], self._delete_offensive_msg(msg))
+
+ async def _maybe_schedule_msg_delete(self, ctx: FilterContext, actions: ActionSettings | None) -> None:
+ """Post the message to the database and schedule it for deletion if it's not set to be deleted already."""
+ msg = ctx.message
+ if not msg or not actions or actions.get_setting("remove_context", True):
+ return
+
+ delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat()
+ data = {
+ 'id': msg.id,
+ 'channel_id': msg.channel.id,
+ 'delete_date': delete_date
+ }
+
+ try:
+ await self.bot.api_client.post('bot/offensive-messages', json=data)
+ except ResponseCodeError as e:
+ if e.status == 400 and "already exists" in e.response_json.get("id", [""])[0]:
+ log.debug(f"Offensive message {msg.id} already exists.")
+ else:
+ log.error(f"Offensive message {msg.id} failed to post: {e}")
+ else:
+ self._schedule_msg_delete(data)
+ log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}")
+
+ # endregion
+ # region: tasks
+
+ @tasks.loop(time=datetime.time(hour=18))
+ async def weekly_auto_infraction_report_task(self) -> None:
+ """Trigger an auto-infraction report to be sent if it is the desired day of the week (WEEKLY_REPORT_ISO_DAY)."""
+ if arrow.utcnow().isoweekday() != WEEKLY_REPORT_ISO_DAY:
+ return
+
+ await self.send_weekly_auto_infraction_report()
+
+ async def send_weekly_auto_infraction_report(self, channel: discord.TextChannel | discord.Thread = None) -> None:
+ """
+ Send a list of auto-infractions added in the last 7 days to the specified channel.
+
+ If `channel` is not specified, it is sent to #mod-meta.
+ """
+ log.trace("Preparing weekly auto-infraction report.")
+ seven_days_ago = arrow.utcnow().shift(days=-7)
+ if not channel:
+ log.info("Auto-infraction report: the channel to report to is missing.")
+ channel = self.bot.get_channel(Channels.mod_meta)
+ elif not is_mod_channel(channel):
+ # Silently fail if output is going to be a non-mod channel.
+ log.info(f"Auto-infraction report: the channel {channel} is not a mod channel.")
+ return
+
+ found_filters = defaultdict(list)
+ # Extract all auto-infraction filters added in the past 7 days from each filter type
+ for filter_list in self.filter_lists.values():
+ for sublist in filter_list.values():
+ default_infraction_type = sublist.default("infraction_type")
+ for filter_ in sublist.filters.values():
+ if max(filter_.created_at, filter_.updated_at) < seven_days_ago:
+ continue
+ infraction_type = filter_.overrides[0].get("infraction_type")
+ if (
+ (infraction_type and infraction_type != Infraction.NONE)
+ or (not infraction_type and default_infraction_type != Infraction.NONE)
+ ):
+ found_filters[sublist.label].append((filter_, infraction_type or default_infraction_type))
+
+ # Nicely format the output so each filter list type is grouped
+ lines = [f"**Auto-infraction filters added since {seven_days_ago.format('YYYY-MM-DD')}**"]
+ for list_label, filters in found_filters.items():
+ lines.append("\n".join([f"**{list_label.title()}**"]+[f"{filter_} ({infr})" for filter_, infr in filters]))
+
+ if len(lines) == 1:
+ lines.append("Nothing to show")
+
+ await channel.send("\n\n".join(lines))
+ log.info("Successfully sent auto-infraction report.")
+
+ # endregion
+
+ async def cog_unload(self) -> None:
+ """Cancel the weekly auto-infraction filter report and deletion scheduling on cog unload."""
+ self.weekly_auto_infraction_report_task.cancel()
+ self.delete_scheduler.cancel_all()
+
+
+async def setup(bot: Bot) -> None:
+ """Load the Filtering cog."""
+ await bot.add_cog(Filtering(bot))
diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py
deleted file mode 100644
index ff39700a6..000000000
--- a/bot/exts/filters/antimalware.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import typing as t
-from os.path import splitext
-
-from discord import Embed, Message, NotFound
-from discord.ext.commands import Cog
-
-from bot.bot import Bot
-from bot.constants import Channels, Filter, URLs
-from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
-from bot.log import get_logger
-
-log = get_logger(__name__)
-
-PY_EMBED_DESCRIPTION = (
- "It looks like you tried to attach a Python file - "
- f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"
-)
-
-TXT_LIKE_FILES = {".txt", ".csv", ".json"}
-TXT_EMBED_DESCRIPTION = (
- "You either uploaded a `{blocked_extension}` file or entered a message that was too long. "
- f"Please use our [paste bin]({URLs.site_schema}{URLs.site_paste}) instead."
-)
-
-DISALLOWED_EMBED_DESCRIPTION = (
- "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). "
- "We currently allow the following file types: **{joined_whitelist}**.\n\n"
- "Feel free to ask in {meta_channel_mention} if you think this is a mistake."
-)
-
-
-class AntiMalware(Cog):
- """Delete messages which contain attachments with non-whitelisted file extensions."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
-
- def _get_whitelisted_file_formats(self) -> list:
- """Get the file formats currently on the whitelist."""
- return self.bot.filter_list_cache['FILE_FORMAT.True'].keys()
-
- def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]:
- """Get an iterable containing all the disallowed extensions of attachments."""
- file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments}
- extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats())
- return extensions_blocked
-
- @Cog.listener()
- async def on_message(self, message: Message) -> None:
- """Identify messages with prohibited attachments."""
- # Return when message don't have attachment and don't moderate DMs
- if not message.attachments or not message.guild:
- return
-
- # Ignore webhook and bot messages
- if message.webhook_id or message.author.bot:
- return
-
- # Ignore code jam channels
- if getattr(message.channel, "category", None) and message.channel.category.name == JAM_CATEGORY_NAME:
- return
-
- # Check if user is staff, if is, return
- # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
- if hasattr(message.author, "roles") and any(role.id in Filter.role_whitelist for role in message.author.roles):
- return
-
- embed = Embed()
- extensions_blocked = self._get_disallowed_extensions(message)
- blocked_extensions_str = ', '.join(extensions_blocked)
- if ".py" in extensions_blocked:
- # Short-circuit on *.py files to provide a pastebin link
- embed.description = PY_EMBED_DESCRIPTION
- elif extensions := TXT_LIKE_FILES.intersection(extensions_blocked):
- # Work around Discord AutoConversion of messages longer than 2000 chars to .txt
- cmd_channel = self.bot.get_channel(Channels.bot_commands)
- embed.description = TXT_EMBED_DESCRIPTION.format(
- blocked_extension=extensions.pop(),
- cmd_channel_mention=cmd_channel.mention
- )
- elif extensions_blocked:
- meta_channel = self.bot.get_channel(Channels.meta)
- embed.description = DISALLOWED_EMBED_DESCRIPTION.format(
- joined_whitelist=', '.join(self._get_whitelisted_file_formats()),
- blocked_extensions_str=blocked_extensions_str,
- meta_channel_mention=meta_channel.mention,
- )
-
- if embed.description:
- log.info(
- f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}",
- extra={"attachment_list": [attachment.filename for attachment in message.attachments]}
- )
-
- await message.channel.send(f"Hey {message.author.mention}!", embed=embed)
-
- # Delete the offending message:
- try:
- await message.delete()
- except NotFound:
- log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
-
-
-async def setup(bot: Bot) -> None:
- """Load the AntiMalware cog."""
- await bot.add_cog(AntiMalware(bot))
diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py
deleted file mode 100644
index 0d02edabf..000000000
--- a/bot/exts/filters/antispam.py
+++ /dev/null
@@ -1,326 +0,0 @@
-import asyncio
-from collections import defaultdict
-from collections.abc import Mapping
-from dataclasses import dataclass, field
-from datetime import timedelta
-from itertools import takewhile
-from operator import attrgetter, itemgetter
-from typing import Dict, Iterable, List, Set
-
-import arrow
-from discord import Colour, Member, Message, MessageType, NotFound, TextChannel
-from discord.ext.commands import Cog
-from pydis_core.utils import scheduling
-
-from bot import rules
-from bot.bot import Bot
-from bot.constants import (
- AntiSpam as AntiSpamConfig, Channels, Colours, DEBUG_MODE, Event, Filter, Guild as GuildConfig, Icons
-)
-from bot.converters import Duration
-from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
-from bot.exts.moderation.modlog import ModLog
-from bot.log import get_logger
-from bot.utils import lock
-from bot.utils.message_cache import MessageCache
-from bot.utils.messages import format_user, send_attachments
-
-log = get_logger(__name__)
-
-RULE_FUNCTION_MAPPING = {
- 'attachments': rules.apply_attachments,
- 'burst': rules.apply_burst,
- # burst shared is temporarily disabled due to a bug
- # 'burst_shared': rules.apply_burst_shared,
- 'chars': rules.apply_chars,
- 'discord_emojis': rules.apply_discord_emojis,
- 'duplicates': rules.apply_duplicates,
- 'links': rules.apply_links,
- 'mentions': rules.apply_mentions,
- 'newlines': rules.apply_newlines,
- 'role_mentions': rules.apply_role_mentions,
-}
-
-ANTI_SPAM_RULES = AntiSpamConfig.rules.dict()
-
-
-@dataclass
-class DeletionContext:
- """Represents a Deletion Context for a single spam event."""
-
- members: frozenset[Member]
- triggered_in: TextChannel
- channels: set[TextChannel] = field(default_factory=set)
- rules: Set[str] = field(default_factory=set)
- messages: Dict[int, Message] = field(default_factory=dict)
- attachments: List[List[str]] = field(default_factory=list)
-
- async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None:
- """Adds new rule violation events to the deletion context."""
- self.rules.add(rule_name)
-
- self.channels.update(channels)
-
- for message in messages:
- if message.id not in self.messages:
- self.messages[message.id] = message
-
- # Re-upload attachments
- destination = message.guild.get_channel(Channels.attachment_log)
- urls = await send_attachments(message, destination, link_large=False)
- self.attachments.append(urls)
-
- async def upload_messages(self, actor_id: int, modlog: ModLog) -> None:
- """Method that takes care of uploading the queue and posting modlog alert."""
- triggered_by_users = ", ".join(format_user(m) for m in self.members)
- triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else ""
- channels_description = ", ".join(channel.mention for channel in self.channels)
-
- mod_alert_message = (
- f"**Triggered by:** {triggered_by_users}\n"
- f"{triggered_in_channel}"
- f"**Channels:** {channels_description}\n"
- f"**Rules:** {', '.join(rule for rule in self.rules)}\n"
- )
-
- messages_as_list = list(self.messages.values())
- first_message = messages_as_list[0]
- # For multiple messages and those with attachments or excessive newlines, use the logs API
- if any((
- len(messages_as_list) > 1,
- len(first_message.attachments) > 0,
- first_message.content.count('\n') > 15
- )):
- url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments)
- mod_alert_message += f"A complete log of the offending messages can be found [here]({url})"
- else:
- mod_alert_message += "Message:\n"
- content = first_message.clean_content
- remaining_chars = 4080 - len(mod_alert_message)
-
- if len(content) > remaining_chars:
- url = await modlog.upload_log([first_message], actor_id, self.attachments)
- log_site_msg = f"The full message can be found [here]({url})"
- content = content[:remaining_chars - (3 + len(log_site_msg))] + "..."
-
- mod_alert_message += content
-
- await modlog.send_log_message(
- content=", ".join(str(m.id) for m in self.members), # quality-of-life improvement for mobile moderators
- icon_url=Icons.filtering,
- colour=Colour(Colours.soft_red),
- title="Spam detected!",
- text=mod_alert_message,
- thumbnail=first_message.author.display_avatar.url,
- channel_id=Channels.mod_alerts,
- ping_everyone=AntiSpamConfig.ping_everyone
- )
-
-
-class AntiSpam(Cog):
- """Cog that controls our anti-spam measures."""
-
- def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None:
- self.bot = bot
- self.validation_errors = validation_errors
- self.expiration_date_converter = Duration()
-
- self.message_deletion_queue = dict()
-
- # Fetch the rule configuration with the highest rule interval.
- max_interval_config = max(
- ANTI_SPAM_RULES.values(),
- key=itemgetter('interval')
- )
- self.max_interval = max_interval_config['interval']
- self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True)
-
- @property
- def mod_log(self) -> ModLog:
- """Allows for easy access of the ModLog cog."""
- return self.bot.get_cog("ModLog")
-
- async def cog_load(self) -> None:
- """Unloads the cog and alerts admins if configuration validation failed."""
- await self.bot.wait_until_guild_available()
- if self.validation_errors:
- body = "**The following errors were encountered:**\n"
- body += "\n".join(f"- {error}" for error in self.validation_errors.values())
- body += "\n\n**The cog has been unloaded.**"
-
- await self.mod_log.send_log_message(
- title="Error: AntiSpam configuration validation failed!",
- text=body,
- ping_everyone=True,
- icon_url=Icons.token_removed,
- colour=Colour.red()
- )
-
- await self.bot.remove_cog(self.__class__.__name__)
- return
-
- @Cog.listener()
- async def on_message(self, message: Message) -> None:
- """Applies the antispam rules to each received message."""
- if (
- not message.guild
- or message.guild.id != GuildConfig.id
- or message.author.bot
- or (getattr(message.channel, "category", None) and message.channel.category.name == JAM_CATEGORY_NAME)
- or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE)
- or (any(role.id in Filter.role_whitelist for role in message.author.roles) and not DEBUG_MODE)
- or message.type == MessageType.auto_moderation_action
- ):
- return
-
- self.cache.append(message)
-
- earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.max_interval)
- relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))
-
- for rule_name, rule_config in ANTI_SPAM_RULES.items():
- rule_function = RULE_FUNCTION_MAPPING[rule_name]
-
- # Create a list of messages that were sent in the interval that the rule cares about.
- latest_interesting_stamp = arrow.utcnow() - timedelta(seconds=rule_config['interval'])
- messages_for_rule = list(
- takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) # noqa: B023
- )
-
- result = await rule_function(message, messages_for_rule, rule_config)
-
- # If the rule returns `None`, that means the message didn't violate it.
- # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])`
- # which contains the reason for why the message violated the rule and
- # an iterable of all members that violated the rule.
- if result is not None:
- self.bot.stats.incr(f"mod_alerts.{rule_name}")
- reason, members, relevant_messages = result
- full_reason = f"`{rule_name}` rule: {reason}"
-
- # If there's no spam event going on for this channel, start a new Message Deletion Context
- authors_set = frozenset(members)
- if authors_set not in self.message_deletion_queue:
- log.trace(f"Creating queue for members `{authors_set}`")
- self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel)
- scheduling.create_task(
- self._process_deletion_context(authors_set),
- name=f"AntiSpam._process_deletion_context({authors_set})"
- )
-
- # Add the relevant of this trigger to the Deletion Context
- await self.message_deletion_queue[authors_set].add(
- rule_name=rule_name,
- channels=set(message.channel for message in relevant_messages),
- messages=relevant_messages
- )
-
- for member in members:
- scheduling.create_task(
- self.punish(message, member, full_reason),
- name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"
- )
-
- await self.maybe_delete_messages(relevant_messages)
- break
-
- @lock.lock_arg("antispam.punish", "member", attrgetter("id"))
- async def punish(self, msg: Message, member: Member, reason: str) -> None:
- """Punishes the given member for triggering an antispam rule."""
- if not member.is_timed_out():
- remove_timeout_after = AntiSpamConfig.remove_timeout_after
-
- # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes
- context = await self.bot.get_context(msg)
- command = self.bot.get_command("timeout")
- context.author = context.guild.get_member(self.bot.user.id)
- context.command = command
-
- # Since we're going to invoke the timeout command directly, we need to manually call the converter.
- dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_timeout_after}S")
- await context.invoke(
- command,
- member,
- dt_remove_role_after,
- reason=reason
- )
-
- async def maybe_delete_messages(self, messages: List[Message]) -> None:
- """Cleans the messages if cleaning is configured."""
- if AntiSpamConfig.clean_offending:
- # If we have more than one message, we can use bulk delete.
- if len(messages) > 1:
- message_ids = [message.id for message in messages]
- self.mod_log.ignore(Event.message_delete, *message_ids)
- channel_messages = defaultdict(list)
- for message in messages:
- channel_messages[message.channel].append(message)
- for channel, messages in channel_messages.items():
- try:
- await channel.delete_messages(messages)
- except NotFound:
- # In the rare case where we found messages matching the
- # spam filter across multiple channels, it is possible
- # that a single channel will only contain a single message
- # to delete. If that should be the case, discord.py will
- # use the "delete single message" endpoint instead of the
- # bulk delete endpoint, and the single message deletion
- # endpoint will complain if you give it that does not exist.
- # As this means that we have no other message to delete in
- # this channel (and message deletes work per-channel),
- # we can just log an exception and carry on with business.
- log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.")
-
- # Otherwise, the bulk delete endpoint will throw up.
- # Delete the message directly instead.
- else:
- self.mod_log.ignore(Event.message_delete, messages[0].id)
- try:
- await messages[0].delete()
- except NotFound:
- log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.")
-
- async def _process_deletion_context(self, context_id: frozenset) -> None:
- """Processes the Deletion Context queue."""
- log.trace("Sleeping before processing message deletion queue.")
- await asyncio.sleep(10)
-
- if context_id not in self.message_deletion_queue:
- log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!")
- return
-
- deletion_context = self.message_deletion_queue.pop(context_id)
- await deletion_context.upload_messages(self.bot.user.id, self.mod_log)
-
- @Cog.listener()
- async def on_message_edit(self, before: Message, after: Message) -> None:
- """Updates the message in the cache, if it's cached."""
- self.cache.update(after)
-
-
-def validate_config(rules_: Mapping = ANTI_SPAM_RULES) -> Dict[str, str]:
- """Validates the antispam configs."""
- validation_errors = {}
- for name, config in rules_.items():
- config = config
- if name not in RULE_FUNCTION_MAPPING:
- log.error(
- f"Unrecognized antispam rule `{name}`. "
- f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}"
- )
- validation_errors[name] = f"`{name}` is not recognized as an antispam rule."
- continue
- for required_key in ('interval', 'max'):
- if required_key not in config:
- log.error(
- f"`{required_key}` is required but was not "
- f"set in rule `{name}`'s configuration."
- )
- validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`"
- return validation_errors
-
-
-async def setup(bot: Bot) -> None:
- """Validate the AntiSpam configs and load the AntiSpam cog."""
- validation_errors = validate_config()
- await bot.add_cog(AntiSpam(bot, validation_errors))
diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py
deleted file mode 100644
index 538744204..000000000
--- a/bot/exts/filters/filter_lists.py
+++ /dev/null
@@ -1,359 +0,0 @@
-import datetime
-import re
-from collections import defaultdict
-from typing import Optional
-
-import arrow
-import discord
-from discord.ext import tasks
-from discord.ext.commands import BadArgument, Cog, Context, IDConverter, command, group, has_any_role
-from pydis_core.site_api import ResponseCodeError
-
-from bot import constants
-from bot.bot import Bot
-from bot.constants import Channels, Colours
-from bot.converters import ValidDiscordServerInvite, ValidFilterListType
-from bot.log import get_logger
-from bot.pagination import LinePaginator
-from bot.utils.channel import is_mod_channel
-
-log = get_logger(__name__)
-WEEKLY_REPORT_ISO_DAY = 3 # 1=Monday, 7=Sunday
-
-
-class FilterLists(Cog):
- """Commands for blacklisting and whitelisting things."""
-
- methods_with_filterlist_types = [
- "allow_add",
- "allow_delete",
- "allow_get",
- "deny_add",
- "deny_delete",
- "deny_get",
- ]
-
- def __init__(self, bot: Bot) -> None:
- self.bot = bot
-
- async def cog_load(self) -> None:
- """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations."""
- await self.bot.wait_until_guild_available()
- self.weekly_autoban_report_task.start()
-
- # Add valid filterlist types to the docstrings
- valid_types = await ValidFilterListType.get_valid_types(self.bot)
- valid_types = [f"`{type_.lower()}`" for type_ in valid_types]
-
- for method_name in self.methods_with_filterlist_types:
- command = getattr(self, method_name)
- command.help = (
- f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}."
- )
-
- async def _add_data(
- self,
- ctx: Context,
- allowed: bool,
- list_type: ValidFilterListType,
- content: str,
- comment: Optional[str] = None,
- ) -> None:
- """Add an item to a filterlist."""
- allow_type = "whitelist" if allowed else "blacklist"
-
- # If this is a guild invite, we gotta validate it.
- if list_type == "GUILD_INVITE":
- guild_data = await self._validate_guild_invite(ctx, content)
- content = guild_data.get("id")
-
- # Some guild invites are autoban filters, which require the mod
- # to set a comment which includes [autoban].
- # Having the guild name in the comment is still useful when reviewing
- # filter list, so prepend it to the set comment in case some mod forgets.
- guild_name_part = f'Guild "{guild_data["name"]}"' if "name" in guild_data else None
-
- comment = " - ".join(
- comment_part
- for comment_part in (guild_name_part, comment)
- if comment_part
- )
-
- # If it's a file format, let's make sure it has a leading dot.
- elif list_type == "FILE_FORMAT" and not content.startswith("."):
- content = f".{content}"
-
- # If it's a filter token, validate the passed regex
- elif list_type == "FILTER_TOKEN":
- try:
- re.compile(content)
- except re.error as e:
- await ctx.message.add_reaction("❌")
- await ctx.send(
- f"{ctx.author.mention} that's not a valid regex! "
- f"Regex error message: {e.msg}."
- )
- return
-
- # Try to add the item to the database
- log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}")
- payload = {
- "allowed": allowed,
- "type": list_type,
- "content": content,
- "comment": comment,
- }
-
- try:
- item = await self.bot.api_client.post(
- "bot/filter-lists",
- json=payload
- )
- except ResponseCodeError as e:
- if e.status == 400:
- await ctx.message.add_reaction("❌")
- log.debug(
- f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, "
- "probably because the request violated the UniqueConstraint."
- )
- raise BadArgument(
- f"Unable to add the item to the {allow_type}. "
- "The item probably already exists. Keep in mind that a "
- "blacklist and a whitelist for the same item cannot co-exist, "
- "and we do not permit any duplicates."
- )
- raise
-
- # If it is an autoban trigger we send a warning in #filter-log
- if comment and "[autoban]" in comment:
- await self.bot.get_channel(Channels.filter_log).send(
- f":warning: Heads-up! The new `{list_type}` filter "
- f"`{content}` (`{comment}`) will automatically ban users."
- )
-
- # Insert the item into the cache
- self.bot.insert_item_into_filter_list_cache(item)
- await ctx.message.add_reaction("✅")
-
- async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None:
- """Remove an item from a filterlist."""
- allow_type = "whitelist" if allowed else "blacklist"
-
- # If this is a server invite, we need to convert it.
- if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content):
- guild_data = await self._validate_guild_invite(ctx, content)
- content = guild_data.get("id")
-
- # If it's a file format, let's make sure it has a leading dot.
- elif list_type == "FILE_FORMAT" and not content.startswith("."):
- content = f".{content}"
-
- # Find the content and delete it.
- log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}")
- item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content)
-
- if item is not None:
- try:
- await self.bot.api_client.delete(
- f"bot/filter-lists/{item['id']}"
- )
- del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content]
- await ctx.message.add_reaction("✅")
- except ResponseCodeError as e:
- log.debug(
- f"{ctx.author} tried to delete an item with the id {item['id']}, but "
- f"the API raised an unexpected error: {e}"
- )
- await ctx.message.add_reaction("❌")
- else:
- await ctx.message.add_reaction("❌")
-
- async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None:
- """Paginate and display all items in a filterlist."""
- allow_type = "whitelist" if allowed else "blacklist"
- result = self.bot.filter_list_cache[f"{list_type}.{allowed}"]
-
- # Build a list of lines we want to show in the paginator
- lines = []
- for content, metadata in result.items():
- line = f"• `{content}`"
-
- if comment := metadata.get("comment"):
- line += f" - {comment}"
-
- lines.append(line)
- lines = sorted(lines)
-
- # Build the embed
- list_type_plural = list_type.lower().replace("_", " ").title() + "s"
- embed = discord.Embed(
- title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)",
- colour=Colours.blue
- )
- log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}")
-
- if result:
- await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False)
- else:
- embed.description = "Hmmm, seems like there's nothing here yet."
- await ctx.send(embed=embed)
- await ctx.message.add_reaction("❌")
-
- async def _sync_data(self, ctx: Context) -> None:
- """Syncs the filterlists with the API."""
- try:
- log.trace("Attempting to sync FilterList cache with data from the API.")
- await self.bot.cache_filter_list_data()
- await ctx.message.add_reaction("✅")
- except ResponseCodeError as e:
- log.debug(
- f"{ctx.author} tried to sync FilterList cache data but "
- f"the API raised an unexpected error: {e}"
- )
- await ctx.message.add_reaction("❌")
-
- @staticmethod
- async def _validate_guild_invite(ctx: Context, invite: str) -> dict:
- """
- Validates a guild invite, and returns the guild info as a dict.
-
- Will raise a BadArgument if the guild invite is invalid.
- """
- log.trace(f"Attempting to validate whether or not {invite} is a guild invite.")
- validator = ValidDiscordServerInvite()
- guild_data = await validator.convert(ctx, invite)
-
- # If we make it this far without raising a BadArgument, the invite is
- # valid. Let's return a dict of guild information.
- log.trace(f"{invite} validated as server invite. Converting to ID.")
- return guild_data
-
- @group(aliases=("allowlist", "allow", "al", "wl"))
- async def whitelist(self, ctx: Context) -> None:
- """Group for whitelisting commands."""
- if not ctx.invoked_subcommand:
- await ctx.send_help(ctx.command)
-
- @group(aliases=("denylist", "deny", "bl", "dl"))
- async def blacklist(self, ctx: Context) -> None:
- """Group for blacklisting commands."""
- if not ctx.invoked_subcommand:
- await ctx.send_help(ctx.command)
-
- @whitelist.command(name="add", aliases=("a", "set"))
- async def allow_add(
- self,
- ctx: Context,
- list_type: ValidFilterListType,
- content: str,
- *,
- comment: Optional[str] = None,
- ) -> None:
- """Add an item to the specified allowlist."""
- await self._add_data(ctx, True, list_type, content, comment)
-
- @blacklist.command(name="add", aliases=("a", "set"))
- async def deny_add(
- self,
- ctx: Context,
- list_type: ValidFilterListType,
- content: str,
- *,
- comment: Optional[str] = None,
- ) -> None:
- """Add an item to the specified denylist."""
- await self._add_data(ctx, False, list_type, content, comment)
-
- @whitelist.command(name="remove", aliases=("delete", "rm",))
- async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None:
- """Remove an item from the specified allowlist."""
- await self._delete_data(ctx, True, list_type, content)
-
- @blacklist.command(name="remove", aliases=("delete", "rm",))
- async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None:
- """Remove an item from the specified denylist."""
- await self._delete_data(ctx, False, list_type, content)
-
- @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show"))
- async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None:
- """Get the contents of a specified allowlist."""
- await self._list_all_data(ctx, True, list_type)
-
- @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show"))
- async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None:
- """Get the contents of a specified denylist."""
- await self._list_all_data(ctx, False, list_type)
-
- @whitelist.command(name="sync", aliases=("s",))
- async def allow_sync(self, ctx: Context) -> None:
- """Syncs both allowlists and denylists with the API."""
- await self._sync_data(ctx)
-
- @blacklist.command(name="sync", aliases=("s",))
- async def deny_sync(self, ctx: Context) -> None:
- """Syncs both allowlists and denylists with the API."""
- await self._sync_data(ctx)
-
- @command(name="filter_report")
- async def force_send_weekly_report(self, ctx: Context) -> None:
- """Respond with a list of autobans added in the last 7 days."""
- await self.send_weekly_autoban_report(ctx.channel)
-
- @tasks.loop(time=datetime.time(hour=18))
- async def weekly_autoban_report_task(self) -> None:
- """Trigger autoban report to be sent if it is the desired day of the week (WEEKLY_REPORT_ISO_DAY)."""
- if arrow.utcnow().isoweekday() != WEEKLY_REPORT_ISO_DAY:
- return
-
- await self.send_weekly_autoban_report()
-
- async def send_weekly_autoban_report(self, channel: discord.abc.Messageable = None) -> None:
- """
- Send a list of autobans added in the last 7 days to the specified channel.
-
- If chanel is not specified, it is sent to #mod-meta.
- """
- seven_days_ago = arrow.utcnow().shift(days=-7)
- if not channel:
- channel = self.bot.get_channel(Channels.mod_meta)
- elif not is_mod_channel(channel):
- # Silently fail if output is going to be a non-mod channel.
- return
-
- added_autobans = defaultdict(list)
- # Extract all autoban filters added in the past 7 days from each filter type
- for filter_list, filters in self.bot.filter_list_cache.items():
- filter_type, allow = filter_list.split(".")
- allow_type = "Allow list" if allow.lower() == "true" else "Deny list"
-
- for filter_content, filter_details in filters.items():
- created_at = arrow.get(filter_details["created_at"])
- updated_at = arrow.get(filter_details["updated_at"])
- # Default to empty string so that the in check below doesn't error on None type
- comment = filter_details["comment"] or ""
- if max(created_at, updated_at) > seven_days_ago and "[autoban]" in comment:
- line = f"`{filter_content}`: {comment}"
- added_autobans[f"**{filter_type} {allow_type}**"].append(line)
-
- # Nicely format the output so each filter list type is grouped
- lines = [f"**Autoban filters added since {seven_days_ago.format('YYYY-MM-DD')}**"]
- for filter_list, recently_added_autobans in added_autobans.items():
- lines.append("\n".join([filter_list]+recently_added_autobans))
-
- if len(lines) == 1:
- lines.append("Nothing to show")
-
- await channel.send("\n\n".join(lines))
-
- async def cog_check(self, ctx: Context) -> bool:
- """Only allow moderators to invoke the commands in this cog."""
- return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)
-
- async def cog_unload(self) -> None:
- """Cancel the weekly autoban filter report on cog unload."""
- self.weekly_autoban_report_task.cancel()
-
-
-async def setup(bot: Bot) -> None:
- """Load the FilterLists cog."""
- await bot.add_cog(FilterLists(bot))
diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py
deleted file mode 100644
index 23a6f2d92..000000000
--- a/bot/exts/filters/filtering.py
+++ /dev/null
@@ -1,743 +0,0 @@
-import asyncio
-import re
-import unicodedata
-import urllib.parse
-from datetime import timedelta
-from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
-
-import arrow
-import dateutil.parser
-import regex
-import tldextract
-from async_rediscache import RedisCache
-from dateutil.relativedelta import relativedelta
-from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member, Message, NotFound, TextChannel
-from discord.ext.commands import Cog
-from discord.utils import escape_markdown
-from pydis_core.site_api import ResponseCodeError
-from pydis_core.utils import scheduling
-from pydis_core.utils.regex import DISCORD_INVITE
-
-from bot.bot import Bot
-from bot.constants import Bot as BotConfig, Channels, Colours, Filter, Guild, Icons, URLs
-from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
-from bot.exts.moderation.modlog import ModLog
-from bot.log import get_logger
-from bot.utils.helpers import remove_subdomain_from_url
-from bot.utils.messages import format_user
-
-log = get_logger(__name__)
-
-
-# Regular expressions
-CODE_BLOCK_RE = re.compile(
- r"(?P<delim>``?)[^`]+?(?P=delim)(?!`+)" # Inline codeblock
- r"|```(.+?)```", # Multiline codeblock
- re.DOTALL | re.MULTILINE
-)
-EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{Guild.id}>|@here")
-SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL)
-URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE)
-
-# Exclude variation selectors from zalgo because they're actually invisible.
-VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF"
-INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1)
-ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1)
-
-# Other constants.
-DAYS_BETWEEN_ALERTS = 3
-OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days)
-
-# Autoban
-LINK_PASSWORD = "https://support.discord.com/hc/en-us/articles/218410947-I-forgot-my-Password-Where-can-I-set-a-new-one"
-LINK_2FA = "https://support.discord.com/hc/en-us/articles/219576828-Setting-up-Two-Factor-Authentication"
-AUTO_BAN_REASON = (
- "Your account has been used to send links to a phishing website. You have been automatically banned. "
- "If you are not aware of sending them, that means your account has been compromised.\n\n"
-
- f"Here is a guide from Discord on [how to change your password]({LINK_PASSWORD}).\n\n"
-
- f"We also highly recommend that you [enable 2 factor authentication on your account]({LINK_2FA}), "
- "for heightened security.\n\n"
-
- "Once you have changed your password, feel free to follow the instructions at the bottom of "
- "this message to appeal your ban."
-)
-AUTO_BAN_DURATION = timedelta(days=4)
-
-FilterMatch = Union[re.Match, dict, bool, List[Embed]]
-
-
-class Stats(NamedTuple):
- """Additional stats on a triggered filter to append to a mod log."""
-
- message_content: str
- additional_embeds: Optional[List[Embed]]
-
-
-class Filtering(Cog):
- """Filtering out invites, blacklisting domains, and warning us of certain regular expressions."""
-
- # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent
- name_alerts = RedisCache()
-
- def __init__(self, bot: Bot):
- self.bot = bot
- self.scheduler = scheduling.Scheduler(self.__class__.__name__)
- self.name_lock = asyncio.Lock()
-
- staff_mistake_str = "If you believe this was a mistake, please let staff know!"
- self.filters = {
- "filter_zalgo": {
- "enabled": Filter.filter_zalgo,
- "function": self._has_zalgo,
- "type": "filter",
- "content_only": True,
- "user_notification": Filter.notify_user_zalgo,
- "notification_msg": (
- "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). "
- f"{staff_mistake_str}"
- ),
- "schedule_deletion": False
- },
- "filter_invites": {
- "enabled": Filter.filter_invites,
- "function": self._has_invites,
- "type": "filter",
- "content_only": True,
- "user_notification": Filter.notify_user_invites,
- "notification_msg": (
- f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n"
- r"Our server rules can be found here: <https://pythondiscord.com/pages/rules>"
- ),
- "schedule_deletion": False
- },
- "filter_domains": {
- "enabled": Filter.filter_domains,
- "function": self._has_urls,
- "type": "filter",
- "content_only": True,
- "user_notification": Filter.notify_user_domains,
- "notification_msg": (
- f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}"
- ),
- "schedule_deletion": False
- },
- "watch_regex": {
- "enabled": Filter.watch_regex,
- "function": self._has_watch_regex_match,
- "type": "watchlist",
- "content_only": True,
- "schedule_deletion": True
- },
- "watch_rich_embeds": {
- "enabled": Filter.watch_rich_embeds,
- "function": self._has_rich_embed,
- "type": "watchlist",
- "content_only": False,
- "schedule_deletion": False
- },
- "filter_everyone_ping": {
- "enabled": Filter.filter_everyone_ping,
- "function": self._has_everyone_ping,
- "type": "filter",
- "content_only": True,
- "user_notification": Filter.notify_user_everyone_ping,
- "notification_msg": (
- "Please don't try to ping `@everyone` or `@here`. "
- f"Your message has been removed. {staff_mistake_str}"
- ),
- "schedule_deletion": False,
- "ping_everyone": False
- },
- }
-
- async def cog_unload(self) -> None:
- """Cancel scheduled tasks."""
- self.scheduler.cancel_all()
-
- def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list:
- """Fetch items from the filter_list_cache."""
- return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys()
-
- def _get_filterlist_value(self, list_type: str, value: Any, *, allowed: bool) -> dict:
- """Fetch one specific value from filter_list_cache."""
- return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"][value]
-
- @staticmethod
- def _expand_spoilers(text: str) -> str:
- """Return a string containing all interpretations of a spoilered message."""
- split_text = SPOILER_RE.split(text)
- return ''.join(
- split_text[0::2] + split_text[1::2] + split_text
- )
-
- @property
- def mod_log(self) -> ModLog:
- """Get currently loaded ModLog cog instance."""
- return self.bot.get_cog("ModLog")
-
- @Cog.listener()
- async def on_message(self, msg: Message) -> None:
- """Invoke message filter for new messages."""
- await self._filter_message(msg)
-
- # Ignore webhook messages.
- if msg.webhook_id is None:
- await self.check_bad_words_in_name(msg.author)
-
- @Cog.listener()
- async def on_message_edit(self, before: Message, after: Message) -> None:
- """
- Invoke message filter for message edits.
-
- Also calculates the time delta from the previous edit or when message was sent if there's no prior edits.
- """
- # We only care about changes to the message contents/attachments and embed additions, not pin status etc.
- if all((
- before.content == after.content, # content hasn't changed
- before.attachments == after.attachments, # attachments haven't changed
- len(before.embeds) >= len(after.embeds) # embeds haven't been added
- )):
- return
-
- if not before.edited_at:
- delta = relativedelta(after.edited_at, before.created_at).microseconds
- else:
- delta = relativedelta(after.edited_at, before.edited_at).microseconds
- await self._filter_message(after, delta)
-
- @Cog.listener()
- async def on_voice_state_update(self, member: Member, *_) -> None:
- """Checks for bad words in usernames when users join, switch or leave a voice channel."""
- await self.check_bad_words_in_name(member)
-
- def get_name_match(self, name: str) -> Optional[re.Match]:
- """Check bad words from passed string (name). Return the first match found."""
- normalised_name = unicodedata.normalize("NFKC", name)
- cleaned_normalised_name = "".join([c for c in normalised_name if not unicodedata.combining(c)])
-
- # Run filters against normalised, cleaned normalised and the original name,
- # in case we have filters for one but not the other.
- names_to_check = (name, normalised_name, cleaned_normalised_name)
-
- watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False)
- for pattern in watchlist_patterns:
- for name in names_to_check:
- if match := re.search(pattern, name, flags=re.IGNORECASE):
- return match
- return None
-
- async def check_send_alert(self, member: Member) -> bool:
- """When there is less than 3 days after last alert, return `False`, otherwise `True`."""
- if last_alert := await self.name_alerts.get(member.id):
- last_alert = arrow.get(last_alert)
- if arrow.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert:
- log.trace(f"Last alert was too recent for {member}'s nickname.")
- return False
-
- return True
-
- async def check_bad_words_in_name(self, member: Member) -> None:
- """Send a mod alert every 3 days if a username still matches a watchlist pattern."""
- # Use lock to avoid race conditions
- async with self.name_lock:
- # Check if we recently alerted about this user first,
- # to avoid running all the filter tokens against their name again.
- if not await self.check_send_alert(member):
- return
-
- # Check whether the users display name contains any words in our blacklist
- match = self.get_name_match(member.display_name)
- if not match:
- return
-
- log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).")
-
- log_string = (
- f"**User:** {format_user(member)}\n"
- f"**Display Name:** {escape_markdown(member.display_name)}\n"
- f"**Bad Match:** {match.group()}"
- )
-
- await self.mod_log.send_log_message(
- content=str(member.id), # quality-of-life improvement for mobile moderators
- icon_url=Icons.token_removed,
- colour=Colours.soft_red,
- title="Username filtering alert",
- text=log_string,
- channel_id=Channels.mod_alerts,
- thumbnail=member.display_avatar.url,
- ping_everyone=True
- )
-
- # Update time when alert sent
- await self.name_alerts.set(member.id, arrow.utcnow().timestamp())
-
- async def filter_snekbox_output(self, result: str, msg: Message) -> bool:
- """
- Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly.
-
- Also requires the original message, to check whether to filter and for mod logs.
- Returns whether a filter was triggered or not.
- """
- filter_triggered = False
- # Should we filter this message?
- if self._check_filter(msg):
- for filter_name, _filter in self.filters.items():
- # Is this specific filter enabled in the config?
- # We also do not need to worry about filters that take the full message,
- # since all we have is an arbitrary string.
- if _filter["enabled"] and _filter["content_only"]:
- filter_result = await _filter["function"](result)
- reason = None
-
- if isinstance(filter_result, tuple):
- match, reason = filter_result
- else:
- match = filter_result
-
- if match:
- # If this is a filter (not a watchlist), we set the variable so we know
- # that it has been triggered
- if _filter["type"] == "filter":
- filter_triggered = True
-
- stats = self._add_stats(filter_name, match, result)
- await self._send_log(filter_name, _filter, msg, stats, reason, is_eval=True)
-
- break # We don't want multiple filters to trigger
-
- return filter_triggered
-
- async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None:
- """Filter the input message to see if it violates any of our rules, and then respond accordingly."""
- # Should we filter this message?
- if self._check_filter(msg):
- for filter_name, _filter in self.filters.items():
- # Is this specific filter enabled in the config?
- if _filter["enabled"]:
- # Double trigger check for the embeds filter
- if filter_name == "watch_rich_embeds":
- # If the edit delta is less than 0.001 seconds, then we're probably dealing
- # with a double filter trigger.
- if delta is not None and delta < 100:
- continue
-
- if filter_name in ("filter_invites", "filter_everyone_ping"):
- # Disable invites filter in codejam team channels
- category = getattr(msg.channel, "category", None)
- if category and category.name == JAM_CATEGORY_NAME:
- continue
-
- # Does the filter only need the message content or the full message?
- if _filter["content_only"]:
- payload = msg.content
- else:
- payload = msg
-
- result = await _filter["function"](payload)
- reason = None
-
- if isinstance(result, tuple):
- match, reason = result
- else:
- match = result
-
- if match:
- is_private = msg.channel.type is ChannelType.private
-
- # If this is a filter (not a watchlist) and not in a DM, delete the message.
- if _filter["type"] == "filter" and not is_private:
- try:
- # Embeds (can?) trigger both the `on_message` and `on_message_edit`
- # event handlers, triggering filtering twice for the same message.
- #
- # If `on_message`-triggered filtering already deleted the message
- # then `on_message_edit`-triggered filtering will raise exception
- # since the message no longer exists.
- #
- # In addition, to avoid sending two notifications to the user, the
- # logs, and mod_alert, we return if the message no longer exists.
- await msg.delete()
- except NotFound:
- return
-
- # Notify the user if the filter specifies
- if _filter["user_notification"]:
- await self.notify_member(msg.author, _filter["notification_msg"], msg.channel)
-
- # If the message is classed as offensive, we store it in the site db and
- # it will be deleted after one week.
- if _filter["schedule_deletion"] and not is_private:
- delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat()
- data = {
- 'id': msg.id,
- 'channel_id': msg.channel.id,
- 'delete_date': delete_date
- }
-
- try:
- await self.bot.api_client.post('bot/offensive-messages', json=data)
- except ResponseCodeError as e:
- if e.status == 400 and "already exists" in e.response_json.get("id", [""])[0]:
- log.debug(f"Offensive message {msg.id} already exists.")
- else:
- log.error(f"Offensive message {msg.id} failed to post: {e}")
- else:
- self.schedule_msg_delete(data)
- log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}")
-
- stats = self._add_stats(filter_name, match, msg.content)
-
- # If the filter reason contains `[autoban]`, we want to auto-ban the user.
- # Also pass this to _send_log so mods are not pinged filter matches that are auto-actioned
- autoban = reason and "[autoban]" in reason.lower()
- if not autoban and filter_name == "filter_invites" and isinstance(result, dict):
- autoban = any(
- "[autoban]" in invite_info["reason"].lower()
- for invite_info in result.values()
- if invite_info.get("reason")
- )
-
- await self._send_log(filter_name, _filter, msg, stats, reason, autoban=autoban)
-
- if autoban:
- # Create a new context, with the author as is the bot, and the channel as #mod-alerts.
- # This sends the ban confirmation directly under watchlist trigger embed, to inform
- # mods that the user was auto-banned for the message.
- context = await self.bot.get_context(msg)
- context.guild = self.bot.get_guild(Guild.id)
- context.author = context.guild.get_member(self.bot.user.id)
- context.channel = self.bot.get_channel(Channels.mod_alerts)
- context.command = self.bot.get_command("tempban")
-
- await context.invoke(
- context.command,
- msg.author,
- (arrow.utcnow() + AUTO_BAN_DURATION).datetime,
- reason=AUTO_BAN_REASON
- )
-
- break # We don't want multiple filters to trigger
-
- async def _send_log(
- self,
- filter_name: str,
- _filter: Dict[str, Any],
- msg: Message,
- stats: Stats,
- reason: Optional[str] = None,
- *,
- is_eval: bool = False,
- autoban: bool = False,
- ) -> None:
- """Send a mod log for a triggered filter."""
- if msg.channel.type is ChannelType.private:
- channel_str = "via DM"
- ping_everyone = False
- else:
- channel_str = f"in {msg.channel.mention}"
- # Allow specific filters to override ping_everyone
- ping_everyone = Filter.ping_everyone and _filter.get("ping_everyone", True)
-
- content = str(msg.author.id) # quality-of-life improvement for mobile moderators
-
- # If we are going to autoban, we don't want to ping and don't need the user ID
- if autoban:
- ping_everyone = False
- content = None
-
- eval_msg = f"using {BotConfig.prefix}eval " if is_eval else ""
- footer = f"Reason: {reason}" if reason else None
- message = (
- f"The {filter_name} {_filter['type']} was triggered by {format_user(msg.author)} "
- f"{channel_str} {eval_msg}with [the following message]({msg.jump_url}):\n\n"
- f"{stats.message_content}"
- )
-
- log.debug(message)
-
- # Send pretty mod log embed to mod-alerts
- await self.mod_log.send_log_message(
- content=content,
- icon_url=Icons.filtering,
- colour=Colour(Colours.soft_red),
- title=f"{_filter['type'].title()} triggered!",
- text=message,
- thumbnail=msg.author.display_avatar.url,
- channel_id=Channels.mod_alerts,
- ping_everyone=ping_everyone,
- additional_embeds=stats.additional_embeds,
- footer=footer,
- )
-
- def _add_stats(self, name: str, match: FilterMatch, content: str) -> Stats:
- """Adds relevant statistical information to the relevant filter and increments the bot's stats."""
- # Word and match stats for watch_regex
- if name == "watch_regex":
- surroundings = match.string[max(match.start() - 10, 0): match.end() + 10]
- message_content = (
- f"**Match:** '{match[0]}'\n"
- f"**Location:** '...{escape_markdown(surroundings)}...'\n"
- f"\n**Original Message:**\n{escape_markdown(content)}"
- )
- else: # Use original content
- message_content = content
-
- additional_embeds = None
-
- self.bot.stats.incr(f"filters.{name}")
-
- # The function returns True for invalid invites.
- # They have no data so additional embeds can't be created for them.
- if name == "filter_invites" and match is not True:
- additional_embeds = []
- for _, data in match.items():
- reason = f"Reason: {data['reason']} | " if data.get('reason') else ""
- embed = Embed(description=(
- f"**Members:**\n{data['members']}\n"
- f"**Active:**\n{data['active']}"
- ))
- embed.set_author(name=data["name"])
- embed.set_thumbnail(url=data["icon"])
- embed.set_footer(text=f"{reason}Guild ID: {data['id']}")
- additional_embeds.append(embed)
-
- elif name == "watch_rich_embeds":
- additional_embeds = match
-
- return Stats(message_content, additional_embeds)
-
- @staticmethod
- def _check_filter(msg: Message) -> bool:
- """Check whitelists to see if we should filter this message."""
- role_whitelisted = False
-
- if type(msg.author) is Member: # Only Member has roles, not User.
- for role in msg.author.roles:
- if role.id in Filter.role_whitelist:
- role_whitelisted = True
-
- return (
- msg.channel.id not in Filter.channel_whitelist # Channel not in whitelist
- and not role_whitelisted # Role not in whitelist
- and not msg.author.bot # Author not a bot
- )
-
- async def _has_watch_regex_match(self, text: str) -> Tuple[Union[bool, re.Match], Optional[str]]:
- """
- Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs.
-
- `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is
- matched as-is. Spoilers are expanded, if any, and URLs are ignored.
- Second return value is a reason written to database about blacklist entry (can be None).
- """
- if SPOILER_RE.search(text):
- text = self._expand_spoilers(text)
-
- text = self.clean_input(text)
-
- watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False)
- for pattern in watchlist_patterns:
- match = re.search(pattern, text, flags=re.IGNORECASE)
- if match:
- return match, self._get_filterlist_value('filter_token', pattern, allowed=False)['comment']
-
- return False, None
-
- async def _has_urls(self, text: str) -> Tuple[bool, Optional[str]]:
- """
- Returns True if the text contains one of the blacklisted URLs from the config file.
-
- Second return value is a reason of URL blacklisting (can be None).
- """
- text = self.clean_input(text)
-
- domain_blacklist = self._get_filterlist_items("domain_name", allowed=False)
- for match in URL_RE.finditer(text):
- for url in domain_blacklist:
- if url.lower() in match.group(1).lower():
- blacklisted_parsed = tldextract.extract(url.lower())
- url_parsed = tldextract.extract(match.group(1).lower())
- if blacklisted_parsed.registered_domain == url_parsed.registered_domain:
- return True, self._get_filterlist_value("domain_name", url, allowed=False)["comment"]
- return False, None
-
- @staticmethod
- async def _has_zalgo(text: str) -> bool:
- """
- Returns True if the text contains zalgo characters.
-
- Zalgo range is \u0300 – \u036F and \u0489.
- """
- return bool(ZALGO_RE.search(text))
-
- async def _has_invites(self, text: str) -> Union[dict, bool]:
- """
- Checks if there's any invites in the text content that aren't in the guild whitelist.
-
- If any are detected, a dictionary of invite data is returned, with a key per invite.
- If none are detected, False is returned.
- If we are unable to process an invite, True is returned.
-
- Attempts to catch some of common ways to try to cheat the system.
- """
- text = self.clean_input(text)
-
- # Remove backslashes to prevent escape character fuckaroundery like
- # discord\.gg/gdudes-pony-farm
- text = text.replace("\\", "")
-
- invites = [m.group("invite") for m in DISCORD_INVITE.finditer(text)]
- invite_data = dict()
- for invite in invites:
- invite = urllib.parse.quote_plus(invite.rstrip("/"))
- if invite in invite_data:
- continue
-
- response = await self.bot.http_session.get(
- f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"}
- )
- response = await response.json()
- guild = response.get("guild")
- if guild is None:
- # Lack of a "guild" key in the JSON response indicates either an group DM invite, an
- # expired invite, or an invalid invite. The API does not currently differentiate
- # between invalid and expired invites
- return True
-
- guild_id = guild.get("id")
- guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True)
- guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False)
-
- # Is this invite allowed?
- guild_partnered_or_verified = (
- 'PARTNERED' in guild.get("features", [])
- or 'VERIFIED' in guild.get("features", [])
- )
- invite_not_allowed = (
- guild_id in guild_invite_blacklist # Blacklisted guilds are never permitted.
- or guild_id not in guild_invite_whitelist # Whitelisted guilds are always permitted.
- and not guild_partnered_or_verified # Otherwise guilds have to be Verified or Partnered.
- )
-
- if invite_not_allowed:
- reason = None
- if guild_id in guild_invite_blacklist:
- reason = self._get_filterlist_value("guild_invite", guild_id, allowed=False)["comment"]
-
- guild_icon_hash = guild["icon"]
- guild_icon = (
- "https://cdn.discordapp.com/icons/"
- f"{guild_id}/{guild_icon_hash}.png?size=512"
- )
-
- invite_data[invite] = {
- "name": guild["name"],
- "id": guild['id'],
- "icon": guild_icon,
- "members": response["approximate_member_count"],
- "active": response["approximate_presence_count"],
- "reason": reason
- }
-
- return invite_data if invite_data else False
-
- @staticmethod
- async def _has_rich_embed(msg: Message) -> Union[bool, List[Embed]]:
- """Determines if `msg` contains any rich embeds not auto-generated from a URL."""
- if msg.embeds:
- for embed in msg.embeds:
- if embed.type == "rich":
- urls = URL_RE.findall(msg.content)
- final_urls = set(urls)
- # This is due to way discord renders relative urls in Embdes
- # if we send the following url: https://mobile.twitter.com/something
- # Discord renders it as https://twitter.com/something
- for url in urls:
- final_urls.add(remove_subdomain_from_url(url))
- if not embed.url or embed.url not in final_urls:
- # If `embed.url` does not exist or if `embed.url` is not part of the content
- # of the message, it's unlikely to be an auto-generated embed by Discord.
- return msg.embeds
- else:
- log.trace(
- "Found a rich embed sent by a regular user account, "
- "but it was likely just an automatic URL embed."
- )
- return False
- return False
-
- @staticmethod
- async def _has_everyone_ping(text: str) -> bool:
- """Determines if `msg` contains an @everyone or @here ping outside of a codeblock."""
- # First pass to avoid running re.sub on every message
- if not EVERYONE_PING_RE.search(text):
- return False
-
- content_without_codeblocks = CODE_BLOCK_RE.sub("", text)
- return bool(EVERYONE_PING_RE.search(content_without_codeblocks))
-
- async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None:
- """
- Notify filtered_member about a moderation action with the reason str.
-
- First attempts to DM the user, fall back to in-channel notification if user has DMs disabled
- """
- try:
- await filtered_member.send(reason)
- except Forbidden:
- await channel.send(f"{filtered_member.mention} {reason}")
-
- def schedule_msg_delete(self, msg: dict) -> None:
- """Delete an offensive message once its deletion date is reached."""
- delete_at = dateutil.parser.isoparse(msg['delete_date'])
- self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg))
-
- async def cog_load(self) -> None:
- """Get all the pending message deletion from the API and reschedule them."""
- await self.bot.wait_until_ready()
- response = await self.bot.api_client.get('bot/offensive-messages',)
-
- now = arrow.utcnow()
-
- for msg in response:
- delete_at = dateutil.parser.isoparse(msg['delete_date'])
-
- if delete_at < now:
- await self.delete_offensive_msg(msg)
- else:
- self.schedule_msg_delete(msg)
-
- async def delete_offensive_msg(self, msg: Mapping[str, int]) -> None:
- """Delete an offensive message, and then delete it from the db."""
- try:
- channel = self.bot.get_channel(msg['channel_id'])
- if channel:
- msg_obj = await channel.fetch_message(msg['id'])
- await msg_obj.delete()
- except NotFound:
- log.info(
- f"Tried to delete message {msg['id']}, but the message can't be found "
- f"(it has been probably already deleted)."
- )
- except HTTPException as e:
- log.warning(f"Failed to delete message {msg['id']}: status {e.status}")
-
- await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}')
- log.info(f"Deleted the offensive message with id {msg['id']}.")
-
- @staticmethod
- def clean_input(string: str) -> str:
- """Remove zalgo and invisible characters from `string`."""
- # For future consideration: remove characters in the Mc, Sk, and Lm categories too.
- # Can be normalised with form C to merge char + combining char into a single char to avoid
- # removing legit diacritics, but this would open up a way to bypass filters.
- no_zalgo = ZALGO_RE.sub("", string)
- return INVISIBLE_RE.sub("", no_zalgo)
-
-
-async def setup(bot: Bot) -> None:
- """Load the Filtering cog."""
- await bot.add_cog(Filtering(bot))
diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py
deleted file mode 100644
index b42613804..000000000
--- a/bot/exts/filters/webhook_remover.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import re
-
-from discord import Colour, Message, NotFound
-from discord.ext.commands import Cog
-
-from bot.bot import Bot
-from bot.constants import Channels, Colours, Event, Icons
-from bot.exts.moderation.modlog import ModLog
-from bot.log import get_logger
-from bot.utils.messages import format_user
-
-WEBHOOK_URL_RE = re.compile(
- r"((?:https?:\/\/)?(?:ptb\.|canary\.)?discord(?:app)?\.com\/api\/webhooks\/\d+\/)\S+\/?",
- re.IGNORECASE
-)
-
-ALERT_MESSAGE_TEMPLATE = (
- "{user}, looks like you posted a Discord webhook URL. Therefore, your "
- "message has been removed, and your webhook has been deleted. "
- "You can re-create it if you wish to. If you believe this was a "
- "mistake, please let us know."
-)
-
-log = get_logger(__name__)
-
-
-class WebhookRemover(Cog):
- """Scan messages to detect Discord webhooks links."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
-
- @property
- def mod_log(self) -> ModLog:
- """Get current instance of `ModLog`."""
- return self.bot.get_cog("ModLog")
-
- async def delete_and_respond(self, msg: Message, redacted_url: str, *, webhook_deleted: bool) -> None:
- """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`."""
- # Don't log this, due internal delete, not by user. Will make different entry.
- self.mod_log.ignore(Event.message_delete, msg.id)
-
- try:
- await msg.delete()
- except NotFound:
- log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.")
- return
-
- await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention))
- if webhook_deleted:
- delete_state = "The webhook was successfully deleted."
- else:
- delete_state = "There was an error when deleting the webhook, it might have already been removed."
- message = (
- f"{format_user(msg.author)} posted a Discord webhook URL to {msg.channel.mention}. {delete_state} "
- f"Webhook URL was `{redacted_url}`"
- )
- log.debug(message)
-
- # Send entry to moderation alerts.
- await self.mod_log.send_log_message(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Discord webhook URL removed!",
- text=message,
- thumbnail=msg.author.display_avatar.url,
- channel_id=Channels.mod_alerts
- )
-
- self.bot.stats.incr("tokens.removed_webhooks")
-
- @Cog.listener()
- async def on_message(self, msg: Message) -> None:
- """Check if a Discord webhook URL is in `message`."""
- # Ignore DMs; can't delete messages in there anyway.
- if not msg.guild or msg.author.bot:
- return
-
- matches = WEBHOOK_URL_RE.search(msg.content)
- if matches:
- async with self.bot.http_session.delete(matches[0]) as resp:
- # The Discord API Returns a 204 NO CONTENT response on success.
- deleted_successfully = resp.status == 204
- await self.delete_and_respond(msg, matches[1] + "xxx", webhook_deleted=deleted_successfully)
-
- @Cog.listener()
- async def on_message_edit(self, before: Message, after: Message) -> None:
- """Check if a Discord webhook URL is in the edited message `after`."""
- await self.on_message(after)
-
-
-async def setup(bot: Bot) -> None:
- """Load `WebhookRemover` cog."""
- await bot.add_cog(WebhookRemover(bot))
diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py
index 073a91a53..e72f32887 100644
--- a/bot/exts/info/codeblock/_cog.py
+++ b/bot/exts/info/codeblock/_cog.py
@@ -8,8 +8,8 @@ from pydis_core.utils import scheduling
from bot import constants
from bot.bot import Bot
-from bot.exts.filters.token_remover import TokenRemover
-from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE
+from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter
+from bot.exts.filtering._filters.unique.webhook import WEBHOOK_URL_RE
from bot.exts.help_channels._channel import is_help_forum_post
from bot.exts.info.codeblock._instructions import get_instructions
from bot.log import get_logger
@@ -135,7 +135,7 @@ class CodeBlockCog(Cog, name="Code Block"):
not message.author.bot
and self.is_valid_channel(message.channel)
and has_lines(message.content, constants.CodeBlock.minimum_lines)
- and not TokenRemover.find_token_in_message(message)
+ and not DiscordTokenFilter.find_token_in_message(message.content)
and not WEBHOOK_URL_RE.search(message.content)
)
diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py
index fd9404b1a..aee751345 100644
--- a/bot/exts/moderation/clean.py
+++ b/bot/exts/moderation/clean.py
@@ -19,6 +19,7 @@ from bot.converters import Age, ISODateTime
from bot.exts.moderation.modlog import ModLog
from bot.log import get_logger
from bot.utils.channel import is_mod_channel
+from bot.utils.messages import upload_log
log = get_logger(__name__)
@@ -351,7 +352,7 @@ class Clean(Cog):
# Reverse the list to have reverse chronological order
log_messages = reversed(messages)
- log_url = await self.mod_log.upload_log(log_messages, ctx.author.id)
+ log_url = await upload_log(log_messages, ctx.author.id)
# Build the embed and send it
if channels == "*":
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index d61a3fa5c..e785216c9 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -14,7 +14,6 @@ from bot.bot import Bot
from bot.constants import Channels, Event
from bot.converters import Age, Duration, DurationOrExpiry, MemberOrUser, UnambiguousMemberOrUser
from bot.decorators import ensure_future_timestamp, respect_role_hierarchy
-from bot.exts.filters.filtering import AUTO_BAN_DURATION, AUTO_BAN_REASON
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._scheduler import InfractionScheduler
from bot.log import get_logger
@@ -30,6 +29,23 @@ if t.TYPE_CHECKING:
from bot.exts.moderation.watchchannels.bigbrother import BigBrother
+# Comp ban
+LINK_PASSWORD = "https://support.discord.com/hc/en-us/articles/218410947-I-forgot-my-Password-Where-can-I-set-a-new-one"
+LINK_2FA = "https://support.discord.com/hc/en-us/articles/219576828-Setting-up-Two-Factor-Authentication"
+COMP_BAN_REASON = (
+ "Your account has been used to send links to a phishing website. You have been automatically banned. "
+ "If you are not aware of sending them, that means your account has been compromised.\n\n"
+
+ f"Here is a guide from Discord on [how to change your password]({LINK_PASSWORD}).\n\n"
+
+ f"We also highly recommend that you [enable 2 factor authentication on your account]({LINK_2FA}), "
+ "for heightened security.\n\n"
+
+ "Once you have changed your password, feel free to follow the instructions at the bottom of "
+ "this message to appeal your ban."
+)
+COMP_BAN_DURATION = timedelta(days=4)
+# Timeout
MAXIMUM_TIMEOUT_DAYS = timedelta(days=28)
TIMEOUT_CAP_MESSAGE = (
f"The timeout for {{0}} can't be longer than {MAXIMUM_TIMEOUT_DAYS.days} days."
@@ -51,7 +67,7 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Permanent infractions
- @command()
+ @command(aliases=("warning",))
async def warn(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:
"""Warn a user for the given reason."""
if not isinstance(user, Member):
@@ -147,7 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog):
@command()
async def compban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None:
"""Same as cleanban, but specifically with the ban reason and duration used for compromised accounts."""
- await self.cleanban(ctx, user, duration=(arrow.utcnow() + AUTO_BAN_DURATION).datetime, reason=AUTO_BAN_REASON)
+ await self.cleanban(ctx, user, duration=(arrow.utcnow() + COMP_BAN_DURATION).datetime, reason=COMP_BAN_REASON)
@command(aliases=("vban",))
async def voiceban(self, ctx: Context) -> None:
diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py
index 2c94d1af8..47a21753c 100644
--- a/bot/exts/moderation/modlog.py
+++ b/bot/exts/moderation/modlog.py
@@ -3,7 +3,6 @@ import difflib
import itertools
import typing as t
from datetime import datetime, timezone
-from itertools import zip_longest
import discord
from dateutil.relativedelta import relativedelta
@@ -12,14 +11,12 @@ from discord import Colour, Message, Thread
from discord.abc import GuildChannel
from discord.ext.commands import Cog, Context
from discord.utils import escape_markdown, format_dt, snowflake_time
-from pydis_core.site_api import ResponseCodeError
-from sentry_sdk import add_breadcrumb
from bot.bot import Bot
-from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs
+from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles
from bot.log import get_logger
from bot.utils import time
-from bot.utils.messages import format_user
+from bot.utils.messages import format_user, upload_log
log = get_logger(__name__)
@@ -45,48 +42,6 @@ class ModLog(Cog, name="ModLog"):
self._cached_edits = []
- async def upload_log(
- self,
- messages: t.Iterable[discord.Message],
- actor_id: int,
- attachments: t.Iterable[t.List[str]] = None
- ) -> str:
- """Upload message logs to the database and return a URL to a page for viewing the logs."""
- if attachments is None:
- attachments = []
-
- deletedmessage_set = [
- {
- "id": message.id,
- "author": message.author.id,
- "channel_id": message.channel.id,
- "content": message.content.replace("\0", ""), # Null chars cause 400.
- "embeds": [embed.to_dict() for embed in message.embeds],
- "attachments": attachment,
- }
- for message, attachment in zip_longest(messages, attachments, fillvalue=[])
- ]
-
- try:
- response = await self.bot.api_client.post(
- "bot/deleted-messages",
- json={
- "actor": actor_id,
- "creation": datetime.now(timezone.utc).isoformat(),
- "deletedmessage_set": deletedmessage_set,
- }
- )
- except ResponseCodeError as e:
- add_breadcrumb(
- category="api_error",
- message=str(e),
- level="error",
- data=deletedmessage_set,
- )
- raise
-
- return f"{URLs.site_logs_view}/{response['id']}"
-
def ignore(self, event: Event, *items: int) -> None:
"""Add event to ignored events to suppress log emission."""
for item in items:
@@ -604,7 +559,7 @@ class ModLog(Cog, name="ModLog"):
remaining_chars = 4090 - len(response)
if len(content) > remaining_chars:
- botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id)
+ botlog_url = await upload_log(messages=[message], actor_id=message.author.id)
ending = f"\n\nMessage truncated, [full message here]({botlog_url})."
truncation_point = remaining_chars - len(ending)
content = f"{content[:truncation_point]}...{ending}"
diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py
index bc70a8c1d..7566021c5 100644
--- a/bot/exts/moderation/watchchannels/_watchchannel.py
+++ b/bot/exts/moderation/watchchannels/_watchchannel.py
@@ -14,8 +14,8 @@ from pydis_core.utils import scheduling
from bot.bot import Bot
from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons
-from bot.exts.filters.token_remover import TokenRemover
-from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE
+from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter
+from bot.exts.filtering._filters.unique.webhook import WEBHOOK_URL_RE
from bot.exts.moderation.modlog import ModLog
from bot.log import CustomLogger, get_logger
from bot.pagination import LinePaginator
@@ -235,7 +235,7 @@ class WatchChannel(metaclass=CogABCMeta):
await self.send_header(msg)
- if TokenRemover.find_token_in_message(msg) or WEBHOOK_URL_RE.search(msg.content):
+ if DiscordTokenFilter.find_token_in_message(msg.content) or WEBHOOK_URL_RE.search(msg.content):
cleaned_content = "Content is censored because it contains a bot or webhook token."
elif cleaned_content := msg.clean_content:
# Put all non-media URLs in a code block to prevent embeds
diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py
index b48fcf592..567fe6c24 100644
--- a/bot/exts/utils/snekbox/_cog.py
+++ b/bot/exts/utils/snekbox/_cog.py
@@ -14,10 +14,10 @@ from pydis_core.utils import interactions
from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
from bot.bot import Bot
-from bot.constants import Channels, Emojis, Filter, MODERATION_ROLES, Roles, URLs
+from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES, URLs
from bot.decorators import redirect_output
from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
-from bot.exts.filters.antimalware import TXT_LIKE_FILES
+from bot.exts.filtering._filter_lists.extension import TXT_LIKE_FILES
from bot.exts.help_channels._channel import is_help_forum_post
from bot.exts.utils.snekbox._eval import EvalJob, EvalResult
from bot.exts.utils.snekbox._io import FileAttachment
@@ -27,7 +27,7 @@ from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.services import PasteTooLongError, PasteUploadError
if TYPE_CHECKING:
- from bot.exts.filters.filtering import Filtering
+ from bot.exts.filtering.filtering import Filtering
log = get_logger(__name__)
@@ -296,7 +296,7 @@ class Snekbox(Cog):
"""Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists."""
# Check if user is staff, if is, return
# Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
- if hasattr(ctx.author, "roles") and any(role.id in Filter.role_whitelist for role in ctx.author.roles):
+ if hasattr(ctx.author, "roles") and any(role.id in STAFF_PARTNERS_COMMUNITY_ROLES for role in ctx.author.roles):
return FilteredFiles(files, [])
# Ignore code jam channels
if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME:
diff --git a/bot/pagination.py b/bot/pagination.py
index c39ce211b..679108933 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -204,6 +204,7 @@ class LinePaginator(Paginator):
footer_text: str = None,
url: str = None,
exception_on_empty_embed: bool = False,
+ reply: bool = False,
) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -254,6 +255,8 @@ class LinePaginator(Paginator):
embed.description = paginator.pages[current_page]
+ reference = ctx.message if reply else None
+
if len(paginator.pages) <= 1:
if footer_text:
embed.set_footer(text=footer_text)
@@ -264,9 +267,10 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
+
if isinstance(ctx, discord.Interaction):
return await ctx.response.send_message(embed=embed)
- return await ctx.send(embed=embed)
+ return await ctx.send(embed=embed, reference=reference)
else:
if footer_text:
embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
@@ -279,11 +283,12 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("Sending first page to channel...")
+
if isinstance(ctx, discord.Interaction):
await ctx.response.send_message(embed=embed)
message = await ctx.original_response()
else:
- message = await ctx.send(embed=embed)
+ message = await ctx.send(embed=embed, reference=reference)
log.debug("Adding emoji reactions to message...")
diff --git a/bot/rules/__init__.py b/bot/rules/__init__.py
deleted file mode 100644
index a01ceae73..000000000
--- a/bot/rules/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# flake8: noqa
-
-from .attachments import apply as apply_attachments
-from .burst import apply as apply_burst
-from .burst_shared import apply as apply_burst_shared
-from .chars import apply as apply_chars
-from .discord_emojis import apply as apply_discord_emojis
-from .duplicates import apply as apply_duplicates
-from .links import apply as apply_links
-from .mentions import apply as apply_mentions
-from .newlines import apply as apply_newlines
-from .role_mentions import apply as apply_role_mentions
diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py
deleted file mode 100644
index 8903c385c..000000000
--- a/bot/rules/attachments.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total attachments exceeding the limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
- )
- total_recent_attachments = sum(len(msg.attachments) for msg in relevant_messages)
-
- if total_recent_attachments > config['max']:
- return (
- f"sent {total_recent_attachments} attachments in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/burst.py b/bot/rules/burst.py
deleted file mode 100644
index 25c5a2f33..000000000
--- a/bot/rules/burst.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects repeated messages sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
- total_relevant = len(relevant_messages)
-
- if total_relevant > config['max']:
- return (
- f"sent {total_relevant} messages in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py
deleted file mode 100644
index bbe9271b3..000000000
--- a/bot/rules/burst_shared.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects repeated messages sent by multiple users."""
- total_recent = len(recent_messages)
-
- if total_recent > config['max']:
- return (
- f"sent {total_recent} messages in {config['interval']}s",
- set(msg.author for msg in recent_messages),
- recent_messages
- )
- return None
diff --git a/bot/rules/chars.py b/bot/rules/chars.py
deleted file mode 100644
index 1f587422c..000000000
--- a/bot/rules/chars.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total message char count exceeding the limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
-
- total_recent_chars = sum(len(msg.content) for msg in relevant_messages)
-
- if total_recent_chars > config['max']:
- return (
- f"sent {total_recent_chars} characters in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py
deleted file mode 100644
index d979ac5e7..000000000
--- a/bot/rules/discord_emojis.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import re
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-from emoji import demojize
-
-DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:")
-CODE_BLOCK_RE = re.compile(r"```.*?```", flags=re.DOTALL)
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total Discord emojis exceeding the limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
-
- # Get rid of code blocks in the message before searching for emojis.
- # Convert Unicode emojis to :emoji: format to get their count.
- total_emojis = sum(
- len(DISCORD_EMOJI_RE.findall(demojize(CODE_BLOCK_RE.sub("", msg.content))))
- for msg in relevant_messages
- )
-
- if total_emojis > config['max']:
- return (
- f"sent {total_emojis} emojis in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py
deleted file mode 100644
index 8e4fbc12d..000000000
--- a/bot/rules/duplicates.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects duplicated messages sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and msg.content == last_message.content
- and msg.content
- )
- )
-
- total_duplicated = len(relevant_messages)
-
- if total_duplicated > config['max']:
- return (
- f"sent {total_duplicated} duplicated messages in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/links.py b/bot/rules/links.py
deleted file mode 100644
index c46b783c5..000000000
--- a/bot/rules/links.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import re
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-LINK_RE = re.compile(r"(https?://[^\s]+)")
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total links exceeding the limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
- total_links = 0
- messages_with_links = 0
-
- for msg in relevant_messages:
- total_matches = len(LINK_RE.findall(msg.content))
- if total_matches:
- messages_with_links += 1
- total_links += total_matches
-
- # Only apply the filter if we found more than one message with
- # links to prevent wrongfully firing the rule on users posting
- # e.g. an installation log of pip packages from GitHub.
- if total_links > config['max'] and messages_with_links > 1:
- return (
- f"sent {total_links} links in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py
deleted file mode 100644
index 4e66e1359..000000000
--- a/bot/rules/newlines.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import re
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total newlines exceeding the set limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
-
- # Identify groups of newline characters and get group & total counts
- exp = r"(\n+)"
- newline_counts = []
- for msg in relevant_messages:
- newline_counts += [len(group) for group in re.findall(exp, msg.content)]
- total_recent_newlines = sum(newline_counts)
-
- # Get maximum newline group size
- if newline_counts:
- max_newline_group = max(newline_counts)
- else:
- # If no newlines are found, newline_counts will be an empty list, which will error out max()
- max_newline_group = 0
-
- # Check first for total newlines, if this passes then check for large groupings
- if total_recent_newlines > config['max']:
- return (
- f"sent {total_recent_newlines} newlines in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- elif max_newline_group > config['max_consecutive']:
- return (
- f"sent {max_newline_group} consecutive newlines in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
-
- return None
diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py
deleted file mode 100644
index 0649540b6..000000000
--- a/bot/rules/role_mentions.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import Dict, Iterable, List, Optional, Tuple
-
-from discord import Member, Message
-
-
-async def apply(
- last_message: Message, recent_messages: List[Message], config: Dict[str, int]
-) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
- """Detects total role mentions exceeding the limit sent by a single user."""
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
-
- total_recent_mentions = sum(len(msg.role_mentions) for msg in relevant_messages)
-
- if total_recent_mentions > config['max']:
- return (
- f"sent {total_recent_mentions} role mentions in {config['interval']}s",
- (last_message.author,),
- relevant_messages
- )
- return None
diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py
index f68d280c9..5deb2376b 100644
--- a/bot/utils/message_cache.py
+++ b/bot/utils/message_cache.py
@@ -31,20 +31,23 @@ class MessageCache:
self._start = 0
self._end = 0
- self._messages: list[t.Optional[Message]] = [None] * self.maxlen
+ self._messages: list[Message | None] = [None] * self.maxlen
self._message_id_mapping = {}
+ self._message_metadata = {}
- def append(self, message: Message) -> None:
+ def append(self, message: Message, *, metadata: dict | None = None) -> None:
"""Add the received message to the cache, depending on the order of messages defined by `newest_first`."""
if self.newest_first:
self._appendleft(message)
else:
self._appendright(message)
+ self._message_metadata[message.id] = metadata
def _appendright(self, message: Message) -> None:
"""Add the received message to the end of the cache."""
if self._is_full():
del self._message_id_mapping[self._messages[self._start].id]
+ del self._message_metadata[self._messages[self._start].id]
self._start = (self._start + 1) % self.maxlen
self._messages[self._end] = message
@@ -56,6 +59,7 @@ class MessageCache:
if self._is_full():
self._end = (self._end - 1) % self.maxlen
del self._message_id_mapping[self._messages[self._end].id]
+ del self._message_metadata[self._messages[self._end].id]
self._start = (self._start - 1) % self.maxlen
self._messages[self._start] = message
@@ -69,6 +73,7 @@ class MessageCache:
self._end = (self._end - 1) % self.maxlen
message = self._messages[self._end]
del self._message_id_mapping[message.id]
+ del self._message_metadata[message.id]
self._messages[self._end] = None
return message
@@ -80,6 +85,7 @@ class MessageCache:
message = self._messages[self._start]
del self._message_id_mapping[message.id]
+ del self._message_metadata[message.id]
self._messages[self._start] = None
self._start = (self._start + 1) % self.maxlen
@@ -89,16 +95,21 @@ class MessageCache:
"""Remove all messages from the cache."""
self._messages = [None] * self.maxlen
self._message_id_mapping = {}
+ self._message_metadata = {}
self._start = 0
self._end = 0
- def get_message(self, message_id: int) -> t.Optional[Message]:
+ def get_message(self, message_id: int) -> Message | None:
"""Return the message that has the given message ID, if it is cached."""
index = self._message_id_mapping.get(message_id, None)
return self._messages[index] if index is not None else None
- def update(self, message: Message) -> bool:
+ def get_message_metadata(self, message_id: int) -> dict | None:
+ """Return the metadata of the message that has the given message ID, if it is cached."""
+ return self._message_metadata.get(message_id, None)
+
+ def update(self, message: Message, *, metadata: dict | None = None) -> bool:
"""
Update a cached message with new contents.
@@ -108,13 +119,15 @@ class MessageCache:
if index is None:
return False
self._messages[index] = message
+ if metadata is not None:
+ self._message_metadata[message.id] = metadata
return True
def __contains__(self, message_id: int) -> bool:
"""Return True if the cache contains a message with the given ID ."""
return message_id in self._message_id_mapping
- def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]:
+ def __getitem__(self, item: int | slice) -> Message | list[Message]:
"""
Return the message(s) in the index or slice provided.
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index f6bdceaef..8d765ebfc 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -1,16 +1,21 @@
import asyncio
import random
import re
+from collections.abc import Iterable
+from datetime import datetime, timezone
from functools import partial
from io import BytesIO
from typing import Callable, List, Optional, Sequence, Union
import discord
+from discord import Message
from discord.ext.commands import Context
+from pydis_core.site_api import ResponseCodeError
from pydis_core.utils import scheduling
+from sentry_sdk import add_breadcrumb
import bot
-from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES
+from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES, URLs
from bot.log import get_logger
log = get_logger(__name__)
@@ -241,6 +246,55 @@ async def send_denial(ctx: Context, reason: str) -> discord.Message:
return await ctx.send(embed=embed)
-def format_user(user: discord.abc.User) -> str:
+def format_user(user: discord.User | discord.Member) -> str:
"""Return a string for `user` which has their mention and ID."""
return f"{user.mention} (`{user.id}`)"
+
+
+def format_channel(channel: discord.abc.Messageable) -> str:
+ """Return a string for `channel` with its mention, ID, and the parent channel if it is a thread."""
+ formatted = f"{channel.mention} ({channel.category}/#{channel}"
+ if hasattr(channel, "parent"):
+ formatted += f"/{channel.parent}"
+ formatted += ")"
+ return formatted
+
+
+async def upload_log(messages: Iterable[Message], actor_id: int, attachments: dict[int, list[str]] = None) -> str:
+ """Upload message logs to the database and return a URL to a page for viewing the logs."""
+ if attachments is None:
+ attachments = []
+ else:
+ attachments = [attachments.get(message.id, []) for message in messages]
+
+ deletedmessage_set = [
+ {
+ "id": message.id,
+ "author": message.author.id,
+ "channel_id": message.channel.id,
+ "content": message.content.replace("\0", ""), # Null chars cause 400.
+ "embeds": [embed.to_dict() for embed in message.embeds],
+ "attachments": attachment,
+ }
+ for message, attachment in zip(messages, attachments)
+ ]
+
+ try:
+ response = await bot.instance.api_client.post(
+ "bot/deleted-messages",
+ json={
+ "actor": actor_id,
+ "creation": datetime.now(timezone.utc).isoformat(),
+ "deletedmessage_set": deletedmessage_set,
+ }
+ )
+ except ResponseCodeError as e:
+ add_breadcrumb(
+ category="api_error",
+ message=str(e),
+ level="error",
+ data=deletedmessage_set,
+ )
+ raise
+
+ return f"{URLs.site_logs_view}/{response['id']}"
diff --git a/tests/bot/exts/filtering/__init__.py b/tests/bot/exts/filtering/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/exts/filtering/__init__.py
diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py
new file mode 100644
index 000000000..ef124e6ff
--- /dev/null
+++ b/tests/bot/exts/filtering/test_discord_token_filter.py
@@ -0,0 +1,276 @@
+import unittest
+from re import Match
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import arrow
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.unique import discord_token
+from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token
+from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec
+
+
+class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the DiscordTokenFilter class."""
+
+ def setUp(self):
+ """Adds the filter, a bot, and a message to the instance for usage in tests."""
+ now = arrow.utcnow().timestamp()
+ self.filter = DiscordTokenFilter({
+ "id": 1,
+ "content": "discord_token",
+ "description": None,
+ "settings": {},
+ "additional_field": "{}", # noqa: P103
+ "created_at": now,
+ "updated_at": now
+ })
+
+ self.msg = MockMessage(id=555, content="hello world")
+ self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
+
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg)
+
+ def test_extract_user_id_valid(self):
+ """Should consider user IDs valid if they decode into an integer ID."""
+ id_pairs = (
+ ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332),
+ ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904),
+ ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641),
+ )
+
+ for token_id, user_id in id_pairs:
+ with self.subTest(token_id=token_id):
+ result = DiscordTokenFilter.extract_user_id(token_id)
+ self.assertEqual(result, user_id)
+
+ def test_extract_user_id_invalid(self):
+ """Should consider non-digit and non-ASCII IDs invalid."""
+ ids = (
+ ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
+ ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
+ ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
+ ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
+ ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for user_id, msg in ids:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.extract_user_id(user_id)
+ self.assertIsNone(result)
+
+ def test_is_valid_timestamp_valid(self):
+ """Should consider timestamps valid if they're greater than the Discord epoch."""
+ timestamps = (
+ "XsyRkw",
+ "Xrim9Q",
+ "XsyR-w",
+ "XsySD_",
+ "Dn9r_A",
+ )
+
+ for timestamp in timestamps:
+ with self.subTest(timestamp=timestamp):
+ result = DiscordTokenFilter.is_valid_timestamp(timestamp)
+ self.assertTrue(result)
+
+ def test_is_valid_timestamp_invalid(self):
+ """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
+ timestamps = (
+ ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
+ ("ew", "123"),
+ ("AoIKgA", "42076800"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for timestamp, msg in timestamps:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.is_valid_timestamp(timestamp)
+ self.assertFalse(result)
+
+ def test_is_valid_hmac_valid(self):
+ """Should consider an HMAC valid if it has at least 3 unique characters."""
+ valid_hmacs = (
+ "VXmErH7j511turNpfURmb0rVNm8",
+ "Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "sJf6omBPORBPju3WJEIAcwW9Zds",
+ "s45jqDV_Iisn-symw0yDRrk_jf4",
+ )
+
+ for hmac in valid_hmacs:
+ with self.subTest(msg=hmac):
+ result = DiscordTokenFilter.is_maybe_valid_hmac(hmac)
+ self.assertTrue(result)
+
+ def test_is_invalid_hmac_invalid(self):
+ """Should consider an HMAC invalid if has fewer than 3 unique characters."""
+ invalid_hmacs = (
+ ("xxxxxxxxxxxxxxxxxx", "Single character"),
+ ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"),
+ ("ASFasfASFasfASFASsf", "Three characters alternating-case"),
+ ("asdasdasdasdasdasdasd", "Three characters one case"),
+ )
+
+ for hmac, msg in invalid_hmacs:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.is_maybe_valid_hmac(hmac)
+ self.assertFalse(result)
+
+ async def test_no_trigger_when_no_token(self):
+ """False should be returned if the message doesn't contain a Discord token."""
+ return_value = await self.filter.triggered_on(self.ctx)
+
+ self.assertFalse(return_value)
+
+ @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "Token")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE")
+ def test_find_token_valid_match(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`."""
+ matches = [
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ ]
+ tokens = [
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ ]
+
+ token_re.finditer.return_value = matches
+ token_cls.side_effect = tokens
+ extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid.
+ is_valid_timestamp.return_value = True
+ is_maybe_valid_hmac.return_value = True
+
+ return_value = DiscordTokenFilter.find_token_in_message(self.msg)
+
+ self.assertEqual(tokens[1], return_value)
+
+ @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "Token")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE")
+ def test_find_token_invalid_matches(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """None should be returned if no matches have valid user IDs, HMACs, and timestamps."""
+ token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
+ token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
+ extract_user_id.return_value = None
+ is_valid_timestamp.return_value = False
+ is_maybe_valid_hmac.return_value = False
+
+ return_value = DiscordTokenFilter.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+
+ def test_regex_invalid_tokens(self):
+ """Messages without anything looking like a token are not matched."""
+ tokens = (
+ "",
+ "lemon wins",
+ "..",
+ "x.y",
+ "x.y.",
+ ".y.z",
+ ".y.",
+ "..z",
+ "x..z",
+ " . . ",
+ "\n.\n.\n",
+ "hellö.world.bye",
+ "base64.nötbåse64.morebase64",
+ "19jd3J.dfkm3d.€víł§tüff",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = discord_token.TOKEN_RE.findall(token)
+ self.assertEqual(len(results), 0)
+
+ def test_regex_valid_tokens(self):
+ """Messages that look like tokens should be matched."""
+ # Don't worry, these tokens have been invalidated.
+ tokens = (
+ "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
+ "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
+ "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = discord_token.TOKEN_RE.fullmatch(token)
+ self.assertIsNotNone(results, f"{token} was not matched by the regex")
+
+ def test_regex_matches_multiple_valid(self):
+ """Should support multiple matches in the middle of a string."""
+ token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
+ token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
+ message = f"garbage {token_1} hello {token_2} world"
+
+ results = discord_token.TOKEN_RE.finditer(message)
+ results = [match[0] for match in results]
+ self.assertCountEqual((token_1, token_2), results)
+
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE")
+ def test_format_log_message(self, log_message):
+ """Should correctly format the log message with info from the message and token."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ log_message.format.return_value = "Howdy"
+
+ return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token)
+
+ self.assertEqual(return_value, log_message.format.return_value)
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member")
+ async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message):
+ """Should correctly format the user ID portion when the actual user it belongs to is unknown."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ unknown_user_log_message.format.return_value = " Partner"
+ get_or_fetch_member.return_value = None
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE")
+ async def test_format_userid_log_message_bot(self, known_user_log_message):
+ """Should correctly format the user ID portion when the ID belongs to a known bot."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ known_user_log_message.format.return_value = " Partner"
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (known_user_log_message.format.return_value, True))
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE")
+ async def test_format_log_message_user_token_user(self, user_token_message):
+ """Should correctly format the user ID portion when the ID belongs to a known user."""
+ token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ user_token_message.format.return_value = "Partner"
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (user_token_message.format.return_value, True))
diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py
new file mode 100644
index 000000000..0ad41116d
--- /dev/null
+++ b/tests/bot/exts/filtering/test_extension_filter.py
@@ -0,0 +1,139 @@
+import unittest
+from unittest.mock import MagicMock, patch
+
+import arrow
+
+from bot.constants import Channels
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists import extension
+from bot.exts.filtering._filter_lists.extension import ExtensionsList
+from bot.exts.filtering._filter_lists.filter_list import ListType
+from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel
+
+BOT = MockBot()
+
+
+class ExtensionsListTests(unittest.IsolatedAsyncioTestCase):
+ """Test the ExtensionsList class."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.filter_list = ExtensionsList(MagicMock())
+ now = arrow.utcnow().timestamp()
+ filters = []
+ self.whitelist = [".first", ".second", ".third"]
+ for i, filter_content in enumerate(self.whitelist, start=1):
+ filters.append({
+ "id": i, "content": filter_content, "description": None, "settings": {},
+ "additional_field": "{}", "created_at": now, "updated_at": now # noqa: P103
+ })
+ self.filter_list.add_list({
+ "id": 1,
+ "list_type": 1,
+ "created_at": now,
+ "updated_at": now,
+ "settings": {},
+ "filters": filters
+ })
+
+ self.message = MockMessage()
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message)
+
+ @patch("bot.instance", BOT)
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should trigger the whitelist and result in no actions or messages."""
+ attachment = MockAttachment(filename="python.first")
+ self.message.attachments = [attachment]
+
+ result = await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]}))
+
+ @patch("bot.instance", BOT)
+ async def test_message_without_attachment(self):
+ """Messages without attachments should return no triggers, messages, or actions."""
+ result = await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(result, (None, [], {}))
+
+ @patch("bot.instance", BOT)
+ async def test_message_with_illegal_extension(self):
+ """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ result = await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []}))
+
+ @patch("bot.instance", BOT)
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site."""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+
+ await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(self.ctx.dm_embed, extension.PY_EMBED_DESCRIPTION)
+
+ @patch("bot.instance", BOT)
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt/.json/.csv file should result in the correct embed."""
+ test_values = (
+ ("text", ".txt"),
+ ("json", ".json"),
+ ("csv", ".csv"),
+ )
+
+ for file_name, disallowed_extension in test_values:
+ with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
+
+ attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
+ self.message.attachments = [attachment]
+
+ await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(
+ self.ctx.dm_embed,
+ extension.TXT_EMBED_DESCRIPTION.format(
+ blocked_extension=disallowed_extension,
+ )
+ )
+
+ @patch("bot.instance", BOT)
+ async def test_other_disallowed_extension_embed_description(self):
+ """Test the description for a non .py/.txt/.json/.csv disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.filter_list.actions_for(self.ctx)
+ meta_channel = BOT.get_channel(Channels.meta)
+
+ self.assertEqual(
+ self.ctx.dm_embed,
+ extension.DISALLOWED_EMBED_DESCRIPTION.format(
+ joined_whitelist=", ".join(self.whitelist),
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+ )
+
+ @patch("bot.instance", BOT)
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (self.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], ["`.disallowed`"]),
+ ([".disallowed"], ["`.disallowed`"]),
+ ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{ext}") for ext in extensions]
+ result = await self.filter_list.actions_for(self.ctx)
+ self.assertCountEqual(result[1], expected_disallowed_extensions)
diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py
new file mode 100644
index 000000000..5a289c1cf
--- /dev/null
+++ b/tests/bot/exts/filtering/test_settings.py
@@ -0,0 +1,20 @@
+import unittest
+
+import bot.exts.filtering._settings
+from bot.exts.filtering._settings import create_settings
+
+
+class FilterTests(unittest.TestCase):
+ """Test functionality of the Settings class and its subclasses."""
+
+ def test_create_settings_returns_none_for_empty_data(self):
+ """`create_settings` should return a tuple of two Nones when passed an empty dict."""
+ result = create_settings({})
+
+ self.assertEqual(result, (None, None))
+
+ def test_unrecognized_entry_makes_a_warning(self):
+ """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`."""
+ create_settings({"abcd": {}})
+
+ self.assertIn("abcd", bot.exts.filtering._settings._already_warned)
diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py
new file mode 100644
index 000000000..5a1eb6fe6
--- /dev/null
+++ b/tests/bot/exts/filtering/test_settings_entries.py
@@ -0,0 +1,216 @@
+import unittest
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification
+from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass
+from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope
+from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM
+from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel
+
+
+class FilterTests(unittest.TestCase):
+ """Test functionality of the Settings class and its subclasses."""
+
+ def setUp(self) -> None:
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ message = MockMessage(author=member, channel=channel)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message)
+
+ def test_role_bypass_is_off_for_user_without_roles(self):
+ """The role bypass should trigger when a user has no roles."""
+ member = MockMember()
+ self.ctx.author = member
+ bypass_entry = RoleBypass(bypass_roles=["123"])
+
+ result = bypass_entry.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_role_bypass_is_on_for_a_user_with_the_right_role(self):
+ """The role bypass should not trigger when the user has one of its roles."""
+ cases = (
+ ([123], ["123"]),
+ ([123, 234], ["123"]),
+ ([123], ["123", "234"]),
+ ([123, 234], ["123", "234"])
+ )
+
+ for user_role_ids, bypasses in cases:
+ with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses):
+ user_roles = [MockRole(id=role_id) for role_id in user_role_ids]
+ member = MockMember(roles=user_roles)
+ self.ctx.author = member
+ bypass_entry = RoleBypass(bypass_roles=bypasses)
+
+ result = bypass_entry.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_for_empty_channel_scope(self):
+ """A filter is enabled for all channels by default."""
+ channel = MockTextChannel()
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_doesnt_trigger_for_disabled_channel(self):
+ """A filter shouldn't trigger if it's been disabled in the channel."""
+ channel = MockTextChannel(id=123)
+ scope = ChannelScope(
+ disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_in_disabled_category(self):
+ """A filter shouldn't trigger if it's been disabled in the category."""
+ channel = MockTextChannel(category=MockCategoryChannel(id=456))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_triggers_in_enabled_channel_in_disabled_category(self):
+ """A filter should trigger in an enabled channel even if it's been disabled in the category."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_triggers_inside_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_doesnt_trigger_outside_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_filtering_dms_when_necessary(self):
+ """A filter correctly ignores or triggers in a channel depending on the value of FilterDM."""
+ cases = (
+ (True, MockDMChannel(), True),
+ (False, MockDMChannel(), False),
+ (True, MockTextChannel(), True),
+ (False, MockTextChannel(), True)
+ )
+
+ for apply_in_dms, channel, expected in cases:
+ with self.subTest(apply_in_dms=apply_in_dms, channel=channel):
+ filter_dms = FilterDM(filter_dm=apply_in_dms)
+ self.ctx.channel = channel
+
+ result = filter_dms.triggers_on(self.ctx)
+
+ self.assertEqual(expected, result)
+
+ def test_infraction_merge_of_same_infraction_type(self):
+ """When both infractions are of the same type, the one with the longer duration wins."""
+ infraction1 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="hi",
+ infraction_duration=10,
+ dm_content="how",
+ dm_embed="what is",
+ infraction_channel=0
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="there",
+ infraction_duration=20,
+ dm_content="are you",
+ dm_embed="your name",
+ infraction_channel=0
+ )
+
+ result = infraction1 | infraction2
+
+ self.assertDictEqual(
+ result.dict(),
+ {
+ "infraction_type": Infraction.MUTE,
+ "infraction_reason": "there",
+ "infraction_duration": 20.0,
+ "dm_content": "are you",
+ "dm_embed": "your name",
+ "infraction_channel": 0
+ }
+ )
+
+ def test_infraction_merge_of_different_infraction_types(self):
+ """If there are two different infraction types, the one higher up the hierarchy should be picked."""
+ infraction1 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="hi",
+ infraction_duration=20,
+ dm_content="",
+ dm_embed="",
+ infraction_channel=0
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="BAN",
+ infraction_reason="",
+ infraction_duration=10,
+ dm_content="there",
+ dm_embed="",
+ infraction_channel=0
+ )
+
+ result = infraction1 | infraction2
+
+ self.assertDictEqual(
+ result.dict(),
+ {
+ "infraction_type": Infraction.BAN,
+ "infraction_reason": "",
+ "infraction_duration": 10.0,
+ "dm_content": "there",
+ "dm_embed": "",
+ "infraction_channel": 0
+ }
+ )
diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py
new file mode 100644
index 000000000..0dfc8ae9f
--- /dev/null
+++ b/tests/bot/exts/filtering/test_token_filter.py
@@ -0,0 +1,49 @@
+import unittest
+
+import arrow
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.token import TokenFilter
+from tests.helpers import MockMember, MockMessage, MockTextChannel
+
+
+class TokenFilterTests(unittest.IsolatedAsyncioTestCase):
+ """Test functionality of the token filter."""
+
+ def setUp(self) -> None:
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ message = MockMessage(author=member, channel=channel)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message)
+
+ async def test_token_filter_triggers(self):
+ """The filter should evaluate to True only if its token is found in the context content."""
+ test_cases = (
+ (r"hi", "oh hi there", True),
+ (r"hi", "goodbye", False),
+ (r"bla\d{2,4}", "bla18", True),
+ (r"bla\d{2,4}", "bla1", False),
+ # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6
+ (r"TOKEN", "https://google.com TOKEN", True),
+ (r"TOKEN", "https://google.com something else", False)
+ )
+ now = arrow.utcnow().timestamp()
+
+ for pattern, content, expected in test_cases:
+ with self.subTest(
+ pattern=pattern,
+ content=content,
+ expected=expected,
+ ):
+ filter_ = TokenFilter({
+ "id": 1,
+ "content": pattern,
+ "description": None,
+ "settings": {},
+ "additional_field": "{}", # noqa: P103
+ "created_at": now,
+ "updated_at": now
+ })
+ self.ctx.content = content
+ result = await filter_.triggered_on(self.ctx)
+ self.assertEqual(result, expected)
diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py
deleted file mode 100644
index 7282334e2..000000000
--- a/tests/bot/exts/filters/test_antimalware.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import unittest
-from unittest.mock import AsyncMock, Mock
-
-from discord import NotFound
-
-from bot.constants import Channels, STAFF_ROLES
-from bot.exts.filters import antimalware
-from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
-
-
-class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
- """Test the AntiMalware cog."""
-
- def setUp(self):
- """Sets up fresh objects for each test."""
- self.bot = MockBot()
- self.bot.filter_list_cache = {
- "FILE_FORMAT.True": {
- ".first": {},
- ".second": {},
- ".third": {},
- }
- }
- self.cog = antimalware.AntiMalware(self.bot)
- self.message = MockMessage()
- self.message.webhook_id = None
- self.message.author.bot = None
- self.whitelist = [".first", ".second", ".third"]
-
- async def test_message_with_allowed_attachment(self):
- """Messages with allowed extensions should not be deleted"""
- attachment = MockAttachment(filename="python.first")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
- self.message.delete.assert_not_called()
-
- async def test_message_without_attachment(self):
- """Messages without attachments should result in no action."""
- await self.cog.on_message(self.message)
- self.message.delete.assert_not_called()
-
- async def test_direct_message_with_attachment(self):
- """Direct messages should have no action taken."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.guild = None
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_webhook_message_with_illegal_extension(self):
- """A webhook message containing an illegal extension should be ignored."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.webhook_id = 697140105563078727
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_bot_message_with_illegal_extension(self):
- """A bot message containing an illegal extension should be ignored."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.author.bot = 409107086526644234
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_message_with_illegal_extension_gets_deleted(self):
- """A message containing an illegal extension should send an embed."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_called_once()
-
- async def test_message_send_by_staff(self):
- """A message send by a member of staff should be ignored."""
- staff_role = MockRole(id=STAFF_ROLES[0])
- self.message.author.roles.append(staff_role)
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_python_file_redirect_embed_description(self):
- """A message containing a .py file should result in an embed redirecting the user to our paste site"""
- attachment = MockAttachment(filename="python.py")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
-
- self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
-
- async def test_txt_file_redirect_embed_description(self):
- """A message containing a .txt/.json/.csv file should result in the correct embed."""
- test_values = (
- ("text", ".txt"),
- ("json", ".json"),
- ("csv", ".csv"),
- )
-
- for file_name, disallowed_extension in test_values:
- with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
-
- attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
- antimalware.TXT_EMBED_DESCRIPTION = Mock()
- antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
- cmd_channel = self.bot.get_channel(Channels.bot_commands)
-
- self.assertEqual(
- embed.description,
- antimalware.TXT_EMBED_DESCRIPTION.format.return_value
- )
- antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(
- blocked_extension=disallowed_extension,
- cmd_channel_mention=cmd_channel.mention
- )
-
- async def test_other_disallowed_extension_embed_description(self):
- """Test the description for a non .py/.txt/.json/.csv disallowed extension."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
- antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
- antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
- meta_channel = self.bot.get_channel(Channels.meta)
-
- self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
- antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
- joined_whitelist=", ".join(self.whitelist),
- blocked_extensions_str=".disallowed",
- meta_channel_mention=meta_channel.mention
- )
-
- async def test_removing_deleted_message_logs(self):
- """Removing an already deleted message logs the correct message"""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
-
- with self.assertLogs(logger=antimalware.log, level="INFO"):
- await self.cog.on_message(self.message)
- self.message.delete.assert_called_once()
-
- async def test_message_with_illegal_attachment_logs(self):
- """Deleting a message with an illegal attachment should result in a log."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- with self.assertLogs(logger=antimalware.log, level="INFO"):
- await self.cog.on_message(self.message)
-
- async def test_get_disallowed_extensions(self):
- """The return value should include all non-whitelisted extensions."""
- test_values = (
- ([], []),
- (self.whitelist, []),
- ([".first"], []),
- ([".first", ".disallowed"], [".disallowed"]),
- ([".disallowed"], [".disallowed"]),
- ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
- )
-
- for extensions, expected_disallowed_extensions in test_values:
- with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
- self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
- disallowed_extensions = self.cog._get_disallowed_extensions(self.message)
- self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
-
-
-class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase):
- """Tests setup of the `AntiMalware` cog."""
-
- async def test_setup(self):
- """Setup of the extension should call add_cog."""
- bot = MockBot()
- await antimalware.setup(bot)
- bot.add_cog.assert_awaited_once()
diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py
deleted file mode 100644
index 6a0e4fded..000000000
--- a/tests/bot/exts/filters/test_antispam.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import unittest
-
-from bot.exts.filters import antispam
-
-
-class AntispamConfigurationValidationTests(unittest.TestCase):
- """Tests validation of the antispam cog configuration."""
-
- def test_default_antispam_config_is_valid(self):
- """The default antispam configuration is valid."""
- validation_errors = antispam.validate_config()
- self.assertEqual(validation_errors, {})
-
- def test_unknown_rule_returns_error(self):
- """Configuring an unknown rule returns an error."""
- self.assertEqual(
- antispam.validate_config({'invalid-rule': {}}),
- {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."}
- )
-
- def test_missing_keys_returns_error(self):
- """Not configuring required keys returns an error."""
- keys = (('interval', 'max'), ('max', 'interval'))
- for configured_key, unconfigured_key in keys:
- with self.subTest(
- configured_key=configured_key,
- unconfigured_key=unconfigured_key
- ):
- config = {'burst': {configured_key: 10}}
- error = f"Key `{unconfigured_key}` is required but not set for rule `burst`"
-
- self.assertEqual(
- antispam.validate_config(config),
- {'burst': error}
- )
diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py
deleted file mode 100644
index e47cf627b..000000000
--- a/tests/bot/exts/filters/test_filtering.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import unittest
-from unittest.mock import patch
-
-from bot.exts.filters import filtering
-from tests.helpers import MockBot, autospec
-
-
-class FilteringCogTests(unittest.IsolatedAsyncioTestCase):
- """Tests the `Filtering` cog."""
-
- def setUp(self):
- """Instantiate the bot and cog."""
- self.bot = MockBot()
- with patch("pydis_core.utils.scheduling.create_task", new=lambda task, **_: task.close()):
- self.cog = filtering.Filtering(self.bot)
-
- @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"])
- async def test_token_filter(self):
- """Ensure that a filter token is correctly detected in a message."""
- messages = {
- "": False,
- "no matches": False,
- "TOKEN": True,
-
- # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6
- "https://google.com TOKEN": True,
- "https://google.com something else": False,
- }
-
- for message, match in messages.items():
- with self.subTest(input=message, match=match):
- result, _ = await self.cog._has_watch_regex_match(message)
-
- self.assertEqual(
- match,
- bool(result),
- msg=f"Hit was {'expected' if match else 'not expected'} for this input."
- )
- if result:
- self.assertEqual("TOKEN", result.group())
diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py
deleted file mode 100644
index c1f3762ac..000000000
--- a/tests/bot/exts/filters/test_token_remover.py
+++ /dev/null
@@ -1,409 +0,0 @@
-import unittest
-from re import Match
-from unittest import mock
-from unittest.mock import MagicMock
-
-from discord import Colour, NotFound
-
-from bot import constants
-from bot.exts.filters import token_remover
-from bot.exts.filters.token_remover import Token, TokenRemover
-from bot.exts.moderation.modlog import ModLog
-from bot.utils.messages import format_user
-from tests.helpers import MockBot, MockMessage, autospec
-
-
-class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
- """Tests the `TokenRemover` cog."""
-
- def setUp(self):
- """Adds the cog, a bot, and a message to the instance for usage in tests."""
- self.bot = MockBot()
- self.cog = TokenRemover(bot=self.bot)
-
- self.msg = MockMessage(id=555, content="hello world")
- self.msg.channel.mention = "#lemonade-stand"
- self.msg.guild.get_member.return_value.bot = False
- self.msg.guild.get_member.return_value.__str__.return_value = "Woody"
- self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
- self.msg.author.display_avatar.url = "picture-lemon.png"
-
- def test_extract_user_id_valid(self):
- """Should consider user IDs valid if they decode into an integer ID."""
- id_pairs = (
- ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332),
- ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904),
- ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641),
- )
-
- for token_id, user_id in id_pairs:
- with self.subTest(token_id=token_id):
- result = TokenRemover.extract_user_id(token_id)
- self.assertEqual(result, user_id)
-
- def test_extract_user_id_invalid(self):
- """Should consider non-digit and non-ASCII IDs invalid."""
- ids = (
- ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
- ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
- ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
- ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
- ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
- ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
- ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
- )
-
- for user_id, msg in ids:
- with self.subTest(msg=msg):
- result = TokenRemover.extract_user_id(user_id)
- self.assertIsNone(result)
-
- def test_is_valid_timestamp_valid(self):
- """Should consider timestamps valid if they're greater than the Discord epoch."""
- timestamps = (
- "XsyRkw",
- "Xrim9Q",
- "XsyR-w",
- "XsySD_",
- "Dn9r_A",
- )
-
- for timestamp in timestamps:
- with self.subTest(timestamp=timestamp):
- result = TokenRemover.is_valid_timestamp(timestamp)
- self.assertTrue(result)
-
- def test_is_valid_timestamp_invalid(self):
- """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
- timestamps = (
- ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
- ("ew", "123"),
- ("AoIKgA", "42076800"),
- ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
- ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
- )
-
- for timestamp, msg in timestamps:
- with self.subTest(msg=msg):
- result = TokenRemover.is_valid_timestamp(timestamp)
- self.assertFalse(result)
-
- def test_is_valid_hmac_valid(self):
- """Should consider an HMAC valid if it has at least 3 unique characters."""
- valid_hmacs = (
- "VXmErH7j511turNpfURmb0rVNm8",
- "Ysnu2wacjaKs7qnoo46S8Dm2us8",
- "sJf6omBPORBPju3WJEIAcwW9Zds",
- "s45jqDV_Iisn-symw0yDRrk_jf4",
- )
-
- for hmac in valid_hmacs:
- with self.subTest(msg=hmac):
- result = TokenRemover.is_maybe_valid_hmac(hmac)
- self.assertTrue(result)
-
- def test_is_invalid_hmac_invalid(self):
- """Should consider an HMAC invalid if has fewer than 3 unique characters."""
- invalid_hmacs = (
- ("xxxxxxxxxxxxxxxxxx", "Single character"),
- ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"),
- ("ASFasfASFasfASFASsf", "Three characters alternating-case"),
- ("asdasdasdasdasdasdasd", "Three characters one case"),
- )
-
- for hmac, msg in invalid_hmacs:
- with self.subTest(msg=msg):
- result = TokenRemover.is_maybe_valid_hmac(hmac)
- self.assertFalse(result)
-
- def test_mod_log_property(self):
- """The `mod_log` property should ask the bot to return the `ModLog` cog."""
- self.bot.get_cog.return_value = 'lemon'
- self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)
- self.bot.get_cog.assert_called_once_with('ModLog')
-
- async def test_on_message_edit_uses_on_message(self):
- """The edit listener should delegate handling of the message to the normal listener."""
- self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True)
-
- await self.cog.on_message_edit(MockMessage(), self.msg)
- self.cog.on_message.assert_awaited_once_with(self.msg)
-
- @autospec(TokenRemover, "find_token_in_message", "take_action")
- async def test_on_message_takes_action(self, find_token_in_message, take_action):
- """Should take action if a valid token is found when a message is sent."""
- cog = TokenRemover(self.bot)
- found_token = "foobar"
- find_token_in_message.return_value = found_token
-
- await cog.on_message(self.msg)
-
- find_token_in_message.assert_called_once_with(self.msg)
- take_action.assert_awaited_once_with(cog, self.msg, found_token)
-
- @autospec(TokenRemover, "find_token_in_message", "take_action")
- async def test_on_message_skips_missing_token(self, find_token_in_message, take_action):
- """Shouldn't take action if a valid token isn't found when a message is sent."""
- cog = TokenRemover(self.bot)
- find_token_in_message.return_value = False
-
- await cog.on_message(self.msg)
-
- find_token_in_message.assert_called_once_with(self.msg)
- take_action.assert_not_awaited()
-
- @autospec(TokenRemover, "find_token_in_message")
- async def test_on_message_ignores_dms_bots(self, find_token_in_message):
- """Shouldn't parse a message if it is a DM or authored by a bot."""
- cog = TokenRemover(self.bot)
- dm_msg = MockMessage(guild=None)
- bot_msg = MockMessage(author=MagicMock(bot=True))
-
- for msg in (dm_msg, bot_msg):
- await cog.on_message(msg)
- find_token_in_message.assert_not_called()
-
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_no_matches(self, token_re):
- """None should be returned if the regex matches no tokens in a message."""
- token_re.finditer.return_value = ()
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertIsNone(return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
- @autospec("bot.exts.filters.token_remover", "Token")
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_valid_match(
- self,
- token_re,
- token_cls,
- extract_user_id,
- is_valid_timestamp,
- is_maybe_valid_hmac,
- ):
- """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`."""
- matches = [
- mock.create_autospec(Match, spec_set=True, instance=True),
- mock.create_autospec(Match, spec_set=True, instance=True),
- ]
- tokens = [
- mock.create_autospec(Token, spec_set=True, instance=True),
- mock.create_autospec(Token, spec_set=True, instance=True),
- ]
-
- token_re.finditer.return_value = matches
- token_cls.side_effect = tokens
- extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid.
- is_valid_timestamp.return_value = True
- is_maybe_valid_hmac.return_value = True
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertEqual(tokens[1], return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
- @autospec("bot.exts.filters.token_remover", "Token")
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_invalid_matches(
- self,
- token_re,
- token_cls,
- extract_user_id,
- is_valid_timestamp,
- is_maybe_valid_hmac,
- ):
- """None should be returned if no matches have valid user IDs, HMACs, and timestamps."""
- token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
- token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
- extract_user_id.return_value = None
- is_valid_timestamp.return_value = False
- is_maybe_valid_hmac.return_value = False
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertIsNone(return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- def test_regex_invalid_tokens(self):
- """Messages without anything looking like a token are not matched."""
- tokens = (
- "",
- "lemon wins",
- "..",
- "x.y",
- "x.y.",
- ".y.z",
- ".y.",
- "..z",
- "x..z",
- " . . ",
- "\n.\n.\n",
- "hellö.world.bye",
- "base64.nötbåse64.morebase64",
- "19jd3J.dfkm3d.€víł§tüff",
- )
-
- for token in tokens:
- with self.subTest(token=token):
- results = token_remover.TOKEN_RE.findall(token)
- self.assertEqual(len(results), 0)
-
- def test_regex_valid_tokens(self):
- """Messages that look like tokens should be matched."""
- # Don't worry, these tokens have been invalidated.
- tokens = (
- "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
- "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
- "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
- "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
- )
-
- for token in tokens:
- with self.subTest(token=token):
- results = token_remover.TOKEN_RE.fullmatch(token)
- self.assertIsNotNone(results, f"{token} was not matched by the regex")
-
- def test_regex_matches_multiple_valid(self):
- """Should support multiple matches in the middle of a string."""
- token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
- token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
- message = f"garbage {token_1} hello {token_2} world"
-
- results = token_remover.TOKEN_RE.finditer(message)
- results = [match[0] for match in results]
- self.assertCountEqual((token_1, token_2), results)
-
- @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE")
- def test_format_log_message(self, log_message):
- """Should correctly format the log message with info from the message and token."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- log_message.format.return_value = "Howdy"
-
- return_value = TokenRemover.format_log_message(self.msg, token)
-
- self.assertEqual(return_value, log_message.format.return_value)
- log_message.format.assert_called_once_with(
- author=format_user(self.msg.author),
- channel=self.msg.channel.mention,
- user_id=token.user_id,
- timestamp=token.timestamp,
- hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4",
- )
-
- @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE")
- async def test_format_userid_log_message_unknown(self, unknown_user_log_message,):
- """Should correctly format the user ID portion when the actual user it belongs to is unknown."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- unknown_user_log_message.format.return_value = " Partner"
- msg = MockMessage(id=555, content="hello world")
- msg.guild.get_member.return_value = None
- msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
-
- return_value = await TokenRemover.format_userid_log_message(msg, token)
-
- self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
- unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332)
-
- @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- async def test_format_userid_log_message_bot(self, known_user_log_message):
- """Should correctly format the user ID portion when the ID belongs to a known bot."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- known_user_log_message.format.return_value = " Partner"
- msg = MockMessage(id=555, content="hello world")
- msg.guild.get_member.return_value.__str__.return_value = "Sam"
- msg.guild.get_member.return_value.bot = True
-
- return_value = await TokenRemover.format_userid_log_message(msg, token)
-
- self.assertEqual(return_value, (known_user_log_message.format.return_value, True))
-
- known_user_log_message.format.assert_called_once_with(
- user_id=472265943062413332,
- user_name="Sam",
- kind="BOT",
- )
-
- @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- async def test_format_log_message_user_token_user(self, user_token_message):
- """Should correctly format the user ID portion when the ID belongs to a known user."""
- token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- user_token_message.format.return_value = "Partner"
-
- return_value = await TokenRemover.format_userid_log_message(self.msg, token)
-
- self.assertEqual(return_value, (user_token_message.format.return_value, True))
- user_token_message.format.assert_called_once_with(
- user_id=467223230650777641,
- user_name="Woody",
- kind="USER",
- )
-
- @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
- @autospec("bot.exts.filters.token_remover", "log")
- @autospec(TokenRemover, "format_log_message", "format_userid_log_message")
- async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property):
- """Should delete the message and send a mod log."""
- cog = TokenRemover(self.bot)
- mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True)
- token = mock.create_autospec(Token, spec_set=True, instance=True)
- token.user_id = "no-id"
- log_msg = "testing123"
- userid_log_message = "userid-log-message"
-
- mod_log_property.return_value = mod_log
- format_log_message.return_value = log_msg
- format_userid_log_message.return_value = (userid_log_message, True)
-
- await cog.take_action(self.msg, token)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_called_once_with(
- token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention)
- )
-
- format_log_message.assert_called_once_with(self.msg, token)
- format_userid_log_message.assert_called_once_with(self.msg, token)
- logger.debug.assert_called_with(log_msg)
- self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens")
-
- mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id)
- mod_log.send_log_message.assert_called_once_with(
- icon_url=constants.Icons.token_removed,
- colour=Colour(constants.Colours.soft_red),
- title="Token removed!",
- text=log_msg + "\n" + userid_log_message,
- thumbnail=self.msg.author.display_avatar.url,
- channel_id=constants.Channels.mod_alerts,
- ping_everyone=True,
- )
-
- @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
- async def test_take_action_delete_failure(self, mod_log_property):
- """Shouldn't send any messages if the token message can't be deleted."""
- cog = TokenRemover(self.bot)
- mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True)
- self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock())
-
- token = mock.create_autospec(Token, spec_set=True, instance=True)
- await cog.take_action(self.msg, token)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_not_awaited()
-
-
-class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase):
- """Tests for the token_remover extension."""
-
- @autospec("bot.exts.filters.token_remover", "TokenRemover")
- async def test_extension_setup(self, cog):
- """The TokenRemover cog should be added."""
- bot = MockBot()
- await token_remover.setup(bot)
-
- cog.assert_called_once_with(bot)
- bot.add_cog.assert_awaited_once()
- self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover))
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
deleted file mode 100644
index 0d570f5a3..000000000
--- a/tests/bot/rules/__init__.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import unittest
-from abc import ABCMeta, abstractmethod
-from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
-
-from tests.helpers import MockMessage
-
-
-class DisallowedCase(NamedTuple):
- """Encapsulation for test cases expected to fail."""
- recent_messages: List[MockMessage]
- culprits: Iterable[str]
- n_violations: int
-
-
-class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
- """
- Abstract class for antispam rule test cases.
-
- Tests for specific rules should inherit from `RuleTest` and implement
- `relevant_messages` and `get_report`. Each instance should also set the
- `apply` and `config` attributes as necessary.
-
- The execution of test cases can then be delegated to the `run_allowed`
- and `run_disallowed` methods.
- """
-
- apply: Callable # The tested rule's apply function
- config: Dict[str, int]
-
- async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
- """Run all `cases` against `self.apply` expecting them to pass."""
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- ):
- self.assertIsNone(
- await self.apply(last_message, recent_messages, self.config)
- )
-
- async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
- """Run all `cases` against `self.apply` expecting them to fail."""
- for case in cases:
- recent_messages, culprits, n_violations = case
- last_message = recent_messages[0]
- relevant_messages = self.relevant_messages(case)
- desired_output = (
- self.get_report(case),
- culprits,
- relevant_messages,
- )
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- n_violations=n_violations,
- config=self.config,
- ):
- self.assertTupleEqual(
- await self.apply(last_message, recent_messages, self.config),
- desired_output,
- )
-
- @abstractmethod
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- """Give expected relevant messages for `case`."""
- raise NotImplementedError # pragma: no cover
-
- @abstractmethod
- def get_report(self, case: DisallowedCase) -> str:
- """Give expected error report for `case`."""
- raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
deleted file mode 100644
index d7e779221..000000000
--- a/tests/bot/rules/test_attachments.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from typing import Iterable
-
-from bot.rules import attachments
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, total_attachments: int) -> MockMessage:
- """Builds a message with `total_attachments` attachments."""
- return MockMessage(author=author, attachments=list(range(total_attachments)))
-
-
-class AttachmentRuleTests(RuleTest):
- """Tests applying the `attachments` antispam rule."""
-
- def setUp(self):
- self.apply = attachments.apply
- self.config = {"max": 5, "interval": 10}
-
- async def test_allows_messages_without_too_many_attachments(self):
- """Messages without too many attachments are allowed as-is."""
- cases = (
- [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
- [make_msg("bob", 2), make_msg("bob", 2)],
- [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_with_too_many_attachments(self):
- """Messages with too many attachments trigger the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
- ("bob",),
- 10,
- ),
- DisallowedCase(
- [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
- ("bob",),
- 6,
- ),
- DisallowedCase(
- [make_msg("alice", 6)],
- ("alice",),
- 6,
- ),
- DisallowedCase(
- [make_msg("alice", 1) for _ in range(6)],
- ("alice",),
- 6,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
deleted file mode 100644
index 03682966b..000000000
--- a/tests/bot/rules/test_burst.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from typing import Iterable
-
-from bot.rules import burst
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str) -> MockMessage:
- """
- Init a MockMessage instance with author set to `author`.
-
- This serves as a shorthand / alias to keep the test cases visually clean.
- """
- return MockMessage(author=author)
-
-
-class BurstRuleTests(RuleTest):
- """Tests the `burst` antispam rule."""
-
- def setUp(self):
- self.apply = burst.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("bob"), make_msg("bob")],
- [make_msg("bob"), make_msg("alice"), make_msg("bob")],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the amount of messages exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("bob")],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
- ("bob",),
- 3,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
deleted file mode 100644
index 3275143d5..000000000
--- a/tests/bot/rules/test_burst_shared.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Iterable
-
-from bot.rules import burst_shared
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str) -> MockMessage:
- """
- Init a MockMessage instance with the passed arg.
-
- This serves as a shorthand / alias to keep the test cases visually clean.
- """
- return MockMessage(author=author)
-
-
-class BurstSharedRuleTests(RuleTest):
- """Tests the `burst_shared` antispam rule."""
-
- def setUp(self):
- self.apply = burst_shared.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """
- Cases that do not violate the rule.
-
- There really isn't more to test here than a single case.
- """
- cases = (
- [make_msg("spongebob"), make_msg("patrick")],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the amount of messages exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("bob")],
- {"bob"},
- 3,
- ),
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
- {"bob", "alice"},
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return case.recent_messages
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
deleted file mode 100644
index f1e3c76a7..000000000
--- a/tests/bot/rules/test_chars.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from typing import Iterable
-
-from bot.rules import chars
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, n_chars: int) -> MockMessage:
- """Build a message with arbitrary content of `n_chars` length."""
- return MockMessage(author=author, content="A" * n_chars)
-
-
-class CharsRuleTests(RuleTest):
- """Tests the `chars` antispam rule."""
-
- def setUp(self):
- self.apply = chars.apply
- self.config = {
- "max": 20, # Max allowed sum of chars per user
- "interval": 10,
- }
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of chars within limit."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 20)],
- [make_msg("bob", 15), make_msg("alice", 15)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the total amount of chars exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 21)],
- ("bob",),
- 21,
- ),
- DisallowedCase(
- [make_msg("bob", 15), make_msg("bob", 15)],
- ("bob",),
- 30,
- ),
- DisallowedCase(
- [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
- ("alice",),
- 30,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
deleted file mode 100644
index 66c2d9f92..000000000
--- a/tests/bot/rules/test_discord_emojis.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from typing import Iterable
-
-from bot.rules import discord_emojis
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
-unicode_emoji = "🧪"
-
-
-def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage:
- """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
- return MockMessage(author=author, content=emoji * n_emojis)
-
-
-class DiscordEmojisRuleTests(RuleTest):
- """Tests for the `discord_emojis` antispam rule."""
-
- def setUp(self):
- self.apply = discord_emojis.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of discord and unicode emojis within limit."""
- cases = (
- [make_msg("bob", 2)],
- [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
- [make_msg("bob", 2, unicode_emoji)],
- [
- make_msg("alice", 1, unicode_emoji),
- make_msg("bob", 2, unicode_emoji),
- make_msg("alice", 1, unicode_emoji)
- ],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with more than the allowed amount of discord and unicode emojis."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- ("alice",),
- 4,
- ),
- DisallowedCase(
- [make_msg("bob", 3, unicode_emoji)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [
- make_msg("alice", 2, unicode_emoji),
- make_msg("bob", 2, unicode_emoji),
- make_msg("alice", 2, unicode_emoji)
- ],
- ("alice",),
- 4
- )
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
deleted file mode 100644
index 9bd886a77..000000000
--- a/tests/bot/rules/test_duplicates.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from typing import Iterable
-
-from bot.rules import duplicates
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, content: str) -> MockMessage:
- """Give a MockMessage instance with `author` and `content` attrs."""
- return MockMessage(author=author, content=content)
-
-
-class DuplicatesRuleTests(RuleTest):
- """Tests the `duplicates` antispam rule."""
-
- def setUp(self):
- self.apply = duplicates.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("alice", "A"), make_msg("alice", "A")],
- [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
- [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with too many duplicate messages from the same author."""
- cases = (
- DisallowedCase(
- [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
- ("alice",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
- ("bob",),
- 3, # 4 duplicate messages, but only 3 from bob
- ),
- DisallowedCase(
- [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
- ("bob",),
- 3, # 4 message from bob, but only 3 duplicates
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if (
- msg.author == last_message.author
- and msg.content == last_message.content
- )
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
deleted file mode 100644
index b091bd9d7..000000000
--- a/tests/bot/rules/test_links.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from typing import Iterable
-
-from bot.rules import links
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, total_links: int) -> MockMessage:
- """Makes a message with `total_links` links."""
- content = " ".join(["https://pydis.com"] * total_links)
- return MockMessage(author=author, content=content)
-
-
-class LinksTests(RuleTest):
- """Tests applying the `links` rule."""
-
- def setUp(self):
- self.apply = links.apply
- self.config = {
- "max": 2,
- "interval": 10
- }
-
- async def test_links_within_limit(self):
- """Messages with an allowed amount of links."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 2)],
- [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
- )
-
- await self.run_allowed(cases)
-
- async def test_links_exceeding_limit(self):
- """Messages with a a higher than allowed amount of links."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 1), make_msg("bob", 2)],
- ("bob",),
- 3
- ),
- DisallowedCase(
- [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
- ("alice",),
- 3
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
- ("alice",),
- 3
- )
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
deleted file mode 100644
index e1f904917..000000000
--- a/tests/bot/rules/test_mentions.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from typing import Iterable, Optional
-
-import discord
-
-from bot.rules import mentions
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMember, MockMessage, MockMessageReference
-
-
-def make_msg(
- author: str,
- total_user_mentions: int,
- total_bot_mentions: int = 0,
- *,
- reference: Optional[MockMessageReference] = None
-) -> MockMessage:
- """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""
- user_mentions = [MockMember() for _ in range(total_user_mentions)]
- bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
-
- mentions = user_mentions + bot_mentions
- if reference is not None:
- # For the sake of these tests we assume that all references are mentions.
- mentions.append(reference.resolved.author)
- msg_type = discord.MessageType.reply
- else:
- msg_type = discord.MessageType.default
-
- return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)
-
-
-class TestMentions(RuleTest):
- """Tests applying the `mentions` antispam rule."""
-
- def setUp(self):
- self.apply = mentions.apply
- self.config = {
- "max": 2,
- "interval": 10,
- }
-
- async def test_mentions_within_limit(self):
- """Messages with an allowed amount of mentions."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 2)],
- [make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 1), make_msg("alice", 2)],
- )
-
- await self.run_allowed(cases)
-
- async def test_mentions_exceeding_limit(self):
- """Messages with a higher than allowed amount of mentions."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
- ("alice",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
- ("bob",),
- 4,
- ),
- DisallowedCase(
- [make_msg("bob", 3, 1)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference())],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))],
- ("bob",),
- 3
- )
- )
-
- await self.run_disallowed(cases)
-
- async def test_ignore_bot_mentions(self):
- """Messages with an allowed amount of mentions, also containing bot mentions."""
- cases = (
- [make_msg("bob", 0, 3)],
- [make_msg("bob", 2, 1)],
- [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
- [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
- )
-
- await self.run_allowed(cases)
-
- async def test_ignore_reply_mentions(self):
- """Messages with an allowed amount of mentions in the content, also containing reply mentions."""
- cases = (
- [
- make_msg("bob", 2, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True))
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference()),
- make_msg("bob", 0, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)),
- make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True))
- ]
- )
-
- await self.run_allowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
deleted file mode 100644
index e35377773..000000000
--- a/tests/bot/rules/test_newlines.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from typing import Iterable, List
-
-from bot.rules import newlines
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
- """Init a MockMessage instance with `author` and content configured by `newline_groups".
-
- Configure content by passing a list of ints, where each int `n` will generate
- a separate group of `n` newlines.
-
- Example:
- newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
- """
- content = " ".join("\n" * n for n in newline_groups)
- return MockMessage(author=author, content=content)
-
-
-class TotalNewlinesRuleTests(RuleTest):
- """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
-
- def setUp(self):
- self.apply = newlines.apply
- self.config = {
- "max": 5, # Max sum of newlines in relevant messages
- "max_consecutive": 3, # Max newlines in one group, in one message
- "interval": 10,
- }
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("alice", [])], # Single message with no newlines
- [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
- [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
- [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_total(self):
- """Cases which violate the rule by having too many newlines in total."""
- cases = (
- DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
- [make_msg("alice", [2, 2]), make_msg("alice", [2])],
- ("alice",),
- 6,
- ),
- DisallowedCase( # Here we test that only alice's newlines count in the sum
- [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
- ("alice",),
- 7,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_author = case.recent_messages[0].author
- return tuple(msg for msg in case.recent_messages if msg.author == last_author)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} newlines in {self.config['interval']}s"
-
-
-class GroupNewlinesRuleTests(RuleTest):
- """
- Tests the `newlines` antispam rule against max consecutive newline violations.
-
- As these violations yield a different error report, they require a different
- `get_report` implementation.
- """
-
- def setUp(self):
- self.apply = newlines.apply
- self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
-
- async def test_disallows_messages_consecutive(self):
- """Cases which violate the rule due to having too many consecutive newlines."""
- cases = (
- DisallowedCase( # Bob sends a group of newlines too large
- [make_msg("bob", [4])],
- ("bob",),
- 4,
- ),
- DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
- [make_msg("alice", [1]), make_msg("alice", [4])],
- ("alice",),
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_author = case.recent_messages[0].author
- return tuple(msg for msg in case.recent_messages if msg.author == last_author)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
deleted file mode 100644
index 26c05d527..000000000
--- a/tests/bot/rules/test_role_mentions.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from typing import Iterable
-
-from bot.rules import role_mentions
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, n_mentions: int) -> MockMessage:
- """Build a MockMessage instance with `n_mentions` role mentions."""
- return MockMessage(author=author, role_mentions=[None] * n_mentions)
-
-
-class RoleMentionsRuleTests(RuleTest):
- """Tests for the `role_mentions` antispam rule."""
-
- def setUp(self):
- self.apply = role_mentions.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of role mentions within limit."""
- cases = (
- [make_msg("bob", 2)],
- [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with more than the allowed amount of role mentions."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- ("alice",),
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/helpers.py b/tests/helpers.py
index 1a71f210a..020f1aee5 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data
class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
- A MagicMock subclass to mock TextChannel objects.
+ A MagicMock subclass to mock DMChannel objects.
- Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ Instances of this class will follow the specifications of `discord.DMChannel` instances. For
more information, see the `MockGuild` docstring.
"""
spec_set = dm_channel_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
@@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel(
class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id)}
- super().__init__(**collections.ChainMap(default_kwargs, kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
# Create a Message instance to get a realistic MagicMock of `discord.Message`