diff options
Diffstat (limited to '')
104 files changed, 7265 insertions, 3782 deletions
| diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7cd00a0d6..816bdf290 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,6 +1,5 @@  # Extensions  **/bot/exts/backend/sync/**             @MarkKoz -**/bot/exts/filters/*token_remover.py   @MarkKoz  **/bot/exts/moderation/*silence.py      @MarkKoz  bot/exts/info/codeblock/**              @MarkKoz  bot/exts/utils/extensions.py            @MarkKoz @@ -8,14 +7,11 @@ bot/exts/utils/snekbox.py               @MarkKoz @jb3  bot/exts/moderation/**                  @mbaruh @Den4200 @ks129 @jb3  bot/exts/info/**                        @Den4200 @jb3  bot/exts/info/information.py            @mbaruh @jb3 -bot/exts/filters/**                     @mbaruh @jb3 +bot/exts/filtering/**                   @mbaruh  bot/exts/fun/**                         @ks129  bot/exts/utils/**                       @ks129 @jb3  bot/exts/recruitment/**                 @wookie184 -# Rules -bot/rules/**                            @mbaruh -  # Utils  bot/utils/function.py                   @MarkKoz  bot/utils/lock.py                       @MarkKoz diff --git a/bot/bot.py b/bot/bot.py index 6164ba9fd..f56aec38e 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,5 +1,4 @@  import asyncio -from collections import defaultdict  import aiohttp  from pydis_core import BotBase @@ -27,8 +26,6 @@ class Bot(BotBase):          super().__init__(*args, **kwargs) -        self.filter_list_cache = defaultdict(dict) -      async def ping_services(self) -> None:          """A helper to make sure all the services the bot relies on are available on startup."""          # Connect Site/API @@ -45,33 +42,10 @@ class Bot(BotBase):                      raise                  await asyncio.sleep(constants.URLs.connect_cooldown) -    def insert_item_into_filter_list_cache(self, item: dict[str, str]) -> None: -        """Add an item to the bots filter_list_cache.""" -        type_ = item["type"] -        allowed = item["allowed"] -        content = item["content"] - -        self.filter_list_cache[f"{type_}.{allowed}"][content] = { -            "id": item["id"], -            "comment": item["comment"], -            "created_at": item["created_at"], -            "updated_at": item["updated_at"], -        } - -    async def cache_filter_list_data(self) -> None: -        """Cache all the data in the FilterList on the site.""" -        full_cache = await self.api_client.get('bot/filter-lists') - -        for item in full_cache: -            self.insert_item_into_filter_list_cache(item) -      async def setup_hook(self) -> None:          """Default async initialisation method for discord.py."""          await super().setup_hook() -        # Build the FilterList cache -        await self.cache_filter_list_data() -          # This is not awaited to avoid a deadlock with any cogs that have          # wait_until_guild_available in their cog_load method.          scheduling.create_task(self.load_extensions(exts)) diff --git a/bot/constants.py b/bot/constants.py index 0b75153d3..3d0ae9542 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -336,43 +336,6 @@ class _Free(EnvConfig):  Free = _Free() -class Rule(BaseModel): -    interval: int -    max: int - - -# Some help in choosing an appropriate name for this is appreciated -class ExtendedRule(Rule): -    max_consecutive: int - - -class Rules(BaseModel): -    attachments: Rule = Rule(interval=10, max=6) -    burst: Rule = Rule(interval=10, max=7) -    chars: Rule = Rule(interval=5, max=4_200) -    discord_emojis: Rule = Rule(interval=10, max=20) -    duplicates: Rule = Rule(interval=10, max=3) -    links: Rule = Rule(interval=10, max=10) -    mentions: Rule = Rule(interval=10, max=5) -    newlines: ExtendedRule = ExtendedRule(interval=10, max=100, max_consecutive=10) -    role_mentions: Rule = Rule(interval=10, max=3) - - -class _AntiSpam(EnvConfig): -    EnvConfig.Config.env_prefix = 'anti_spam_' - -    cache_size = 100 - -    clean_offending = True -    ping_everyone = True - -    remove_timeout_after = 600 -    rules = Rules() - - -AntiSpam = _AntiSpam() - -  class _HelpChannels(EnvConfig):      EnvConfig.Config.env_prefix = "help_channels_" @@ -659,47 +622,6 @@ class _Icons(EnvConfig):  Icons = _Icons() -class _Filter(EnvConfig): -    EnvConfig.Config.env_prefix = "filters_" - -    filter_domains = True -    filter_everyone_ping = True -    filter_invites = True -    filter_zalgo = False -    watch_regex = True -    watch_rich_embeds = True - -    # Notifications are not expected for "watchlist" type filters - -    notify_user_domains = False -    notify_user_everyone_ping = True -    notify_user_invites = True -    notify_user_zalgo = False - -    offensive_msg_delete_days = 7 -    ping_everyone = True - -    channel_whitelist = [ -        Channels.admins, -        Channels.big_brother, -        Channels.dev_log, -        Channels.message_log, -        Channels.mod_log, -        Channels.staff_lounge -    ] -    role_whitelist = [ -        Roles.admins, -        Roles.helpers, -        Roles.moderators, -        Roles.owners, -        Roles.python_community, -        Roles.partners -    ] - - -Filter = _Filter() - -  class _Keys(EnvConfig):      EnvConfig.Config.env_prefix = "api_keys_" diff --git a/bot/converters.py b/bot/converters.py index 544513c90..21623b597 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -9,7 +9,7 @@ import dateutil.parser  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter +from discord.ext.commands import BadArgument, Context, Converter, IDConverter, MemberConverter, UserConverter  from discord.utils import escape_markdown, snowflake_time  from pydis_core.site_api import ResponseCodeError  from pydis_core.utils import unqualify @@ -68,54 +68,6 @@ class ValidDiscordServerInvite(Converter):          raise BadArgument("This does not appear to be a valid Discord server invite.") -class ValidFilterListType(Converter): -    """ -    A converter that checks whether the given string is a valid FilterList type. - -    Raises `BadArgument` if the argument is not a valid FilterList type, and simply -    passes through the given argument otherwise. -    """ - -    @staticmethod -    async def get_valid_types(bot: Bot) -> list: -        """ -        Try to get a list of valid filter list types. - -        Raise a BadArgument if the API can't respond. -        """ -        try: -            valid_types = await bot.api_client.get('bot/filter-lists/get-types') -        except ResponseCodeError: -            raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") - -        return [enum for enum, classname in valid_types] - -    async def convert(self, ctx: Context, list_type: str) -> str: -        """Checks whether the given string is a valid FilterList type.""" -        valid_types = await self.get_valid_types(ctx.bot) -        list_type = list_type.upper() - -        if list_type not in valid_types: - -            # Maybe the user is using the plural form of this type, -            # e.g. "guild_invites" instead of "guild_invite". -            # -            # This code will support the simple plural form (a single 's' at the end), -            # which works for all current list types, but if a list type is added in the future -            # which has an irregular plural form (like 'ies'), this code will need to be -            # refactored to support this. -            if list_type.endswith("S") and list_type[:-1] in valid_types: -                list_type = list_type[:-1] - -            else: -                valid_types_list = '\n'.join([f"• {type_.lower()}" for type_ in valid_types]) -                raise BadArgument( -                    f"You have provided an invalid list type!\n\n" -                    f"Please provide one of the following: \n{valid_types_list}" -                ) -        return list_type - -  class Extension(Converter):      """      Fully qualify the name of an extension and ensure it exists. diff --git a/bot/exts/filtering/FILTERS-DEVELOPMENT.md b/bot/exts/filtering/FILTERS-DEVELOPMENT.md new file mode 100644 index 000000000..c6237b60c --- /dev/null +++ b/bot/exts/filtering/FILTERS-DEVELOPMENT.md @@ -0,0 +1,63 @@ +# Filters Development +This file gives a short overview of the extension, and shows how to perform some basic changes/additions to it. + +## Overview +The main idea is that there is a list of filters each deciding whether they apply to the given content. +For example, there can be a filter that decides it will trigger when the content contains the string "lemon". + +There are several types of filters, and two filters of the same type differ by their content. +For example, filters of type "token" search for a specific token inside the provided string. +One token filter might look for the string "lemon", while another will look for the string "joe". + +Each filter has a set of settings that decide when it triggers (e.g. in which channels, in which categories, etc.), and what happens if it does (e.g. delete the message, ping specific roles/users, etc.). +Filters of a specific type can have additional settings that are special to them. + +A list of filters is contained within a filter list. +The filter list gets content to filter, and dispatches it to each of its filters. +It takes the answers from its filters and returns a unified response (e.g. if at least one of the filters says it should be deleted, then the filter list response will include it). + +A filter list has the same set of possible settings, which act as defaults. +If a filter in the list doesn't define a value for a setting (meaning it has a value of None), it will use the value of the containing filter list. + +The cog receives "filtering events". For example, a new message is sent. +It creates a "filtering context" with everything a filtering list needs to know to provide an answer for what should be done. +For example, if the event is a new message, then the content to filter is the content of the message, embeds if any exist, etc. + +The cog dispatches the event to each filter list, gets the result from each, compiles them, and takes any action dictated by them. +For example, if any of the filter lists want the message to be deleted, then the cog will delete it. + +## Example Changes +### Creating a new type of filter list +1. Head over to `bot.exts.filtering._filter_lists` and create a new Python file. +2. Subclass the FilterList class in `bot.exts.filtering._filter_lists.filter_list` and implement its abstract methods. Make sure to set the `name` class attribute. + +You can now add filter lists to the database with the same name defined in the new FilterList subclass. + +### Creating a new type of filter +1. Head over to `bot.exts.filtering._filters` and create a new Python file. +2. Subclass the Filter class in `bot.exts.filtering._filters.filter` and implement its abstract methods. +3. Make sure to set the `name` class attribute, and have one of the FilterList subclasses return this new Filter subclass in `get_filter_type`. + +### Creating a new type of setting +1. Head over to `bot.exts.filtering._settings_types`, and open a new Python file in either `actions` or `validations`, depending on whether you want to subclass `ActionEntry` or `ValidationEntry`. +2. Subclass one of the aforementioned classes, and implement its abstract methods. Make sure to set the `name` and `description` class attributes. + +You can now make the appropriate changes to the site repo: +1. Add a new field in the `Filter` and `FilterList` models. Make sure that on `Filter` it's nullable, and on `FilterList` it isn't. +2. In `serializers.py`, add the new field to `SETTINGS_FIELDS`, and to `ALLOW_BLANK_SETTINGS` or `ALLOW_EMPTY_SETTINGS` if appropriate. If it's not a part of any group of settings, add it `BASE_SETTINGS_FIELDS`, otherwise add it to the appropriate group or create a new one. +3. If you created a new group, make sure it's used in `to_representation`. +4. Update the docs in the filter viewsets. + +You can merge the changes to the bot first - if no such field is loaded from the database it'll just be ignored. + +You can define entries that are a group of fields in the database. +In that case the created subclass should have fields whose names are the names of the fields in the database. +Then, the description will be a dictionary, whose keys are the names of the fields, and values are the descriptions for each field. + +### Creating a new type of filtering event +1. Head over to `bot.exts.filtering._filter_context` and add a new value to the `Event` enum. +2. Implement the dispatching and actioning of the new event in the cog, by either adding it to an existing even listener, or creating a new one. +3. Have the appropriate filter lists subscribe to the event, so they receive it. +4. Have the appropriate unique filters (currently under `unique` and `antispam` in `bot.exts.filtering._filters`) subscribe to the event, so they receive it. + +It should be noted that the filtering events don't need to correspond to Discord events. For example, `nickname` isn't a Discord event and is dispatched when a message is sent. diff --git a/bot/exts/filters/__init__.py b/bot/exts/filtering/__init__.py index e69de29bb..e69de29bb 100644 --- a/bot/exts/filters/__init__.py +++ b/bot/exts/filtering/__init__.py diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py new file mode 100644 index 000000000..0794a48e4 --- /dev/null +++ b/bot/exts/filtering/_filter_context.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import typing +from collections.abc import Callable, Coroutine, Iterable +from dataclasses import dataclass, field, replace +from enum import Enum, auto + +import discord +from discord import DMChannel, Embed, Member, Message, StageChannel, TextChannel, Thread, User, VoiceChannel + +from bot.utils.message_cache import MessageCache + +if typing.TYPE_CHECKING: +    from bot.exts.filtering._filters.filter import Filter +    from bot.exts.utils.snekbox._io import FileAttachment + + +class Event(Enum): +    """Types of events that can trigger filtering. Note this does not have to align with gateway event types.""" + +    MESSAGE = auto() +    MESSAGE_EDIT = auto() +    NICKNAME = auto() +    SNEKBOX = auto() + + +@dataclass +class FilterContext: +    """A dataclass containing the information that should be filtered, and output information of the filtering.""" + +    # Input context +    event: Event  # The type of event +    author: User | Member | None  # Who triggered the event +    channel: TextChannel | VoiceChannel | StageChannel | Thread | DMChannel | None  # The channel involved +    content: str | Iterable  # What actually needs filtering. The Iterable type depends on the filter list. +    message: Message | None  # The message involved +    embeds: list[Embed] = field(default_factory=list)  # Any embeds involved +    attachments: list[discord.Attachment | FileAttachment] = field(default_factory=list)  # Any attachments sent. +    before_message: Message | None = None +    message_cache: MessageCache | None = None +    # Output context +    dm_content: str = ""  # The content to DM the invoker +    dm_embed: str = ""  # The embed description to DM the invoker +    send_alert: bool = False  # Whether to send an alert for the moderators +    alert_content: str = ""  # The content of the alert +    alert_embeds: list[Embed] = field(default_factory=list)  # Any embeds to add to the alert +    action_descriptions: list[str] = field(default_factory=list)  # What actions were taken +    matches: list[str] = field(default_factory=list)  # What exactly was found +    notification_domain: str = ""  # A domain to send the user for context +    filter_info: dict['Filter', str] = field(default_factory=dict)  # Additional info from a filter. +    messages_deletion: bool = False  # Whether the messages were deleted. Can't upload deletion log otherwise. +    blocked_exts: set[str] = field(default_factory=set)  # Any extensions blocked (used for snekbox) +    # Additional actions to perform +    additional_actions: list[Callable[[FilterContext], Coroutine]] = field(default_factory=list) +    related_messages: set[Message] = field(default_factory=set)  # Deletion will include these. +    related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set) +    uploaded_attachments: dict[int, list[str]] = field(default_factory=dict)  # Message ID to attachment URLs. +    upload_deletion_logs: bool = True  # Whether it's allowed to upload deletion logs. + +    def __post_init__(self): +        # If it's in the context of a DM channel, self.channel won't be None, but self.channel.guild will. +        self.in_guild = self.channel is None or self.channel.guild is not None + +    @classmethod +    def from_message( +        cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None +    ) -> FilterContext: +        """Create a filtering context from the attributes of a message.""" +        return cls( +            event, +            message.author, +            message.channel, +            message.content, +            message, +            message.embeds, +            message.attachments, +            before, +            cache +        ) + +    def replace(self, **changes) -> FilterContext: +        """Return a new context object assigning new values to the specified fields.""" +        return replace(self, **changes) diff --git a/bot/exts/filtering/_filter_lists/__init__.py b/bot/exts/filtering/_filter_lists/__init__.py new file mode 100644 index 000000000..82e0452f9 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/__init__.py @@ -0,0 +1,9 @@ +from os.path import dirname + +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType, list_type_converter +from bot.exts.filtering._utils import subclasses_in_package + +filter_list_types = subclasses_in_package(dirname(__file__), f"{__name__}.", FilterList) +filter_list_types = {filter_list.name: filter_list for filter_list in filter_list_types} + +__all__ = [filter_list_types, FilterList, ListType, list_type_converter] diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py new file mode 100644 index 000000000..94f80e6eb --- /dev/null +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -0,0 +1,197 @@ +import asyncio +import typing +from collections import Counter +from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field +from datetime import timedelta +from functools import reduce +from itertools import takewhile +from operator import add, or_ + +import arrow +from discord import Member +from pydis_core.utils import scheduling +from pydis_core.utils.logging import get_logger + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filter_lists.filter_list import ListType, SubscribingAtomicList, UniquesListBase +from bot.exts.filtering._filters.antispam import antispam_filter_types +from bot.exts.filtering._filters.filter import Filter, UniqueFilter +from bot.exts.filtering._settings import ActionSettings +from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification +from bot.exts.filtering._ui.ui import AlertView, build_mod_alert + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + +log = get_logger(__name__) + +ALERT_DELAY = 6 + + +class AntispamList(UniquesListBase): +    """ +    A list of anti-spam rules. + +    Messages from the last X seconds are passed to each rule, which decides whether it triggers across those messages. + +    The infraction reason is set dynamically. +    """ + +    name = "antispam" + +    def __init__(self, filtering_cog: 'Filtering'): +        super().__init__(filtering_cog) +        self.message_deletion_queue: dict[Member, DeletionContext] = dict() + +    def get_filter_type(self, content: str) -> type[UniqueFilter] | None: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        try: +            return antispam_filter_types[content] +        except KeyError: +            if content not in self._already_warned: +                log.warning(f"An antispam filter named {content} was supplied, but no matching implementation found.") +                self._already_warned.add(content) +            return None + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        if not ctx.message or not ctx.message_cache: +            return None, [], {} + +        sublist: SubscribingAtomicList = self[ListType.DENY] +        potential_filters = [sublist.filters[id_] for id_ in sublist.subscriptions[ctx.event]] +        max_interval = max(filter_.extra_fields.interval for filter_ in potential_filters) + +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=max_interval) +        relevant_messages = list( +            takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.message_cache) +        ) +        new_ctx = ctx.replace(content=relevant_messages) +        triggers = await sublist.filter_list_result(new_ctx) +        if not triggers: +            return None, [], {} + +        if ctx.author not in self.message_deletion_queue: +            self.message_deletion_queue[ctx.author] = DeletionContext() +            ctx.additional_actions.append(self._create_deletion_context_handler(ctx.author)) +            ctx.related_channels |= {msg.channel for msg in ctx.related_messages} +        else:  # The additional messages found are already part of a deletion context +            ctx.related_messages = set() +        current_infraction = self.message_deletion_queue[ctx.author].current_infraction +        # In case another filter wants an alert, prevent deleted messages from being uploaded now and also for +        # the spam alert (upload happens during alerting). +        # Deleted messages API doesn't accept duplicates and will error. +        # Additional messages are necessarily part of the deletion. +        ctx.upload_deletion_logs = False +        self.message_deletion_queue[ctx.author].add(ctx, triggers) + +        current_actions = sublist.merge_actions(triggers) +        # Don't alert yet. +        current_actions.pop("ping", None) +        current_actions.pop("send_alert", None) + +        new_infraction = current_actions[InfractionAndNotification.name].copy() +        # Smaller infraction value => higher in hierarchy. +        if not current_infraction or new_infraction.infraction_type.value < current_infraction.value: +            # Pick the first triggered filter for the reason, there's no good way to decide between them. +            new_infraction.infraction_reason = ( +                f"{triggers[0].name.replace('_', ' ')} spam – {ctx.filter_info[triggers[0]]}" +            ) +            current_actions[InfractionAndNotification.name] = new_infraction +            self.message_deletion_queue[ctx.author].current_infraction = new_infraction.infraction_type +        else: +            current_actions.pop(InfractionAndNotification.name, None) + +        # Provide some message in case another filter list wants there to be an alert. +        return current_actions, ["Handling spam event..."], {ListType.DENY: triggers} + +    def _create_deletion_context_handler(self, member: Member) -> Callable[[FilterContext], Coroutine]: +        async def schedule_processing(ctx: FilterContext) -> None: +            """ +            Schedule a coroutine to process the deletion context. + +            It cannot be awaited directly, as it waits ALERT_DELAY seconds, and actioning a filtering context depends on +            all actions finishing. + +            This is async and takes a context to adhere to the type of ctx.additional_actions. +            """ +            async def process_deletion_context() -> None: +                """Processes the Deletion Context queue.""" +                log.trace("Sleeping before processing message deletion queue.") +                await asyncio.sleep(ALERT_DELAY) + +                if member not in self.message_deletion_queue: +                    log.error(f"Started processing deletion queue for context `{member}`, but it was not found!") +                    return + +                deletion_context = self.message_deletion_queue.pop(member) +                await deletion_context.send_alert(self) + +            scheduling.create_task(process_deletion_context()) + +        return schedule_processing + + +@dataclass +class DeletionContext: +    """Represents a Deletion Context for a single spam event.""" + +    contexts: list[FilterContext] = field(default_factory=list) +    rules: set[UniqueFilter] = field(default_factory=set) +    current_infraction: Infraction | None = None + +    def add(self, ctx: FilterContext, rules: list[UniqueFilter]) -> None: +        """Adds new rule violation events to the deletion context.""" +        self.contexts.append(ctx) +        self.rules.update(rules) + +    async def send_alert(self, antispam_list: AntispamList) -> None: +        """Post the mod alert.""" +        if not self.contexts or not self.rules: +            return + +        webhook = antispam_list.filtering_cog.webhook +        if not webhook: +            return + +        ctx, *other_contexts = self.contexts +        new_ctx = FilterContext(ctx.event, ctx.author, ctx.channel, ctx.content, ctx.message) +        all_descriptions_counts = Counter(reduce( +            add, (other_ctx.action_descriptions for other_ctx in other_contexts), ctx.action_descriptions +        )) +        new_ctx.action_descriptions = [ +            f"{action} X {count}" if count > 1 else action for action, count in all_descriptions_counts.items() +        ] +        # It shouldn't ever come to this, but just in case. +        if (descriptions_num := len(new_ctx.action_descriptions)) > 20: +            new_ctx.action_descriptions = new_ctx.action_descriptions[:20] +            new_ctx.action_descriptions[-1] += f" (+{descriptions_num - 20} other actions)" +        new_ctx.related_messages = reduce( +            or_, (other_ctx.related_messages for other_ctx in other_contexts), ctx.related_messages +        ) | {ctx.message for ctx in other_contexts} +        new_ctx.related_channels = reduce( +            or_, (other_ctx.related_channels for other_ctx in other_contexts), ctx.related_channels +        ) | {ctx.channel for ctx in other_contexts} +        new_ctx.uploaded_attachments = reduce( +            or_, (other_ctx.uploaded_attachments for other_ctx in other_contexts), ctx.uploaded_attachments +        ) +        new_ctx.upload_deletion_logs = True +        new_ctx.messages_deletion = all(ctx.messages_deletion for ctx in self.contexts) + +        rules = list(self.rules) +        actions = antispam_list[ListType.DENY].merge_actions(rules) +        for action in list(actions): +            if action not in ("ping", "send_alert"): +                actions.pop(action, None) +        await actions.action(new_ctx) + +        messages = antispam_list[ListType.DENY].format_messages(rules) +        embed = await build_mod_alert(new_ctx, {antispam_list: messages}) +        if other_contexts: +            embed.set_footer( +                text="The list of actions taken includes actions from additional contexts after deletion began." +            ) +        await webhook.send(username="Anti-Spam", content=ctx.alert_content, embeds=[embed], view=AlertView(new_ctx)) diff --git a/bot/exts/filtering/_filter_lists/domain.py b/bot/exts/filtering/_filter_lists/domain.py new file mode 100644 index 000000000..091fd14e0 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/domain.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import re +import typing + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.domain import DomainFilter +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings import ActionSettings +from bot.exts.filtering._utils import clean_input + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + +URL_RE = re.compile(r"https?://(\S+)", flags=re.IGNORECASE) + + +class DomainsList(FilterList[DomainFilter]): +    """ +    A list of filters, each looking for a specific domain given by URL. + +    The blacklist defaults dictate what happens by default when a filter is matched, and can be overridden by +    individual filters. + +    Domains are found by looking for a URL schema (http or https). +    Filters will also trigger for subdomains. +    """ + +    name = "domain" + +    def __init__(self, filtering_cog: Filtering): +        super().__init__() +        filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) + +    def get_filter_type(self, content: str) -> type[Filter]: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        return DomainFilter + +    @property +    def filter_types(self) -> set[type[Filter]]: +        """Return the types of filters used by this list.""" +        return {DomainFilter} + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        text = ctx.content +        if not text: +            return None, [], {} + +        text = clean_input(text) +        urls = {match.group(1).lower().rstrip("/") for match in URL_RE.finditer(text)} +        new_ctx = ctx.replace(content=urls) + +        triggers = await self[ListType.DENY].filter_list_result(new_ctx) +        ctx.notification_domain = new_ctx.notification_domain +        actions = None +        messages = [] +        if triggers: +            actions = self[ListType.DENY].merge_actions(triggers) +            messages = self[ListType.DENY].format_messages(triggers) +        return actions, messages, {ListType.DENY: triggers} diff --git a/bot/exts/filtering/_filter_lists/extension.py b/bot/exts/filtering/_filter_lists/extension.py new file mode 100644 index 000000000..d805fa7aa --- /dev/null +++ b/bot/exts/filtering/_filter_lists/extension.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import typing +from os.path import splitext + +import bot +from bot.constants import BaseURLs, Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.extension import ExtensionFilter +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings import ActionSettings + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + + +PY_EMBED_DESCRIPTION = ( +    "It looks like you tried to attach a Python file - " +    f"please use a code-pasting service such as {BaseURLs.site_paste}" +) + +TXT_LIKE_FILES = {".txt", ".csv", ".json"} +TXT_EMBED_DESCRIPTION = ( +    "You either uploaded a `{blocked_extension}` file or entered a message that was too long. " +    f"Please use our [paste bin]({BaseURLs.site_paste}) instead." +) + +DISALLOWED_EMBED_DESCRIPTION = ( +    "It looks like you tried to attach file type(s) that we do not allow ({joined_blacklist}). " +    "We currently allow the following file types: **{joined_whitelist}**.\n\n" +    "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + + +class ExtensionsList(FilterList[ExtensionFilter]): +    """ +    A list of filters, each looking for a file attachment with a specific extension. + +    If an extension is not explicitly allowed, it will be blocked. + +    Whitelist defaults dictate what happens when an extension is *not* explicitly allowed, +    and whitelist filters overrides have no effect. + +    Items should be added as file extensions preceded by a dot. +    """ + +    name = "extension" + +    def __init__(self, filtering_cog: Filtering): +        super().__init__() +        filtering_cog.subscribe(self, Event.MESSAGE, Event.SNEKBOX) +        self._whitelisted_description = None + +    def get_filter_type(self, content: str) -> type[Filter]: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        return ExtensionFilter + +    @property +    def filter_types(self) -> set[type[Filter]]: +        """Return the types of filters used by this list.""" +        return {ExtensionFilter} + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        # Return early if the message doesn't have attachments. +        if not ctx.message or not ctx.attachments: +            return None, [], {} + +        _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx) +        if failed:  # There's no extension filtering in this context. +            return None, [], {} + +        # Find all extensions in the message. +        all_ext = { +            (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.attachments +        } +        new_ctx = ctx.replace(content={ext for ext, _ in all_ext})  # And prepare the context for the filters to read. +        triggered = [ +            filter_ for filter_ in self[ListType.ALLOW].filters.values() if await filter_.triggered_on(new_ctx) +        ] +        allowed_ext = {filter_.content for filter_ in triggered}  # Get the extensions in the message that are allowed. + +        # See if there are any extensions left which aren't allowed. +        not_allowed = {ext: filename for ext, filename in all_ext if ext not in allowed_ext} + +        if ctx.event == Event.SNEKBOX: +            not_allowed = {ext: filename for ext, filename in not_allowed.items() if ext not in TXT_LIKE_FILES} + +        if not not_allowed:  # Yes, it's a double negative. Meaning all attachments are allowed :) +            return None, [], {ListType.ALLOW: triggered} + +        # At this point, something is disallowed. +        if ctx.event != Event.SNEKBOX:  # Don't post the embed if it's a snekbox response. +            if ".py" in not_allowed: +                # Provide a pastebin link for .py files. +                ctx.dm_embed = PY_EMBED_DESCRIPTION +            elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}: +                # Work around Discord auto-conversion of messages longer than 2000 chars to .txt +                ctx.dm_embed = TXT_EMBED_DESCRIPTION.format(blocked_extension=txt_extensions.pop()) +            else: +                meta_channel = bot.instance.get_channel(Channels.meta) +                if not self._whitelisted_description: +                    self._whitelisted_description = ', '.join( +                        filter_.content for filter_ in self[ListType.ALLOW].filters.values() +                    ) +                ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format( +                    joined_whitelist=self._whitelisted_description, +                    joined_blacklist=", ".join(not_allowed), +                    meta_channel_mention=meta_channel.mention, +                ) + +        ctx.matches += not_allowed.values() +        ctx.blocked_exts |= set(not_allowed) +        actions = self[ListType.ALLOW].defaults.actions if ctx.event != Event.SNEKBOX else None +        return actions, [f"`{ext}`" if ext else "`No Extension`" for ext in not_allowed], {ListType.ALLOW: triggered} diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py new file mode 100644 index 000000000..d4c975766 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -0,0 +1,308 @@ +import dataclasses +import typing +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum +from functools import reduce +from typing import Any + +import arrow +from discord.ext.commands import BadArgument + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import Filter, UniqueFilter +from bot.exts.filtering._settings import ActionSettings, Defaults, create_settings +from bot.exts.filtering._utils import FieldRequiring, past_tense +from bot.log import get_logger + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + +log = get_logger(__name__) + + +class ListType(Enum): +    """An enumeration of list types.""" + +    DENY = 0 +    ALLOW = 1 + + +#  Alternative names with which each list type can be specified in commands. +aliases = ( +    (ListType.DENY, {"deny", "blocklist", "blacklist", "denylist", "bl", "dl"}), +    (ListType.ALLOW, {"allow", "allowlist", "whitelist", "al", "wl"}) +) + + +def list_type_converter(argument: str) -> ListType: +    """A converter to get the appropriate list type.""" +    argument = argument.lower() +    for list_type, list_aliases in aliases: +        if argument in list_aliases or argument in map(past_tense, list_aliases): +            return list_type +    raise BadArgument(f"No matching list type found for {argument!r}.") + + +# AtomicList and its subclasses must have eq=False, otherwise the dataclass deco will replace the hash function. +@dataclass(frozen=True, eq=False) +class AtomicList: +    """ +    Represents the atomic structure of a single filter list as it appears in the database. + +    This is as opposed to the FilterList class which is a combination of several list types. +    """ + +    id: int +    created_at: arrow.Arrow +    updated_at: arrow.Arrow +    name: str +    list_type: ListType +    defaults: Defaults +    filters: dict[int, Filter] + +    @property +    def label(self) -> str: +        """Provide a short description identifying the list with its name and type.""" +        return f"{past_tense(self.list_type.name.lower())} {self.name.lower()}" + +    async def filter_list_result(self, ctx: FilterContext) -> list[Filter]: +        """ +        Sift through the list of filters, and return only the ones which apply to the given context. + +        The strategy is as follows: +        1. The default settings are evaluated on the given context. The default answer for whether the filter is +        relevant in the given context is whether there aren't any validation settings which returned False. +        2. For each filter, its overrides are considered: +            - If there are no overrides, then the filter is relevant if that is the default answer. +            - Otherwise it is relevant if there are no failed overrides, and any failing default is overridden by a +            successful override. + +        If the filter is relevant in context, see if it actually triggers. +        """ +        return await self._create_filter_list_result(ctx, self.defaults, self.filters.values()) + +    async def _create_filter_list_result( +        self, ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter] +    ) -> list[Filter]: +        """A helper function to evaluate the result of `filter_list_result`.""" +        passed_by_default, failed_by_default = defaults.validations.evaluate(ctx) +        default_answer = not bool(failed_by_default) + +        relevant_filters = [] +        for filter_ in filters: +            if not filter_.validations: +                if default_answer and await filter_.triggered_on(ctx): +                    relevant_filters.append(filter_) +            else: +                passed, failed = filter_.validations.evaluate(ctx) +                if not failed and failed_by_default < passed: +                    if await filter_.triggered_on(ctx): +                        relevant_filters.append(filter_) + +        if ctx.event == Event.MESSAGE_EDIT and ctx.message and self.list_type == ListType.DENY: +            previously_triggered = ctx.message_cache.get_message_metadata(ctx.message.id) +            # The message might not be cached. +            if previously_triggered: +                ignore_filters = previously_triggered[self] +                # This updates the cache. Some filters are ignored, but they're necessary if there's another edit. +                previously_triggered[self] = relevant_filters +                relevant_filters = [filter_ for filter_ in relevant_filters if filter_ not in ignore_filters] +        return relevant_filters + +    def default(self, setting_name: str) -> Any: +        """Get the default value of a specific setting.""" +        missing = object() +        value = self.defaults.actions.get_setting(setting_name, missing) +        if value is missing: +            value = self.defaults.validations.get_setting(setting_name, missing) +            if value is missing: +                raise ValueError(f"Couldn't find a setting named {setting_name!r}.") +        return value + +    def merge_actions(self, filters: list[Filter]) -> ActionSettings | None: +        """ +        Merge the settings of the given filters, with the list's defaults as fallback. + +        If `merge_default` is True, include it in the merge instead of using it as a fallback. +        """ +        if not filters:  # Nothing to action. +            return None +        try: +            return reduce( +                ActionSettings.union, (filter_.actions or self.defaults.actions for filter_ in filters) +            ).fallback_to(self.defaults.actions) +        except TypeError: +            # The sequence fed to reduce is empty, meaning none of the filters have actions, +            # meaning they all use the defaults. +            return self.defaults.actions + +    @staticmethod +    def format_messages(triggers: list[Filter], *, expand_single_filter: bool = True) -> list[str]: +        """Convert the filters into strings that can be added to the alert embed.""" +        if len(triggers) == 1 and expand_single_filter: +            message = f"#{triggers[0].id} (`{triggers[0].content}`)" +            if triggers[0].description: +                message += f" - {triggers[0].description}" +            messages = [message] +        else: +            messages = [f"{filter_.id} (`{filter_.content}`)" for filter_ in triggers] +        return messages + +    def __hash__(self): +        return hash(id(self)) + + +T = typing.TypeVar("T", bound=Filter) + + +class FilterList(dict[ListType, AtomicList], typing.Generic[T], FieldRequiring): +    """Dispatches events to lists of _filters, and aggregates the responses into a single list of actions to take.""" + +    # Each subclass must define a name matching the filter_list name we're expecting to receive from the database. +    # Names must be unique across all filter lists. +    name = FieldRequiring.MUST_SET_UNIQUE + +    _already_warned = set() + +    def add_list(self, list_data: dict) -> AtomicList: +        """Add a new type of list (such as a whitelist or a blacklist) this filter list.""" +        actions, validations = create_settings(list_data["settings"], keep_empty=True) +        list_type = ListType(list_data["list_type"]) +        defaults = Defaults(actions, validations) + +        filters = {} +        for filter_data in list_data["filters"]: +            new_filter = self._create_filter(filter_data, defaults) +            if new_filter: +                filters[filter_data["id"]] = new_filter + +        self[list_type] = AtomicList( +            list_data["id"], +            arrow.get(list_data["created_at"]), +            arrow.get(list_data["updated_at"]), +            self.name, +            list_type, +            defaults, +            filters +        ) +        return self[list_type] + +    def add_filter(self, list_type: ListType, filter_data: dict) -> T | None: +        """Add a filter to the list of the specified type.""" +        new_filter = self._create_filter(filter_data, self[list_type].defaults) +        if new_filter: +            self[list_type].filters[filter_data["id"]] = new_filter +        return new_filter + +    @abstractmethod +    def get_filter_type(self, content: str) -> type[T]: +        """Get a subclass of filter matching the filter list and the filter's content.""" + +    @property +    @abstractmethod +    def filter_types(self) -> set[type[T]]: +        """Return the types of filters used by this list.""" + +    @abstractmethod +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" + +    def _create_filter(self, filter_data: dict, defaults: Defaults) -> T | None: +        """Create a filter from the given data.""" +        try: +            content = filter_data["content"] +            filter_type = self.get_filter_type(content) +            if filter_type: +                return filter_type(filter_data, defaults) +            elif content not in self._already_warned: +                log.warning(f"A filter named {content} was supplied, but no matching implementation found.") +                self._already_warned.add(content) +            return None +        except TypeError as e: +            log.warning(e) + +    def __hash__(self): +        return hash(id(self)) + + +@dataclass(frozen=True, eq=False) +class SubscribingAtomicList(AtomicList): +    """ +    A base class for a list of unique filters. + +    Unique filters are ones that should only be run once in a given context. +    Each unique filter is subscribed to a subset of events to respond to. +    """ + +    subscriptions: defaultdict[Event, list[int]] = dataclasses.field(default_factory=lambda: defaultdict(list)) + +    def subscribe(self, filter_: UniqueFilter, *events: Event) -> None: +        """ +        Subscribe a unique filter to the given events. + +        The filter is added to a list for each event. When the event is triggered, the filter context will be +        dispatched to the subscribed filters. +        """ +        for event in events: +            if filter_ not in self.subscriptions[event]: +                self.subscriptions[event].append(filter_.id) + +    async def filter_list_result(self, ctx: FilterContext) -> list[Filter]: +        """Sift through the list of filters, and return only the ones which apply to the given context.""" +        event_filters = [self.filters[id_] for id_ in self.subscriptions[ctx.event]] +        return await self._create_filter_list_result(ctx, self.defaults, event_filters) + + +class UniquesListBase(FilterList[UniqueFilter], ABC): +    """ +    A list of unique filters. + +    Unique filters are ones that should only be run once in a given context. +    Each unique filter subscribes to a subset of events to respond to. +    """ + +    def __init__(self, filtering_cog: 'Filtering'): +        super().__init__() +        self.filtering_cog = filtering_cog +        self.loaded_types: dict[str, type[UniqueFilter]] = {} + +    def add_list(self, list_data: dict) -> SubscribingAtomicList: +        """Add a new type of list (such as a whitelist or a blacklist) this filter list.""" +        actions, validations = create_settings(list_data["settings"], keep_empty=True) +        list_type = ListType(list_data["list_type"]) +        defaults = Defaults(actions, validations) +        new_list = SubscribingAtomicList( +            list_data["id"], +            arrow.get(list_data["created_at"]), +            arrow.get(list_data["updated_at"]), +            self.name, +            list_type, +            defaults, +            {} +        ) +        self[list_type] = new_list + +        filters = {} +        events = set() +        for filter_data in list_data["filters"]: +            new_filter = self._create_filter(filter_data, defaults) +            if new_filter: +                new_list.subscribe(new_filter, *new_filter.events) +                filters[filter_data["id"]] = new_filter +                self.loaded_types[new_filter.name] = type(new_filter) +                events.update(new_filter.events) + +        new_list.filters.update(filters) +        if hasattr(self.filtering_cog, "subscribe"):  # Subscribe the filter list to any new events found. +            self.filtering_cog.subscribe(self, *events) +        return new_list + +    @property +    def filter_types(self) -> set[type[UniqueFilter]]: +        """Return the types of filters used by this list.""" +        return set(self.loaded_types.values()) diff --git a/bot/exts/filtering/_filter_lists/invite.py b/bot/exts/filtering/_filter_lists/invite.py new file mode 100644 index 000000000..d934e9d53 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/invite.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import typing + +from discord import Embed, Invite +from discord.errors import NotFound +from pydis_core.utils.regex import DISCORD_INVITE + +import bot +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filters.invite import InviteFilter +from bot.exts.filtering._settings import ActionSettings +from bot.exts.filtering._utils import clean_input + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + + +class InviteList(FilterList[InviteFilter]): +    """ +    A list of filters, each looking for guild invites to a specific guild. + +    If the invite is not whitelisted, it will be blocked. Partnered and verified servers are allowed unless blacklisted. + +    Whitelist defaults dictate what happens when an invite is *not* explicitly allowed, +    and whitelist filters overrides have no effect. + +    Blacklist defaults dictate what happens by default when an explicitly blocked invite is found. + +    Items in the list are added through invites for the purpose of fetching the guild info. +    Items are stored as guild IDs, guild invites are *not* stored. +    """ + +    name = "invite" + +    def __init__(self, filtering_cog: Filtering): +        super().__init__() +        filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) + +    def get_filter_type(self, content: str) -> type[Filter]: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        return InviteFilter + +    @property +    def filter_types(self) -> set[type[Filter]]: +        """Return the types of filters used by this list.""" +        return {InviteFilter} + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        text = clean_input(ctx.content) + +        # Avoid escape characters +        text = text.replace("\\", "") + +        matches = list(DISCORD_INVITE.finditer(text)) +        invite_codes = {m.group("invite") for m in matches} +        if not invite_codes: +            return None, [], {} +        all_triggers = {} + +        _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx) +        # If the allowed list doesn't operate in the context, unknown invites are allowed. +        check_if_allowed = not failed + +        # Sort the invites into two categories: +        invites_for_inspection = dict()  # Found guild invites requiring further inspection. +        unknown_invites = dict()  # Either don't resolve or group DMs. +        for invite_code in invite_codes: +            try: +                invite = await bot.instance.fetch_invite(invite_code) +            except NotFound: +                if check_if_allowed: +                    unknown_invites[invite_code] = None +            else: +                if invite.guild: +                    invites_for_inspection[invite_code] = invite +                elif check_if_allowed:  # Group DM +                    unknown_invites[invite_code] = invite + +        # Find any blocked invites +        new_ctx = ctx.replace(content={invite.guild.id for invite in invites_for_inspection.values()}) +        triggered = await self[ListType.DENY].filter_list_result(new_ctx) +        blocked_guilds = {filter_.content for filter_ in triggered} +        blocked_invites = { +            code: invite for code, invite in invites_for_inspection.items() if invite.guild.id in blocked_guilds +        } + +        # Remove the ones which are already confirmed as blocked, or otherwise ones which are partnered or verified. +        invites_for_inspection = { +            code: invite for code, invite in invites_for_inspection.items() +            if invite.guild.id not in blocked_guilds +            and "PARTNERED" not in invite.guild.features and "VERIFIED" not in invite.guild.features +        } + +        # Remove any remaining invites which are allowed +        guilds_for_inspection = {invite.guild.id for invite in invites_for_inspection.values()} + +        if check_if_allowed:  # Whether unknown invites need to be checked. +            new_ctx = ctx.replace(content=guilds_for_inspection) +            all_triggers[ListType.ALLOW] = [ +                filter_ for filter_ in self[ListType.ALLOW].filters.values() +                if await filter_.triggered_on(new_ctx) +            ] +            allowed = {filter_.content for filter_ in all_triggers[ListType.ALLOW]} +            unknown_invites.update({ +                code: invite for code, invite in invites_for_inspection.items() if invite.guild.id not in allowed +            }) + +        if not triggered and not unknown_invites: +            return None, [], all_triggers + +        actions = None +        if unknown_invites:  # There are invites which weren't allowed but aren't explicitly blocked. +            actions = self[ListType.ALLOW].defaults.actions +        # Blocked invites come second so that their actions have preference. +        if triggered: +            if actions: +                actions = actions.union(self[ListType.DENY].merge_actions(triggered)) +            else: +                actions = self[ListType.DENY].merge_actions(triggered) +            all_triggers[ListType.DENY] = triggered + +        blocked_invites |= unknown_invites +        ctx.matches += {match[0] for match in matches if match.group("invite") in blocked_invites} +        ctx.alert_embeds += (self._guild_embed(invite) for invite in blocked_invites.values() if invite) +        messages = self[ListType.DENY].format_messages(triggered) +        messages += [ +            f"`{code} - {invite.guild.id}`" if invite else f"`{code}`" for code, invite in unknown_invites.items() +        ] +        return actions, messages, all_triggers + +    @staticmethod +    def _guild_embed(invite: Invite) -> Embed: +        """Return an embed representing the guild invites to.""" +        embed = Embed() +        if invite.guild: +            embed.title = invite.guild.name +            embed.set_footer(text=f"Guild ID: {invite.guild.id}") +            if invite.guild.icon is not None: +                embed.set_thumbnail(url=invite.guild.icon.url) +        else: +            embed.title = "Group DM" + +        embed.description = ( +            f"**Invite Code:** {invite.code}\n" +            f"**Members:** {invite.approximate_member_count}\n" +            f"**Active:** {invite.approximate_presence_count}" +        ) + +        return embed diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py new file mode 100644 index 000000000..0c591ac3b --- /dev/null +++ b/bot/exts/filtering/_filter_lists/token.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import re +import typing + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filters.token import TokenFilter +from bot.exts.filtering._settings import ActionSettings +from bot.exts.filtering._utils import clean_input + +if typing.TYPE_CHECKING: +    from bot.exts.filtering.filtering import Filtering + +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) + + +class TokensList(FilterList[TokenFilter]): +    """ +    A list of filters, each looking for a specific token in the given content given as regex. + +    The blacklist defaults dictate what happens by default when a filter is matched, and can be overridden by +    individual filters. + +    Usually, if blocking literal strings, the literals themselves can be specified as the filter's value. +    But since this is a list of regex patterns, be careful of the items added. For example, a dot needs to be escaped +    to function as a literal dot. +    """ + +    name = "token" + +    def __init__(self, filtering_cog: Filtering): +        super().__init__() +        filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME, Event.SNEKBOX) + +    def get_filter_type(self, content: str) -> type[Filter]: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        return TokenFilter + +    @property +    def filter_types(self) -> set[type[Filter]]: +        """Return the types of filters used by this list.""" +        return {TokenFilter} + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        text = ctx.content +        if not text: +            return None, [], {} +        if SPOILER_RE.search(text): +            text = self._expand_spoilers(text) +        text = clean_input(text) +        ctx = ctx.replace(content=text) + +        triggers = await self[ListType.DENY].filter_list_result(ctx) +        actions = None +        messages = [] +        if triggers: +            actions = self[ListType.DENY].merge_actions(triggers) +            messages = self[ListType.DENY].format_messages(triggers) +        return actions, messages, {ListType.DENY: triggers} + +    @staticmethod +    def _expand_spoilers(text: str) -> str: +        """Return a string containing all interpretations of a spoilered message.""" +        split_text = SPOILER_RE.split(text) +        return ''.join( +            split_text[0::2] + split_text[1::2] + split_text +        ) diff --git a/bot/exts/filtering/_filter_lists/unique.py b/bot/exts/filtering/_filter_lists/unique.py new file mode 100644 index 000000000..a5a04d25a --- /dev/null +++ b/bot/exts/filtering/_filter_lists/unique.py @@ -0,0 +1,39 @@ +from pydis_core.utils.logging import get_logger + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filter_lists.filter_list import ListType, UniquesListBase +from bot.exts.filtering._filters.filter import Filter, UniqueFilter +from bot.exts.filtering._filters.unique import unique_filter_types +from bot.exts.filtering._settings import ActionSettings + +log = get_logger(__name__) + + +class UniquesList(UniquesListBase): +    """ +    A list of unique filters. + +    Unique filters are ones that should only be run once in a given context. +    Each unique filter subscribes to a subset of events to respond to. +    """ + +    name = "unique" + +    def get_filter_type(self, content: str) -> type[UniqueFilter] | None: +        """Get a subclass of filter matching the filter list and the filter's content.""" +        try: +            return unique_filter_types[content] +        except KeyError: +            return None + +    async def actions_for( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: +        """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" +        triggers = await self[ListType.DENY].filter_list_result(ctx) +        actions = None +        messages = [] +        if triggers: +            actions = self[ListType.DENY].merge_actions(triggers) +            messages = self[ListType.DENY].format_messages(triggers) +        return actions, messages, {ListType.DENY: triggers} diff --git a/tests/bot/exts/filters/__init__.py b/bot/exts/filtering/_filters/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/exts/filters/__init__.py +++ b/bot/exts/filtering/_filters/__init__.py diff --git a/bot/exts/filtering/_filters/antispam/__init__.py b/bot/exts/filtering/_filters/antispam/__init__.py new file mode 100644 index 000000000..637bcd410 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/__init__.py @@ -0,0 +1,9 @@ +from os.path import dirname + +from bot.exts.filtering._filters.filter import UniqueFilter +from bot.exts.filtering._utils import subclasses_in_package + +antispam_filter_types = subclasses_in_package(dirname(__file__), f"{__name__}.", UniqueFilter) +antispam_filter_types = {filter_.name: filter_ for filter_ in antispam_filter_types} + +__all__ = [antispam_filter_types] diff --git a/bot/exts/filtering/_filters/antispam/attachments.py b/bot/exts/filtering/_filters/antispam/attachments.py new file mode 100644 index 000000000..216d9b886 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/attachments.py @@ -0,0 +1,43 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraAttachmentsSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of attachments before the filter is triggered." + +    interval: int = 10 +    threshold: int = 6 + + +class AttachmentsFilter(UniqueFilter): +    """Detects too many attachments sent by a single user.""" + +    name = "attachments" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraAttachmentsSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author and len(msg.attachments) > 0} +        total_recent_attachments = sum(len(msg.attachments) for msg in detected_messages) + +        if total_recent_attachments > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_recent_attachments} attachments" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/burst.py b/bot/exts/filtering/_filters/antispam/burst.py new file mode 100644 index 000000000..d78107d0a --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/burst.py @@ -0,0 +1,41 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraBurstSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of messages before the filter is triggered." + +    interval: int = 10 +    threshold: int = 7 + + +class BurstFilter(UniqueFilter): +    """Detects too many messages sent by a single user.""" + +    name = "burst" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraBurstSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} +        if len(detected_messages) > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {len(detected_messages)} messages" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/chars.py b/bot/exts/filtering/_filters/antispam/chars.py new file mode 100644 index 000000000..5c4fa201c --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/chars.py @@ -0,0 +1,43 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraCharsSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of characters before the filter is triggered." + +    interval: int = 5 +    threshold: int = 4_200 + + +class CharsFilter(UniqueFilter): +    """Detects too many characters sent by a single user.""" + +    name = "chars" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraCharsSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} +        total_recent_chars = sum(len(msg.content) for msg in relevant_messages) + +        if total_recent_chars > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_recent_chars} characters" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/duplicates.py b/bot/exts/filtering/_filters/antispam/duplicates.py new file mode 100644 index 000000000..60d5c322c --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/duplicates.py @@ -0,0 +1,44 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraDuplicatesSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of duplicate messages before the filter is triggered." + +    interval: int = 10 +    threshold: int = 3 + + +class DuplicatesFilter(UniqueFilter): +    """Detects duplicated messages sent by a single user.""" + +    name = "duplicates" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraDuplicatesSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + +        detected_messages = { +            msg for msg in relevant_messages +            if msg.author == ctx.author and msg.content == ctx.message.content and msg.content +        } +        if len(detected_messages) > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {len(detected_messages)} duplicate messages" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/emoji.py b/bot/exts/filtering/_filters/antispam/emoji.py new file mode 100644 index 000000000..0511e4a7b --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/emoji.py @@ -0,0 +1,53 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from emoji import demojize +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") +CODE_BLOCK_RE = re.compile(r"```.*?```", flags=re.DOTALL) + + +class ExtraEmojiSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of emojis before the filter is triggered." + +    interval: int = 10 +    threshold: int = 20 + + +class EmojiFilter(UniqueFilter): +    """Detects too many emojis sent by a single user.""" + +    name = "emoji" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraEmojiSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + +        # Get rid of code blocks in the message before searching for emojis. +        # Convert Unicode emojis to :emoji: format to get their count. +        total_emojis = sum( +            len(DISCORD_EMOJI_RE.findall(demojize(CODE_BLOCK_RE.sub("", msg.content)))) +            for msg in relevant_messages +        ) + +        if total_emojis > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_emojis} emojis" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/links.py b/bot/exts/filtering/_filters/antispam/links.py new file mode 100644 index 000000000..76fe53e70 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/links.py @@ -0,0 +1,52 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +LINK_RE = re.compile(r"(https?://\S+)") + + +class ExtraLinksSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of links before the filter is triggered." + +    interval: int = 10 +    threshold: int = 10 + + +class DuplicatesFilter(UniqueFilter): +    """Detects too many links sent by a single user.""" + +    name = "links" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraLinksSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + +        total_links = 0 +        messages_with_links = 0 +        for msg in relevant_messages: +            total_matches = len(LINK_RE.findall(msg.content)) +            if total_matches: +                messages_with_links += 1 +                total_links += total_matches + +        if total_links > self.extra_fields.threshold and messages_with_links > 1: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_links} links" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/mentions.py b/bot/exts/filtering/_filters/antispam/mentions.py new file mode 100644 index 000000000..f3c945e16 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/mentions.py @@ -0,0 +1,90 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from discord import DeletedReferencedMessage, MessageType, NotFound +from pydantic import BaseModel +from pydis_core.utils.logging import get_logger + +import bot +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +log = get_logger(__name__) + + +class ExtraMentionsSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of distinct mentions before the filter is triggered." + +    interval: int = 10 +    threshold: int = 5 + + +class DuplicatesFilter(UniqueFilter): +    """ +    Detects total mentions exceeding the limit sent by a single user. + +    Excludes mentions that are bots, themselves, or replied users. + +    In very rare cases, may not be able to determine a +    mention was to a reply, in which case it is not ignored. +    """ + +    name = "mentions" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraMentionsSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + +        # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned. +        # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body. +        # In order to exclude users who are mentioned as a reply, we check if the msg has a reference +        # +        # While we could use regex to parse the message content, and get a list of +        # the mentions, that solution is very prone to breaking. +        # We would need to deal with codeblocks, escaping markdown, and any discrepancies between +        # our implementation and discord's Markdown parser which would cause false positives or false negatives. +        total_recent_mentions = 0 +        for msg in relevant_messages: +            # We check if the message is a reply, and if it is try to get the author +            # since we ignore mentions of a user that we're replying to +            reply_author = None + +            if msg.type == MessageType.reply: +                ref = msg.reference + +                if not (resolved := ref.resolved): +                    # It is possible, in a very unusual situation, for a message to have a reference +                    # that is both not in the cache, and deleted while running this function. +                    # In such a situation, this will throw an error which we catch. +                    try: +                        resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message( +                            resolved.message_id +                        ) +                    except NotFound: +                        log.info('Could not fetch the reference message as it has been deleted.') + +                if resolved and not isinstance(resolved, DeletedReferencedMessage): +                    reply_author = resolved.author + +            for user in msg.mentions: +                # Don't count bot or self mentions, or the user being replied to (if applicable) +                if user.bot or user in {msg.author, reply_author}: +                    continue +                total_recent_mentions += 1 + +        if total_recent_mentions > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_recent_mentions} mentions" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/newlines.py b/bot/exts/filtering/_filters/antispam/newlines.py new file mode 100644 index 000000000..b15a35219 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/newlines.py @@ -0,0 +1,61 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +NEWLINES = re.compile(r"(\n+)") + + +class ExtraNewlinesSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of newlines before the filter is triggered." +    consecutive_threshold_description: ClassVar[str] = ( +        "Maximum number of consecutive newlines before the filter is triggered." +    ) + +    interval: int = 10 +    threshold: int = 100 +    consecutive_threshold: int = 10 + + +class NewlinesFilter(UniqueFilter): +    """Detects too many newlines sent by a single user.""" + +    name = "newlines" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraNewlinesSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + +        # Identify groups of newline characters and get group & total counts +        newline_counts = [] +        for msg in relevant_messages: +            newline_counts += [len(group) for group in NEWLINES.findall(msg.content)] +        total_recent_newlines = sum(newline_counts) +        # Get maximum newline group size +        max_newline_group = max(newline_counts, default=0) + +        # Check first for total newlines, if this passes then check for large groupings +        if total_recent_newlines > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_recent_newlines} newlines" +            return True +        if max_newline_group > self.extra_fields.consecutive_threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {max_newline_group} consecutive newlines" +            return True +        return False diff --git a/bot/exts/filtering/_filters/antispam/role_mentions.py b/bot/exts/filtering/_filters/antispam/role_mentions.py new file mode 100644 index 000000000..49de642fa --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/role_mentions.py @@ -0,0 +1,42 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraRoleMentionsSettings(BaseModel): +    """Extra settings for when to trigger the antispam rule.""" + +    interval_description: ClassVar[str] = ( +        "Look for rule violations in messages from the last `interval` number of seconds." +    ) +    threshold_description: ClassVar[str] = "Maximum number of role mentions before the filter is triggered." + +    interval: int = 10 +    threshold: int = 3 + + +class DuplicatesFilter(UniqueFilter): +    """Detects too many role mentions sent by a single user.""" + +    name = "role_mentions" +    events = (Event.MESSAGE,) +    extra_fields_type = ExtraRoleMentionsSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) +        detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} +        total_recent_mentions = sum(len(msg.role_mentions) for msg in relevant_messages) + +        if total_recent_mentions > self.extra_fields.threshold: +            ctx.related_messages |= detected_messages +            ctx.filter_info[self] = f"sent {total_recent_mentions} role mentions" +            return True +        return False diff --git a/bot/exts/filtering/_filters/domain.py b/bot/exts/filtering/_filters/domain.py new file mode 100644 index 000000000..ac9cc9018 --- /dev/null +++ b/bot/exts/filtering/_filters/domain.py @@ -0,0 +1,62 @@ +import re +from typing import ClassVar +from urllib.parse import urlparse + +import tldextract +from discord.ext.commands import BadArgument +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filters.filter import Filter + +URL_RE = re.compile(r"(?:https?://)?(\S+?)[\\/]*", flags=re.IGNORECASE) + + +class ExtraDomainSettings(BaseModel): +    """Extra settings for how domains should be matched in a message.""" + +    only_subdomains_description: ClassVar[str] = ( +        "A boolean. If True, will only trigger for subdomains and subpaths, and not for the domain itself." +    ) + +    # Whether to trigger only for subdomains and subpaths, and not the specified domain itself. +    only_subdomains: bool = False + + +class DomainFilter(Filter): +    """ +    A filter which looks for a specific domain given by URL. + +    The schema (http, https) does not need to be included in the filter. +    Will also match subdomains. +    """ + +    name = "domain" +    extra_fields_type = ExtraDomainSettings + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Searches for a domain within a given context.""" +        domain = tldextract.extract(self.content).registered_domain + +        for found_url in ctx.content: +            extract = tldextract.extract(found_url) +            if self.content in found_url and extract.registered_domain == domain: +                if self.extra_fields.only_subdomains: +                    if not extract.subdomain and not urlparse(f"https://{found_url}").path: +                        return False +                ctx.matches.append(found_url) +                ctx.notification_domain = self.content +                return True +        return False + +    @classmethod +    async def process_input(cls, content: str, description: str) -> tuple[str, str]: +        """ +        Process the content and description into a form which will work with the filtering. + +        A BadArgument should be raised if the content can't be used. +        """ +        match = URL_RE.fullmatch(content) +        if not match or not match.group(1): +            raise BadArgument(f"`{content}` is not a URL.") +        return match.group(1), description diff --git a/bot/exts/filtering/_filters/extension.py b/bot/exts/filtering/_filters/extension.py new file mode 100644 index 000000000..97eddc406 --- /dev/null +++ b/bot/exts/filtering/_filters/extension.py @@ -0,0 +1,27 @@ +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filters.filter import Filter + + +class ExtensionFilter(Filter): +    """ +    A filter which looks for a specific attachment extension in messages. + +    The filter stores the extension preceded by a dot. +    """ + +    name = "extension" + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Searches for an attachment extension in the context content, given as a set of extensions.""" +        return self.content in ctx.content + +    @classmethod +    async def process_input(cls, content: str, description: str) -> tuple[str, str]: +        """ +        Process the content and description into a form which will work with the filtering. + +        A BadArgument should be raised if the content can't be used. +        """ +        if not content.startswith("."): +            content = f".{content}" +        return content, description diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py new file mode 100644 index 000000000..2b8f8d5d4 --- /dev/null +++ b/bot/exts/filtering/_filters/filter.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from typing import Any + +import arrow +from pydantic import ValidationError + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings import Defaults, create_settings +from bot.exts.filtering._utils import FieldRequiring + + +class Filter(FieldRequiring): +    """ +    A class representing a filter. + +    Each filter looks for a specific attribute within an event (such as message sent), +    and defines what action should be performed if it is triggered. +    """ + +    # Each subclass must define a name which will be used to fetch its description. +    # Names must be unique across all types of filters. +    name = FieldRequiring.MUST_SET_UNIQUE +    # If a subclass uses extra fields, it should assign the pydantic model type to this variable. +    extra_fields_type = None + +    def __init__(self, filter_data: dict, defaults: Defaults | None = None): +        self.id = filter_data["id"] +        self.content = filter_data["content"] +        self.description = filter_data["description"] +        self.created_at = arrow.get(filter_data["created_at"]) +        self.updated_at = arrow.get(filter_data["updated_at"]) +        self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults) +        if self.extra_fields_type: +            self.extra_fields = self.extra_fields_type.parse_obj(filter_data["additional_settings"]) +        else: +            self.extra_fields = None + +    @property +    def overrides(self) -> tuple[dict[str, Any], dict[str, Any]]: +        """Return a tuple of setting overrides and filter setting overrides.""" +        settings = {} +        if self.actions: +            settings = self.actions.overrides +        if self.validations: +            settings |= self.validations.overrides + +        filter_settings = {} +        if self.extra_fields: +            filter_settings = self.extra_fields.dict(exclude_unset=True) + +        return settings, filter_settings + +    @abstractmethod +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" + +    @classmethod +    def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, str | None]: +        """Validate whether the supplied fields are valid for the filter, and provide the error message if not.""" +        if cls.extra_fields_type is None: +            return True, None + +        try: +            cls.extra_fields_type(**extra_fields) +        except ValidationError as e: +            return False, repr(e) +        else: +            return True, None + +    @classmethod +    async def process_input(cls, content: str, description: str) -> tuple[str, str]: +        """ +        Process the content and description into a form which will work with the filtering. + +        A BadArgument should be raised if the content can't be used. +        """ +        return content, description + +    def __str__(self) -> str: +        """A string representation of the filter.""" +        string = f"{self.id}. `{self.content}`" +        if self.description: +            string += f" - {self.description}" +        return string + + +class UniqueFilter(Filter, ABC): +    """ +    Unique filters are ones that should only be run once in a given context. + +    This is as opposed to say running many domain filters on the same message. +    """ + +    events: tuple[Event, ...] = FieldRequiring.MUST_SET diff --git a/bot/exts/filtering/_filters/invite.py b/bot/exts/filtering/_filters/invite.py new file mode 100644 index 000000000..799a302b9 --- /dev/null +++ b/bot/exts/filtering/_filters/invite.py @@ -0,0 +1,48 @@ +from discord import NotFound +from discord.ext.commands import BadArgument +from pydis_core.utils.regex import DISCORD_INVITE + +import bot +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filters.filter import Filter + + +class InviteFilter(Filter): +    """ +    A filter which looks for invites to a specific guild in messages. + +    The filter stores the guild ID which is allowed or denied. +    """ + +    name = "invite" + +    def __init__(self, filter_data: dict, defaults_data: dict | None = None): +        super().__init__(filter_data, defaults_data) +        self.content = int(self.content) + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Searches for a guild ID in the context content, given as a set of IDs.""" +        return self.content in ctx.content + +    @classmethod +    async def process_input(cls, content: str, description: str) -> tuple[str, str]: +        """ +        Process the content and description into a form which will work with the filtering. + +        A BadArgument should be raised if the content can't be used. +        """ +        match = DISCORD_INVITE.fullmatch(content) +        if not match or not match.group("invite"): +            raise BadArgument(f"`{content}` is not a valid Discord invite.") +        invite_code = match.group("invite") +        try: +            invite = await bot.instance.fetch_invite(invite_code) +        except NotFound: +            raise BadArgument(f"`{invite_code}` is not a valid Discord invite code.") +        if not invite.guild: +            raise BadArgument("Did you just try to add a group DM?") + +        guild_name = invite.guild.name if hasattr(invite.guild, "name") else "" +        if guild_name.lower() not in description.lower(): +            description = " - ".join(part for part in (f'Guild "{guild_name}"', description) if part) +        return str(invite.guild.id), description diff --git a/bot/exts/filtering/_filters/token.py b/bot/exts/filtering/_filters/token.py new file mode 100644 index 000000000..3cd9b909d --- /dev/null +++ b/bot/exts/filtering/_filters/token.py @@ -0,0 +1,35 @@ +import re + +from discord.ext.commands import BadArgument + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filters.filter import Filter + + +class TokenFilter(Filter): +    """A filter which looks for a specific token given by regex.""" + +    name = "token" + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Searches for a regex pattern within a given context.""" +        pattern = self.content + +        match = re.search(pattern, ctx.content, flags=re.IGNORECASE) +        if match: +            ctx.matches.append(match[0]) +            return True +        return False + +    @classmethod +    async def process_input(cls, content: str, description: str) -> tuple[str, str]: +        """ +        Process the content and description into a form which will work with the filtering. + +        A BadArgument should be raised if the content can't be used. +        """ +        try: +            re.compile(content) +        except re.error as e: +            raise BadArgument(str(e)) +        return content, description diff --git a/bot/exts/filtering/_filters/unique/__init__.py b/bot/exts/filtering/_filters/unique/__init__.py new file mode 100644 index 000000000..ce78d6922 --- /dev/null +++ b/bot/exts/filtering/_filters/unique/__init__.py @@ -0,0 +1,9 @@ +from os.path import dirname + +from bot.exts.filtering._filters.filter import UniqueFilter +from bot.exts.filtering._utils import subclasses_in_package + +unique_filter_types = subclasses_in_package(dirname(__file__), f"{__name__}.", UniqueFilter) +unique_filter_types = {filter_.name: filter_ for filter_ in unique_filter_types} + +__all__ = [unique_filter_types] diff --git a/bot/exts/filters/token_remover.py b/bot/exts/filtering/_filters/unique/discord_token.py index 29f80671d..f4b9cc741 100644 --- a/bot/exts/filters/token_remover.py +++ b/bot/exts/filtering/_filters/unique/discord_token.py @@ -1,38 +1,34 @@  import base64  import re -import typing as t - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot import utils -from bot.bot import Bot -from bot.constants import Channels, Colours, Event, Icons +from collections.abc import Callable, Coroutine +from typing import ClassVar, NamedTuple + +import discord +from pydantic import BaseModel, Field +from pydis_core.utils.logging import get_logger +from pydis_core.utils.members import get_or_fetch_member + +import bot +from bot import constants, utils +from bot.constants import Guild +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter +from bot.exts.filtering._utils import resolve_mention  from bot.exts.moderation.modlog import ModLog -from bot.log import get_logger -from bot.utils.members import get_or_fetch_member  from bot.utils.messages import format_user  log = get_logger(__name__) +  LOG_MESSAGE = ( -    "Censored a seemingly valid token sent by {author} in {channel}, " -    "token was `{user_id}.{timestamp}.{hmac}`" +    "Censored a seemingly valid token sent by {author} in {channel}. " +    "Token was: `{user_id}.{timestamp}.{hmac}`."  )  UNKNOWN_USER_LOG_MESSAGE = "Decoded user ID: `{user_id}` (Not present in server)."  KNOWN_USER_LOG_MESSAGE = (      "Decoded user ID: `{user_id}` **(Present in server)**.\n"      "This matches `{user_name}` and means this is likely a valid **{kind}** token."  ) -DELETION_MESSAGE_TEMPLATE = ( -    "Hey {mention}! I noticed you posted a seemingly valid Discord API " -    "token in your message and have removed your message. " -    "This means that your token has been **compromised**. " -    "Please change your token **immediately** at: " -    "<https://discord.com/developers/applications>\n\n" -    "Feel free to re-post it with the token removed. " -    "If you believe this was a mistake, please let us know!" -)  DISCORD_EPOCH = 1_420_070_400  TOKEN_EPOCH = 1_293_840_000 @@ -43,7 +39,17 @@ TOKEN_EPOCH = 1_293_840_000  TOKEN_RE = re.compile(r"([\w-]{10,})\.([\w-]{5,})\.([\w-]{10,})") -class Token(t.NamedTuple): +class ExtraDiscordTokenSettings(BaseModel): +    """Extra settings for who should be pinged when a Discord token is detected.""" + +    pings_for_bot_description: ClassVar[str] = "A sequence. Who should be pinged if the token found belongs to a bot." +    pings_for_user_description: ClassVar[str] = "A sequence. Who should be pinged if the token found belongs to a user." + +    pings_for_bot: set[str] = Field(default_factory=set) +    pings_for_user: set[str] = Field(default_factory=lambda: {"Moderators"}) + + +class Token(NamedTuple):      """A Discord Bot token."""      user_id: str @@ -51,84 +57,64 @@ class Token(t.NamedTuple):      hmac: str -class TokenRemover(Cog): +class DiscordTokenFilter(UniqueFilter):      """Scans messages for potential discord client tokens and removes them.""" -    def __init__(self, bot: Bot): -        self.bot = bot +    name = "discord_token" +    events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) +    extra_fields_type = ExtraDiscordTokenSettings      @property -    def mod_log(self) -> ModLog: +    def mod_log(self) -> ModLog | None:          """Get currently loaded ModLog cog instance.""" -        return self.bot.get_cog("ModLog") - -    @Cog.listener() -    async def on_message(self, msg: Message) -> None: -        """ -        Check each message for a string that matches Discord's token pattern. +        return bot.instance.get_cog("ModLog") -        See: https://discordapp.com/developers/docs/reference#snowflakes -        """ -        # Ignore DMs; can't delete messages in there anyway. -        if not msg.guild or msg.author.bot: -            return - -        found_token = self.find_token_in_message(msg) -        if found_token: -            await self.take_action(msg, found_token) - -    @Cog.listener() -    async def on_message_edit(self, before: Message, after: Message) -> None: -        """ -        Check each edit for a string that matches Discord's token pattern. - -        See: https://discordapp.com/developers/docs/reference#snowflakes -        """ -        await self.on_message(after) - -    async def take_action(self, msg: Message, found_token: Token) -> None: -        """Remove the `msg` containing the `found_token` and send a mod log message.""" -        self.mod_log.ignore(Event.message_delete, msg.id) - -        try: -            await msg.delete() -        except NotFound: -            log.debug(f"Failed to remove token in message {msg.id}: message already deleted.") -            return - -        await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - -        log_message = self.format_log_message(msg, found_token) -        userid_message, mention_everyone = await self.format_userid_log_message(msg, found_token) -        log.debug(log_message) - -        # Send pretty mod log embed to mod-alerts -        await self.mod_log.send_log_message( -            icon_url=Icons.token_removed, -            colour=Colour(Colours.soft_red), -            title="Token removed!", -            text=log_message + "\n" + userid_message, -            thumbnail=msg.author.display_avatar.url, -            channel_id=Channels.mod_alerts, -            ping_everyone=mention_everyone, -        ) +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Return whether the message contains Discord client tokens.""" +        found_token = self.find_token_in_message(ctx.content) +        if not found_token: +            return False -        self.bot.stats.incr("tokens.removed_tokens") +        if ctx.message and (mod_log := self.mod_log): +            mod_log.ignore(constants.Event.message_delete, ctx.message.id) +        ctx.content = ctx.content.replace(found_token.hmac, self.censor_hmac(found_token.hmac)) +        ctx.additional_actions.append(self._create_token_alert_embed_wrapper(found_token)) +        return True + +    def _create_token_alert_embed_wrapper(self, found_token: Token) -> Callable[[FilterContext], Coroutine]: +        """Create the action to perform when an alert should be sent for a message containing a Discord token.""" +        async def _create_token_alert_embed(ctx: FilterContext) -> None: +            """Add an alert embed to the context with info about the token sent.""" +            userid_message, is_user = await self.format_userid_log_message(found_token) +            log_message = self.format_log_message(ctx.author, ctx.channel, found_token) +            log.debug(log_message) + +            if is_user: +                mentions = map(resolve_mention, self.extra_fields.pings_for_user) +                color = discord.Colour.red() +            else: +                mentions = map(resolve_mention, self.extra_fields.pings_for_bot) +                color = discord.Colour.blue() +            unmentioned = [mention for mention in mentions if mention not in ctx.alert_content] +            if unmentioned: +                ctx.alert_content = f"{' '.join(unmentioned)} {ctx.alert_content}" +            ctx.alert_embeds.append(discord.Embed(colour=color, description=userid_message)) + +        return _create_token_alert_embed      @classmethod -    async def format_userid_log_message(cls, msg: Message, token: Token) -> t.Tuple[str, bool]: +    async def format_userid_log_message(cls, token: Token) -> tuple[str, bool]:          """          Format the portion of the log message that includes details about the detected user ID.          If the user is resolved to a member, the format includes the user ID, name, and the          kind of user detected. - -        If we resolve to a member and it is not a bot, we also return True to ping everyone. - -        Returns a tuple of (log_message, mention_everyone) +        If it is resolved to a user or a member, and it is not a bot, also return True. +        Returns a tuple of (log_message, is_user)          """          user_id = cls.extract_user_id(token.user_id) -        user = await get_or_fetch_member(msg.guild, user_id) +        guild = bot.instance.get_guild(Guild.id) +        user = await get_or_fetch_member(guild, user_id)          if user:              return KNOWN_USER_LOG_MESSAGE.format( @@ -140,22 +126,27 @@ class TokenRemover(Cog):              return UNKNOWN_USER_LOG_MESSAGE.format(user_id=user_id), False      @staticmethod -    def format_log_message(msg: Message, token: Token) -> str: +    def censor_hmac(hmac: str) -> str: +        """Return a censored version of the hmac.""" +        return 'x' * (len(hmac) - 3) + hmac[-3:] + +    @classmethod +    def format_log_message(cls, author: discord.User, channel: discord.abc.GuildChannel, token: Token) -> str:          """Return the generic portion of the log message to send for `token` being censored in `msg`."""          return LOG_MESSAGE.format( -            author=format_user(msg.author), -            channel=msg.channel.mention, +            author=format_user(author), +            channel=channel.mention,              user_id=token.user_id,              timestamp=token.timestamp, -            hmac='x' * (len(token.hmac) - 3) + token.hmac[-3:], +            hmac=cls.censor_hmac(token.hmac),          )      @classmethod -    def find_token_in_message(cls, msg: Message) -> t.Optional[Token]: -        """Return a seemingly valid token found in `msg` or `None` if no token is found.""" +    def find_token_in_message(cls, content: str) -> Token | None: +        """Return a seemingly valid token found in `content` or `None` if no token is found."""          # Use finditer rather than search to guard against method calls prematurely returning the          # token check (e.g. `message.channel.send` also matches our token pattern) -        for match in TOKEN_RE.finditer(msg.content): +        for match in TOKEN_RE.finditer(content):              token = Token(*match.groups())              if (                  (cls.extract_user_id(token.user_id) is not None) @@ -169,7 +160,7 @@ class TokenRemover(Cog):          return None      @staticmethod -    def extract_user_id(b64_content: str) -> t.Optional[int]: +    def extract_user_id(b64_content: str) -> int | None:          """Return a user ID integer from part of a potential token, or None if it couldn't be decoded."""          b64_content = utils.pad_base64(b64_content) @@ -214,7 +205,7 @@ class TokenRemover(Cog):          """          Determine if a given HMAC portion of a token is potentially valid. -        If the HMAC has 3 or less characters, it's probably a dummy value like "xxxxxxxxxx", +        If the HMAC has 3 or fewer characters, it's probably a dummy value like "xxxxxxxxxx",          and thus the token can probably be skipped.          """          unique = len(set(b64_content.lower())) @@ -226,8 +217,3 @@ class TokenRemover(Cog):              return False          else:              return True - - -async def setup(bot: Bot) -> None: -    """Load the TokenRemover cog.""" -    await bot.add_cog(TokenRemover(bot)) diff --git a/bot/exts/filtering/_filters/unique/everyone.py b/bot/exts/filtering/_filters/unique/everyone.py new file mode 100644 index 000000000..e49ede82f --- /dev/null +++ b/bot/exts/filtering/_filters/unique/everyone.py @@ -0,0 +1,28 @@ +import re + +from bot.constants import Guild +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{Guild.id}>|@here") +CODE_BLOCK_RE = re.compile( +    r"(?P<delim>``?)[^`]+?(?P=delim)(?!`+)"  # Inline codeblock +    r"|```(.+?)```",  # Multiline codeblock +    re.DOTALL | re.MULTILINE +) + + +class EveryoneFilter(UniqueFilter): +    """Filter messages which contain `@everyone` and `@here` tags outside a codeblock.""" + +    name = "everyone" +    events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for the filter's content within a given context.""" +        # First pass to avoid running re.sub on every message +        if not EVERYONE_PING_RE.search(ctx.content): +            return False + +        content_without_codeblocks = CODE_BLOCK_RE.sub("", ctx.content) +        return bool(EVERYONE_PING_RE.search(content_without_codeblocks)) diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py new file mode 100644 index 000000000..2ee469f51 --- /dev/null +++ b/bot/exts/filtering/_filters/unique/rich_embed.py @@ -0,0 +1,51 @@ +import re + +from pydis_core.utils.logging import get_logger + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter +from bot.utils.helpers import remove_subdomain_from_url + +log = get_logger(__name__) + +URL_RE = re.compile(r"(https?://\S+)", flags=re.IGNORECASE) + + +class RichEmbedFilter(UniqueFilter): +    """Filter messages which contain rich embeds not auto-generated from a URL.""" + +    name = "rich_embed" +    events = (Event.MESSAGE, Event.MESSAGE_EDIT) + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Determine if `msg` contains any rich embeds not auto-generated from a URL.""" +        if ctx.embeds: +            if ctx.event == Event.MESSAGE_EDIT: +                if not ctx.message.edited_at:  # This might happen, apparently. +                    return False +                # If the edit delta is less than 100 microseconds, it's probably a double filter trigger. +                delta = ctx.message.edited_at - (ctx.before_message.edited_at or ctx.before_message.created_at) +                if delta.total_seconds() < 0.0001: +                    return False + +            for embed in ctx.embeds: +                if embed.type == "rich": +                    urls = URL_RE.findall(ctx.content) +                    final_urls = set(urls) +                    # This is due to the way discord renders relative urls in Embeds +                    # if the following url is sent: https://mobile.twitter.com/something +                    # Discord renders it as https://twitter.com/something +                    for url in urls: +                        final_urls.add(remove_subdomain_from_url(url)) +                    if not embed.url or embed.url not in final_urls: +                        # If `embed.url` does not exist or if `embed.url` is not part of the content +                        # of the message, it's unlikely to be an auto-generated embed by Discord. +                        ctx.alert_embeds.extend(ctx.embeds) +                        return True +                    else: +                        log.trace( +                            "Found a rich embed sent by a regular user account, " +                            "but it was likely just an automatic URL embed." +                        ) + +        return False diff --git a/bot/exts/filtering/_filters/unique/webhook.py b/bot/exts/filtering/_filters/unique/webhook.py new file mode 100644 index 000000000..4e1e2e44d --- /dev/null +++ b/bot/exts/filtering/_filters/unique/webhook.py @@ -0,0 +1,63 @@ +import re +from collections.abc import Callable, Coroutine + +from pydis_core.utils.logging import get_logger + +import bot +from bot import constants +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter +from bot.exts.moderation.modlog import ModLog + +log = get_logger(__name__) + + +WEBHOOK_URL_RE = re.compile( +    r"((?:https?://)?(?:ptb\.|canary\.)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", +    re.IGNORECASE +) + + +class WebhookFilter(UniqueFilter): +    """Scan messages to detect Discord webhooks links.""" + +    name = "webhook" +    events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) + +    @property +    def mod_log(self) -> ModLog | None: +        """Get current instance of `ModLog`.""" +        return bot.instance.get_cog("ModLog") + +    async def triggered_on(self, ctx: FilterContext) -> bool: +        """Search for a webhook in the given content. If found, attempt to delete it.""" +        matches = set(WEBHOOK_URL_RE.finditer(ctx.content)) +        if not matches: +            return False + +        # Don't log this. +        if ctx.message and (mod_log := self.mod_log): +            mod_log.ignore(constants.Event.message_delete, ctx.message.id) + +        for i, match in enumerate(matches, start=1): +            extra = "" if len(matches) == 1 else f" ({i})" +            # Queue the webhook for deletion. +            ctx.additional_actions.append(self._delete_webhook_wrapper(match[0], extra)) +            # Don't show the full webhook in places such as the mod alert. +            ctx.content = ctx.content.replace(match[0], match[1] + "xxx") + +        return True + +    @staticmethod +    def _delete_webhook_wrapper(webhook_url: str, extra_message: str) -> Callable[[FilterContext], Coroutine]: +        """Create the action to perform when a webhook should be deleted.""" +        async def _delete_webhook(ctx: FilterContext) -> None: +            """Delete the given webhook and update the filter context.""" +            async with bot.instance.http_session.delete(webhook_url) as resp: +                # The Discord API Returns a 204 NO CONTENT response on success. +                if resp.status == 204: +                    ctx.action_descriptions.append("webhook deleted" + extra_message) +                else: +                    ctx.action_descriptions.append("failed to delete webhook" + extra_message) + +        return _delete_webhook diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py new file mode 100644 index 000000000..f51a42704 --- /dev/null +++ b/bot/exts/filtering/_settings.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import operator +import traceback +from abc import abstractmethod +from copy import copy +from functools import reduce +from typing import Any, NamedTuple, Optional, TypeVar + +from typing_extensions import Self + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types import settings_types +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, SettingsEntry, ValidationEntry +from bot.exts.filtering._utils import FieldRequiring +from bot.log import get_logger + +TSettings = TypeVar("TSettings", bound="Settings") + +log = get_logger(__name__) + +_already_warned: set[str] = set() + +T = TypeVar("T", bound=SettingsEntry) + + +def create_settings( +    settings_data: dict, *, defaults: Defaults | None = None, keep_empty: bool = False +) -> tuple[Optional[ActionSettings], Optional[ValidationSettings]]: +    """ +    Create and return instances of the Settings subclasses from the given data. + +    Additionally, warn for data entries with no matching class. + +    In case these are setting overrides, the defaults can be provided to keep track of the correct values. +    """ +    action_data = {} +    validation_data = {} +    for entry_name, entry_data in settings_data.items(): +        if entry_name in settings_types["ActionEntry"]: +            action_data[entry_name] = entry_data +        elif entry_name in settings_types["ValidationEntry"]: +            validation_data[entry_name] = entry_data +        elif entry_name not in _already_warned: +            log.warning( +                f"A setting named {entry_name} was loaded from the database, but no matching class." +            ) +            _already_warned.add(entry_name) +    if defaults is None: +        default_actions = None +        default_validations = None +    else: +        default_actions, default_validations = defaults +    return ( +        ActionSettings.create(action_data, defaults=default_actions, keep_empty=keep_empty), +        ValidationSettings.create(validation_data, defaults=default_validations, keep_empty=keep_empty) +    ) + + +class Settings(FieldRequiring, dict[str, T]): +    """ +    A collection of settings. + +    For processing the settings parts in the database and evaluating them on given contexts. + +    Each filter list and filter has its own settings. + +    A filter doesn't have to have its own settings. For every undefined setting, it falls back to the value defined in +    the filter list which contains the filter. +    """ + +    entry_type: type[T] + +    _already_warned: set[str] = set() + +    @abstractmethod  # ABCs have to have at least once abstract method to actually count as such. +    def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): +        super().__init__() + +        entry_classes = settings_types.get(self.entry_type.__name__) +        for entry_name, entry_data in settings_data.items(): +            try: +                entry_cls = entry_classes[entry_name] +            except KeyError: +                if entry_name not in self._already_warned: +                    log.warning( +                        f"A setting named {entry_name} was loaded from the database, " +                        f"but no matching {self.entry_type.__name__} class." +                    ) +                    self._already_warned.add(entry_name) +            else: +                try: +                    entry_defaults = None if defaults is None else defaults[entry_name] +                    new_entry = entry_cls.create( +                        entry_data, defaults=entry_defaults, keep_empty=keep_empty +                    ) +                    if new_entry: +                        self[entry_name] = new_entry +                except TypeError as e: +                    raise TypeError( +                        f"Attempted to load a {entry_name} setting, but the response is malformed: {entry_data}" +                    ) from e + +    @property +    def overrides(self) -> dict[str, Any]: +        """Return a dictionary of overrides across all entries.""" +        return reduce(operator.or_, (entry.overrides for entry in self.values() if entry), {}) + +    def copy(self: TSettings) -> TSettings: +        """Create a shallow copy of the object.""" +        return copy(self) + +    def get_setting(self, key: str, default: Optional[Any] = None) -> Any: +        """Get the setting matching the key, or fall back to the default value if the key is missing.""" +        for entry in self.values(): +            if hasattr(entry, key): +                return getattr(entry, key) +        return default + +    @classmethod +    def create( +        cls, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False +    ) -> Optional[Settings]: +        """ +        Returns a Settings object from `settings_data` if it holds any value, None otherwise. + +        Use this method to create Settings objects instead of the init. +        The None value is significant for how a filter list iterates over its filters. +        """ +        settings = cls(settings_data, defaults=defaults, keep_empty=keep_empty) +        # If an entry doesn't hold any values, its `create` method will return None. +        # If all entries are None, then the settings object holds no values. +        if not keep_empty and not any(settings.values()): +            return None + +        return settings + + +class ValidationSettings(Settings[ValidationEntry]): +    """ +    A collection of validation settings. + +    A filter is triggered only if all of its validation settings (e.g whether to invoke in DM) approve +    (the check returns True). +    """ + +    entry_type = ValidationEntry + +    def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): +        super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty) + +    def evaluate(self, ctx: FilterContext) -> tuple[set[str], set[str]]: +        """Evaluates for each setting whether the context is relevant to the filter.""" +        passed = set() +        failed = set() + +        for name, validation in self.items(): +            if validation: +                if validation.triggers_on(ctx): +                    passed.add(name) +                else: +                    failed.add(name) + +        return passed, failed + + +class ActionSettings(Settings[ActionEntry]): +    """ +    A collection of action settings. + +    If a filter is triggered, its action settings (e.g how to infract the user) are combined with the action settings of +    other triggered filters in the same event, and action is taken according to the combined action settings. +    """ + +    entry_type = ActionEntry + +    def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): +        super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty) + +    def union(self, other: Self) -> Self: +        """Combine the entries of two collections of settings into a new ActionsSettings.""" +        actions = {} +        # A settings object doesn't necessarily have all types of entries (e.g in the case of filter overrides). +        for entry in self: +            if entry in other: +                actions[entry] = self[entry].union(other[entry]) +            else: +                actions[entry] = self[entry] +        for entry in other: +            if entry not in actions: +                actions[entry] = other[entry] + +        result = ActionSettings({}) +        result.update(actions) +        return result + +    async def action(self, ctx: FilterContext) -> None: +        """Execute the action of every action entry stored, as well as any additional actions in the context.""" +        for entry in self.values(): +            try: +                await entry.action(ctx) +            # Filtering should not stop even if one type of action raised an exception. +            # For example, if deleting the message raised somehow, it should still try to infract the user. +            except Exception: +                log.exception(traceback.format_exc()) + +        for action in ctx.additional_actions: +            try: +                await action(ctx) +            except Exception: +                log.exception(traceback.format_exc()) + +    def fallback_to(self, fallback: ActionSettings) -> ActionSettings: +        """Fill in missing entries from `fallback`.""" +        new_actions = self.copy() +        for entry_name, entry_value in fallback.items(): +            if entry_name not in self: +                new_actions[entry_name] = entry_value +        return new_actions + + +class Defaults(NamedTuple): +    """Represents an atomic list's default settings.""" + +    actions: ActionSettings +    validations: ValidationSettings + +    def dict(self) -> dict[str, Any]: +        """Return a dict representation of the stored fields across all entries.""" +        dict_ = {} +        for settings in self: +            dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_) +        return dict_ diff --git a/bot/exts/filtering/_settings_types/__init__.py b/bot/exts/filtering/_settings_types/__init__.py new file mode 100644 index 000000000..61b5737d4 --- /dev/null +++ b/bot/exts/filtering/_settings_types/__init__.py @@ -0,0 +1,9 @@ +from bot.exts.filtering._settings_types.actions import action_types +from bot.exts.filtering._settings_types.validations import validation_types + +settings_types = { +    "ActionEntry": {settings_type.name: settings_type for settings_type in action_types}, +    "ValidationEntry": {settings_type.name: settings_type for settings_type in validation_types} +} + +__all__ = [settings_types] diff --git a/bot/exts/filtering/_settings_types/actions/__init__.py b/bot/exts/filtering/_settings_types/actions/__init__.py new file mode 100644 index 000000000..a8175b976 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import subclasses_in_package + +action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry) + +__all__ = [action_types] diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py new file mode 100644 index 000000000..508c09c2a --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -0,0 +1,255 @@ +from enum import Enum, auto +from typing import ClassVar + +import arrow +import discord.abc +from dateutil.relativedelta import relativedelta +from discord import Colour, Embed, Member, User +from discord.errors import Forbidden +from pydantic import validator +from pydis_core.utils.logging import get_logger +from pydis_core.utils.members import get_or_fetch_member +from typing_extensions import Self + +import bot as bot_module +from bot.constants import Channels +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import CustomIOField, FakeContext +from bot.utils.time import humanize_delta, parse_duration_string, relativedelta_to_timedelta + +log = get_logger(__name__) + +passive_form = { +    "BAN": "banned", +    "KICK": "kicked", +    "TIMEOUT": "timed out", +    "VOICE_MUTE": "voice muted", +    "SUPERSTAR": "superstarred", +    "WARNING": "warned", +    "WATCH": "watch", +    "NOTE": "noted", +} + + +class InfractionDuration(CustomIOField): +    """A field that converts a string to a duration and presents it in a human-readable format.""" + +    @classmethod +    def process_value(cls, v: str | relativedelta) -> relativedelta: +        """ +        Transform the given string into a relativedelta. + +        Raise a ValueError if the conversion is not possible. +        """ +        if isinstance(v, relativedelta): +            return v + +        try: +            v = float(v) +        except ValueError:  # Not a float. +            if not (delta := parse_duration_string(v)): +                raise ValueError(f"`{v}` is not a valid duration string.") +        else: +            delta = relativedelta(seconds=float(v)).normalized() + +        return delta + +    def serialize(self) -> float: +        """The serialized value is the total number of seconds this duration represents.""" +        return relativedelta_to_timedelta(self.value).total_seconds() + +    def __str__(self): +        """Represent the stored duration in a human-readable format.""" +        return humanize_delta(self.value, max_units=2) if self.value else "Permanent" + + +class Infraction(Enum): +    """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" + +    BAN = auto() +    KICK = auto() +    TIMEOUT = auto() +    VOICE_MUTE = auto() +    SUPERSTAR = auto() +    WARNING = auto() +    WATCH = auto() +    NOTE = auto() +    NONE = auto() + +    def __str__(self) -> str: +        return self.name + +    async def invoke( +        self, +        user: Member | User, +        message: discord.Message, +        channel: discord.abc.GuildChannel | discord.DMChannel, +        alerts_channel: discord.TextChannel, +        duration: InfractionDuration, +        reason: str +    ) -> None: +        """Invokes the command matching the infraction name.""" +        command_name = self.name.lower() +        command = bot_module.instance.get_command(command_name) +        if not command: +            await alerts_channel.send(f":warning: Could not apply {command_name} to {user.mention}: command not found.") +            log.warning(f":warning: Could not apply {command_name} to {user.mention}: command not found.") +            return + +        if isinstance(user, discord.User):  # For example because a message was sent in a DM. +            member = await get_or_fetch_member(channel.guild, user.id) +            if member: +                user = member +            else: +                log.warning( +                    f"The user {user} were set to receive an automatic {command_name}, " +                    "but they were not found in the guild." +                ) +                return + +        ctx = FakeContext(message, channel, command) +        if self.name in ("KICK", "WARNING", "WATCH", "NOTE"): +            await command(ctx, user, reason=reason or None) +        else: +            duration = arrow.utcnow().datetime + duration.value if duration.value else None +            await command(ctx, user, duration, reason=reason or None) + + +class InfractionAndNotification(ActionEntry): +    """ +    A setting entry which specifies what infraction to issue and the notification to DM the user. + +    Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. +    """ + +    name: ClassVar[str] = "infraction_and_notification" +    description: ClassVar[dict[str, str]] = { +        "infraction_type": ( +            "The type of infraction to issue when the filter triggers, or 'NONE'. " +            "If two infractions are triggered for the same message, " +            "the harsher one will be applied (by type or duration).\n\n" +            "Valid infraction types in order of harshness: " +        ) + ", ".join(infraction.name for infraction in Infraction), +        "infraction_duration": ( +            "How long the infraction should last for in seconds. 0 for permanent. " +            "Also supports durations as in an infraction invocation (such as `10d`)." +        ), +        "infraction_reason": "The reason delivered with the infraction.", +        "infraction_channel": ( +            "The channel ID in which to invoke the infraction (and send the confirmation message). " +            "If 0, the infraction will be sent in the context channel. If the ID otherwise fails to resolve, " +            "it will default to the mod-alerts channel." +        ), +        "dm_content": "The contents of a message to be DMed to the offending user. Doesn't send when invoked in DMs.", +        "dm_embed": "The contents of the embed to be DMed to the offending user. Doesn't send when invoked in DMs." +    } + +    dm_content: str +    dm_embed: str +    infraction_type: Infraction +    infraction_reason: str +    infraction_duration: InfractionDuration +    infraction_channel: int + +    @validator("infraction_type", pre=True) +    @classmethod +    def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction: +        """Convert the string to an Infraction by name.""" +        if isinstance(infr_type, Infraction): +            return infr_type +        return Infraction[infr_type.replace(" ", "_").upper()] + +    async def send_message(self, ctx: FilterContext) -> None: +        """Send the notification to the user.""" +        # If there is no infraction to apply, any DM contents already provided in the context take precedence. +        if self.infraction_type == Infraction.NONE and (ctx.dm_content or ctx.dm_embed): +            dm_content = ctx.dm_content +            dm_embed = ctx.dm_embed +        else: +            dm_content = self.dm_content +            dm_embed = self.dm_embed + +        if dm_content or dm_embed: +            formatting = {"domain": ctx.notification_domain} +            dm_content = f"Hey {ctx.author.mention}!\n{dm_content.format(**formatting)}" +            if dm_embed: +                dm_embed = Embed(description=dm_embed.format(**formatting), colour=Colour.og_blurple()) +            else: +                dm_embed = None + +            try: +                await ctx.author.send(dm_content, embed=dm_embed) +                ctx.action_descriptions.append("notified") +            except Forbidden: +                ctx.action_descriptions.append("failed to notify") + +    async def action(self, ctx: FilterContext) -> None: +        """Send the notification to the user, and apply any specified infractions.""" +        if ctx.in_guild:  # Don't DM the user for filters invoked in DMs. +            await self.send_message(ctx) + +        if self.infraction_type != Infraction.NONE: +            alerts_channel = bot_module.instance.get_channel(Channels.mod_alerts) +            if self.infraction_channel: +                channel = bot_module.instance.get_channel(self.infraction_channel) +                if not channel: +                    log.info(f"Could not find a channel with ID {self.infraction_channel}, infracting in mod-alerts.") +                    channel = alerts_channel +            elif not ctx.channel: +                channel = alerts_channel +            else: +                channel = ctx.channel +            if not channel:  # If somehow it's set to `alerts_channel` and it can't be found. +                log.error(f"Unable to apply infraction as the context channel {channel} can't be found.") +                return + +            await self.infraction_type.invoke( +                ctx.author, ctx.message, channel, alerts_channel, self.infraction_duration, self.infraction_reason +            ) +            ctx.action_descriptions.append(passive_form[self.infraction_type.name]) + +    def union(self, other: Self) -> Self: +        """ +        Combines two actions of the same type. Each type of action is executed once per filter. + +        If the infractions are different, take the data of the one higher up the hierarchy. + +        There is no clear way to properly combine several notification messages, especially when it's in two parts. +        To avoid bombarding the user with several notifications, the message with the more significant infraction +        is used. If the more significant infraction has no accompanying message, use the one from the other infraction, +        if it exists. +        """ +        # Lower number -> higher in the hierarchy +        if self.infraction_type is None: +            return other.copy() +        elif other.infraction_type is None: +            return self.copy() + +        if self.infraction_type.value < other.infraction_type.value: +            result = self.copy() +        elif self.infraction_type.value > other.infraction_type.value: +            result = other.copy() +            other = self +        else: +            now = arrow.utcnow().datetime +            if self.infraction_duration is None or ( +                other.infraction_duration is not None +                and now + self.infraction_duration.value > now + other.infraction_duration.value +            ): +                result = self.copy() +            else: +                result = other.copy() +                other = self + +        # If the winner has no message but the loser does, copy the message to the winner. +        result_overrides = result.overrides +        # Either take both or nothing, don't mix content from one filter and embed from another. +        if "dm_content" not in result_overrides and "dm_embed" not in result_overrides: +            other_overrides = other.overrides +            if "dm_content" in other_overrides: +                result.dm_content = other_overrides["dm_content"] +            if "dm_embed" in other_overrides: +                result.dm_content = other_overrides["dm_embed"] + +        return result diff --git a/bot/exts/filtering/_settings_types/actions/ping.py b/bot/exts/filtering/_settings_types/actions/ping.py new file mode 100644 index 000000000..ee40c54fe --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/ping.py @@ -0,0 +1,45 @@ +from typing import ClassVar + +from pydantic import validator +from typing_extensions import Self + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import resolve_mention + + +class Ping(ActionEntry): +    """A setting entry which adds the appropriate pings to the alert.""" + +    name: ClassVar[str] = "mentions" +    description: ClassVar[dict[str, str]] = { +        "guild_pings": ( +            "A list of role IDs/role names/user IDs/user names/here/everyone. " +            "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged." +        ), +        "dm_pings": ( +            "A list of role IDs/role names/user IDs/user names/here/everyone. " +            "If a mod-alert is generated for a filter triggered in DMs, these will be pinged." +        ) +    } + +    guild_pings: set[str] +    dm_pings: set[str] + +    @validator("*", pre=True) +    @classmethod +    def init_sequence_if_none(cls, pings: list[str] | None) -> list[str]: +        """Initialize an empty sequence if the value is None.""" +        if pings is None: +            return [] +        return pings + +    async def action(self, ctx: FilterContext) -> None: +        """Add the stored pings to the alert message content.""" +        mentions = self.guild_pings if not ctx.channel or ctx.channel.guild else self.dm_pings +        new_content = " ".join([resolve_mention(mention) for mention in mentions]) +        ctx.alert_content = f"{new_content} {ctx.alert_content}" + +    def union(self, other: Self) -> Self: +        """Combines two actions of the same type. Each type of action is executed once per filter.""" +        return Ping(guild_pings=self.guild_pings | other.guild_pings, dm_pings=self.dm_pings | other.dm_pings) diff --git a/bot/exts/filtering/_settings_types/actions/remove_context.py b/bot/exts/filtering/_settings_types/actions/remove_context.py new file mode 100644 index 000000000..5ec2613f4 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/remove_context.py @@ -0,0 +1,113 @@ +from collections import defaultdict +from typing import ClassVar + +from discord import Message +from discord.errors import HTTPException +from pydis_core.utils import scheduling +from pydis_core.utils.logging import get_logger +from typing_extensions import Self + +import bot +from bot.constants import Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import FakeContext +from bot.utils.messages import send_attachments + +log = get_logger(__name__) + +SUPERSTAR_REASON = ( +    "Your nickname was found to be in violation of our code of conduct. " +    "If you believe this is a mistake, please let us know." +) + + +async def upload_messages_attachments(ctx: FilterContext, messages: list[Message]) -> None: +    """Re-upload the messages' attachments for future logging.""" +    if not messages: +        return +    destination = messages[0].guild.get_channel(Channels.attachment_log) +    for message in messages: +        if message.attachments and message.id not in ctx.uploaded_attachments: +            ctx.uploaded_attachments[message.id] = await send_attachments(message, destination, link_large=False) + + +class RemoveContext(ActionEntry): +    """A setting entry which tells whether to delete the offending message(s).""" + +    name: ClassVar[str] = "remove_context" +    description: ClassVar[str] = ( +        "A boolean field. If True, the filter being triggered will cause the offending context to be removed. " +        "An offending message will be deleted, while an offending nickname will be superstarified." +    ) + +    remove_context: bool + +    async def action(self, ctx: FilterContext) -> None: +        """Remove the offending context.""" +        if not self.remove_context: +            return + +        if ctx.event in (Event.MESSAGE, Event.MESSAGE_EDIT): +            await self._handle_messages(ctx) +        elif ctx.event == Event.NICKNAME: +            await self._handle_nickname(ctx) + +    @staticmethod +    async def _handle_messages(ctx: FilterContext) -> None: +        """Delete any messages involved in this context.""" +        if not ctx.message or not ctx.message.guild: +            return + +        # If deletion somehow fails at least this will allow scheduling for deletion. +        ctx.messages_deletion = True +        channel_messages = defaultdict(set)  # Duplicates will cause batch deletion to fail. +        for message in {ctx.message} | ctx.related_messages: +            channel_messages[message.channel].add(message) + +        success = fail = 0 +        deleted = list() +        for channel, messages in channel_messages.items(): +            try: +                await channel.delete_messages(messages) +            except HTTPException: +                fail += len(messages) +            else: +                success += len(messages) +                deleted.extend(messages) +        scheduling.create_task(upload_messages_attachments(ctx, deleted)) + +        if not fail: +            if success == 1: +                ctx.action_descriptions.append("deleted") +            else: +                ctx.action_descriptions.append("deleted all") +        elif not success: +            if fail == 1: +                ctx.action_descriptions.append("failed to delete") +            else: +                ctx.action_descriptions.append("all failed to delete") +        else: +            ctx.action_descriptions.append(f"{success} deleted, {fail} failed to delete") + +    @staticmethod +    async def _handle_nickname(ctx: FilterContext) -> None: +        """Apply a superstar infraction to remove the user's nickname.""" +        alerts_channel = bot.instance.get_channel(Channels.mod_alerts) +        if not alerts_channel: +            log.error(f"Unable to apply superstar as the context channel {alerts_channel} can't be found.") +            return +        command = bot.instance.get_command("superstar") +        if not command: +            user = ctx.author +            await alerts_channel.send(f":warning: Could not apply superstar to {user.mention}: command not found.") +            log.warning(f":warning: Could not apply superstar to {user.mention}: command not found.") +            ctx.action_descriptions.append("failed to superstar") +            return + +        await command(FakeContext(ctx.message, alerts_channel, command), ctx.author, None, reason=SUPERSTAR_REASON) +        ctx.action_descriptions.append("superstar") + +    def union(self, other: Self) -> Self: +        """Combines two actions of the same type. Each type of action is executed once per filter.""" +        return RemoveContext(remove_context=self.remove_context or other.remove_context) diff --git a/bot/exts/filtering/_settings_types/actions/send_alert.py b/bot/exts/filtering/_settings_types/actions/send_alert.py new file mode 100644 index 000000000..f554cdd4d --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/send_alert.py @@ -0,0 +1,23 @@ +from typing import ClassVar + +from typing_extensions import Self + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class SendAlert(ActionEntry): +    """A setting entry which tells whether to send an alert message.""" + +    name: ClassVar[str] = "send_alert" +    description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created." + +    send_alert: bool + +    async def action(self, ctx: FilterContext) -> None: +        """Add the stored pings to the alert message content.""" +        ctx.send_alert = self.send_alert + +    def union(self, other: Self) -> Self: +        """Combines two actions of the same type. Each type of action is executed once per filter.""" +        return SendAlert(send_alert=self.send_alert or other.send_alert) diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py new file mode 100644 index 000000000..e41ef5c7a --- /dev/null +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, ClassVar, Union + +from pydantic import BaseModel, PrivateAttr +from typing_extensions import Self + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._utils import FieldRequiring + + +class SettingsEntry(BaseModel, FieldRequiring): +    """ +    A basic entry in the settings field appearing in every filter list and filter. + +    For a filter list, this is the default setting for it. For a filter, it's an override of the default entry. +    """ + +    # Each subclass must define a name matching the entry name we're expecting to receive from the database. +    # Names must be unique across all filter lists. +    name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE +    # Each subclass must define a description of what it does. If the data an entry type receives comprises +    # several DB fields, the value should a dictionary of field names and their descriptions. +    description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET + +    _overrides: set[str] = PrivateAttr(default_factory=set) + +    def __init__(self, defaults: SettingsEntry | None = None, /, **data): +        overrides = set() +        if defaults: +            defaults_dict = defaults.dict() +            for field_name, field_value in list(data.items()): +                if field_value is None: +                    data[field_name] = defaults_dict[field_name] +                else: +                    overrides.add(field_name) +        super().__init__(**data) +        self._overrides |= overrides + +    @property +    def overrides(self) -> dict[str, Any]: +        """Return a dictionary of overrides.""" +        return {name: getattr(self, name) for name in self._overrides} + +    @classmethod +    def create( +        cls, entry_data: dict[str, Any] | None, *, defaults: SettingsEntry | None = None, keep_empty: bool = False +    ) -> SettingsEntry | None: +        """ +        Returns a SettingsEntry object from `entry_data` if it holds any value, None otherwise. + +        Use this method to create SettingsEntry objects instead of the init. +        The None value is significant for how a filter list iterates over its filters. +        """ +        if entry_data is None: +            return None +        if not keep_empty and hasattr(entry_data, "values") and all(value is None for value in entry_data.values()): +            return None + +        if not isinstance(entry_data, dict): +            entry_data = {cls.name: entry_data} +        return cls(defaults, **entry_data) + + +class ValidationEntry(SettingsEntry): +    """A setting entry to validate whether the filter should be triggered in the given context.""" + +    @abstractmethod +    def triggers_on(self, ctx: FilterContext) -> bool: +        """Return whether the filter should be triggered with this setting in the given context.""" +        ... + + +class ActionEntry(SettingsEntry): +    """A setting entry defining what the bot should do if the filter it belongs to is triggered.""" + +    @abstractmethod +    async def action(self, ctx: FilterContext) -> None: +        """Execute an action that should be taken when the filter this setting belongs to is triggered.""" +        ... + +    @abstractmethod +    def union(self, other: Self) -> Self: +        """ +        Combine two actions of the same type. Each type of action is executed once per filter. + +        The following condition must hold: if self == other, then self | other == self. +        """ +        ... diff --git a/bot/exts/filtering/_settings_types/validations/__init__.py b/bot/exts/filtering/_settings_types/validations/__init__.py new file mode 100644 index 000000000..5c44e8b27 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry +from bot.exts.filtering._utils import subclasses_in_package + +validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry) + +__all__ = [validation_types] diff --git a/bot/exts/filtering/_settings_types/validations/bypass_roles.py b/bot/exts/filtering/_settings_types/validations/bypass_roles.py new file mode 100644 index 000000000..d42e6407c --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/bypass_roles.py @@ -0,0 +1,24 @@ +from typing import ClassVar, Union + +from discord import Member + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class RoleBypass(ValidationEntry): +    """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" + +    name: ClassVar[str] = "bypass_roles" +    description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter." + +    bypass_roles: set[Union[int, str]] + +    def triggers_on(self, ctx: FilterContext) -> bool: +        """Return whether the filter should be triggered on this user given their roles.""" +        if not isinstance(ctx.author, Member): +            return True +        return all( +            member_role.id not in self.bypass_roles and member_role.name not in self.bypass_roles +            for member_role in ctx.author.roles +        ) diff --git a/bot/exts/filtering/_settings_types/validations/channel_scope.py b/bot/exts/filtering/_settings_types/validations/channel_scope.py new file mode 100644 index 000000000..45b769d29 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/channel_scope.py @@ -0,0 +1,72 @@ +from typing import ClassVar, Union + +from pydantic import validator + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class ChannelScope(ValidationEntry): +    """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" + +    name: ClassVar[str] = "channel_scope" +    description: ClassVar[dict[str, str]] = { +        "disabled_channels": ( +            "A list of channel IDs or channel names. " +            "The filter will not trigger in these channels even if the category is expressly enabled." +        ), +        "disabled_categories": ( +            "A list of category IDs or category names. The filter will not trigger in these categories." +        ), +        "enabled_channels": ( +            "A list of channel IDs or channel names. " +            "The filter can trigger in these channels even if the category is disabled or not expressly enabled." +        ), +        "enabled_categories": ( +            "A list of category IDs or category names. " +            "If the list is not empty, filters will trigger only in channels of these categories, " +            "unless the channel is expressly disabled." +        ) +    } + +    # NOTE: Don't change this to use the new 3.10 union syntax unless you ensure Pydantic type validation and coercion +    # work properly. At the time of writing this code there's a difference. +    disabled_channels: set[Union[int, str]] +    disabled_categories: set[Union[int, str]] +    enabled_channels: set[Union[int, str]] +    enabled_categories: set[Union[int, str]] + +    @validator("*", pre=True) +    @classmethod +    def init_if_sequence_none(cls, sequence: list[str] | None) -> list[str]: +        """Initialize an empty sequence if the value is None.""" +        if sequence is None: +            return [] +        return sequence + +    def triggers_on(self, ctx: FilterContext) -> bool: +        """ +        Return whether the filter should be triggered in the given channel. + +        The filter is invoked by default. +        If the channel is explicitly enabled, it bypasses the set disabled channels and categories. +        """ +        channel = ctx.channel + +        if not channel: +            return True +        if not ctx.in_guild:  # This is not a guild channel, outside the scope of this setting. +            return True +        if hasattr(channel, "parent"): +            channel = channel.parent + +        enabled_channel = channel.id in self.enabled_channels or channel.name in self.enabled_channels +        disabled_channel = channel.id in self.disabled_channels or channel.name in self.disabled_channels +        enabled_category = channel.category and (not self.enabled_categories or ( +                channel.category.id in self.enabled_categories or channel.category.name in self.enabled_categories +        )) +        disabled_category = channel.category and ( +            channel.category.id in self.disabled_categories or channel.category.name in self.disabled_categories +        ) + +        return enabled_channel or (enabled_category and not disabled_channel and not disabled_category) diff --git a/bot/exts/filtering/_settings_types/validations/enabled.py b/bot/exts/filtering/_settings_types/validations/enabled.py new file mode 100644 index 000000000..3b5e3e446 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/enabled.py @@ -0,0 +1,19 @@ +from typing import ClassVar + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class Enabled(ValidationEntry): +    """A setting entry which tells whether the filter is enabled.""" + +    name: ClassVar[str] = "enabled" +    description: ClassVar[str] = ( +        "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." +    ) + +    enabled: bool + +    def triggers_on(self, ctx: FilterContext) -> bool: +        """Return whether the filter is enabled.""" +        return self.enabled diff --git a/bot/exts/filtering/_settings_types/validations/filter_dm.py b/bot/exts/filtering/_settings_types/validations/filter_dm.py new file mode 100644 index 000000000..9961984d6 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/filter_dm.py @@ -0,0 +1,20 @@ +from typing import ClassVar + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class FilterDM(ValidationEntry): +    """A setting entry which tells whether to apply the filter to DMs.""" + +    name: ClassVar[str] = "filter_dm" +    description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." + +    filter_dm: bool + +    def triggers_on(self, ctx: FilterContext) -> bool: +        """Return whether the filter should be triggered even if it was triggered in DMs.""" +        if not ctx.channel:  # No channel - out of scope for this setting. +            return True + +        return ctx.channel.guild is not None or self.filter_dm diff --git a/bot/exts/filtering/_ui/__init__.py b/bot/exts/filtering/_ui/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/filtering/_ui/__init__.py diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py new file mode 100644 index 000000000..5b23b71e9 --- /dev/null +++ b/bot/exts/filtering/_ui/filter.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +from typing import Any, Callable + +import discord +import discord.ui +from discord import Embed, Interaction, User +from discord.ext.commands import BadArgument +from discord.ui.select import SelectOption +from pydis_core.site_api import ResponseCodeError + +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._ui.ui import ( +    COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MAX_EMBED_DESCRIPTION, MISSING, SETTINGS_DELIMITER, +    SINGLE_SETTING_PATTERN, format_response_error, parse_value, populate_embed_from_dict +) +from bot.exts.filtering._utils import repr_equals, to_serializable +from bot.log import get_logger + +log = get_logger(__name__) + + +def build_filter_repr_dict( +    filter_list: FilterList, +    list_type: ListType, +    filter_type: type[Filter], +    settings_overrides: dict, +    extra_fields_overrides: dict +) -> dict: +    """Build a dictionary of field names and values to pass to `populate_embed_from_dict`.""" +    # Get filter list settings +    default_setting_values = {} +    for settings_group in filter_list[list_type].defaults: +        for _, setting in settings_group.items(): +            default_setting_values.update(to_serializable(setting.dict(), ui_repr=True)) + +    # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings. +    total_values = {} +    for name, value in default_setting_values.items(): +        if name not in settings_overrides or repr_equals(settings_overrides[name], value): +            total_values[name] = value +        else: +            total_values[f"{name}*"] = settings_overrides[name] + +    # Add the filter-specific settings. +    if filter_type.extra_fields_type: +        # This iterates over the default values of the extra fields model. +        for name, value in filter_type.extra_fields_type().dict().items(): +            if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value): +                total_values[f"{filter_type.name}/{name}"] = value +            else: +                total_values[f"{filter_type.name}/{name}*"] = extra_fields_overrides[name] + +    return total_values + + +class EditContentModal(discord.ui.Modal, title="Edit Content"): +    """A modal to input a filter's content.""" + +    content = discord.ui.TextInput(label="Content") + +    def __init__(self, embed_view: FilterEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new content.""" +        await interaction.response.defer() +        await self.embed_view.update_embed(self.message, content=self.content.value) + + +class EditDescriptionModal(discord.ui.Modal, title="Edit Description"): +    """A modal to input a filter's description.""" + +    description = discord.ui.TextInput(label="Description") + +    def __init__(self, embed_view: FilterEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await interaction.response.defer() +        await self.embed_view.update_embed(self.message, description=self.description.value) + + +class TemplateModal(discord.ui.Modal, title="Template"): +    """A modal to enter a filter ID to copy its overrides over.""" + +    template = discord.ui.TextInput(label="Template Filter ID") + +    def __init__(self, embed_view: FilterEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await self.embed_view.apply_template(self.template.value, self.message, interaction) + + +class FilterEditView(EditBaseView): +    """A view used to edit a filter's settings before updating the database.""" + +    class _REMOVE: +        """Sentinel value for when an override should be removed.""" + +    def __init__( +        self, +        filter_list: FilterList, +        list_type: ListType, +        filter_type: type[Filter], +        content: str | None, +        description: str | None, +        settings_overrides: dict, +        filter_settings_overrides: dict, +        loaded_settings: dict, +        loaded_filter_settings: dict, +        author: User, +        embed: Embed, +        confirm_callback: Callable +    ): +        super().__init__(author) +        self.filter_list = filter_list +        self.list_type = list_type +        self.filter_type = filter_type +        self.content = content +        self.description = description +        self.settings_overrides = settings_overrides +        self.filter_settings_overrides = filter_settings_overrides +        self.loaded_settings = loaded_settings +        self.loaded_filter_settings = loaded_filter_settings +        self.embed = embed +        self.confirm_callback = confirm_callback + +        all_settings_repr_dict = build_filter_repr_dict( +            filter_list, list_type, filter_type, settings_overrides, filter_settings_overrides +        ) +        populate_embed_from_dict(embed, all_settings_repr_dict) + +        self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} +        self.type_per_setting_name.update({ +            f"{filter_type.name}/{name}": type_ +            for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items() +        }) + +        add_select = CustomCallbackSelect( +            self._prompt_new_value, +            placeholder="Select a setting to edit", +            options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)], +            row=1 +        ) +        self.add_item(add_select) + +        if settings_overrides or filter_settings_overrides: +            override_names = ( +                list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides] +            ) +            remove_select = CustomCallbackSelect( +                self._remove_override, +                placeholder="Select an override to remove", +                options=[SelectOption(label=name) for name in sorted(override_names)], +                row=2 +            ) +            self.add_item(remove_select) + +    @discord.ui.button(label="Edit Content", row=3) +    async def edit_content(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to edit the filter's content. Pressing the button invokes a modal.""" +        modal = EditContentModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="Edit Description", row=3) +    async def edit_description(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to edit the filter's description. Pressing the button invokes a modal.""" +        modal = EditDescriptionModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="Empty Description", row=3) +    async def empty_description(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to empty the filter's description.""" +        await self.update_embed(interaction, description=self._REMOVE) + +    @discord.ui.button(label="Template", row=3) +    async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to enter a filter template ID and copy its overrides over.""" +        modal = TemplateModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=4) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Confirm the content, description, and settings, and update the filters database.""" +        if self.content is None: +            await interaction.response.send_message( +                ":x: Cannot add a filter with no content.", ephemeral=True, reference=interaction.message +            ) +        if self.description is None: +            self.description = "" +        await interaction.response.edit_message(view=None)  # Make sure the interaction succeeds first. +        try: +            await self.confirm_callback( +                interaction.message, +                self.filter_list, +                self.list_type, +                self.filter_type, +                self.content, +                self.description, +                self.settings_overrides, +                self.filter_settings_overrides +            ) +        except ResponseCodeError as e: +            await interaction.message.reply(embed=format_response_error(e)) +            await interaction.message.edit(view=self) +        except BadArgument as e: +            await interaction.message.reply( +                embed=Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e)) +            ) +            await interaction.message.edit(view=self) +        else: +            self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=4) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the operation.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) +        self.stop() + +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" +        if setting_name in self.settings_overrides: +            return self.settings_overrides[setting_name] +        if "/" in setting_name: +            _, setting_name = setting_name.split("/", maxsplit=1) +            if setting_name in self.filter_settings_overrides: +                return self.filter_settings_overrides[setting_name] +        return MISSING + +    async def update_embed( +        self, +        interaction_or_msg: discord.Interaction | discord.Message, +        *, +        content: str | None = None, +        description: str | type[FilterEditView._REMOVE] | None = None, +        setting_name: str | None = None, +        setting_value: str | type[FilterEditView._REMOVE] | None = None, +    ) -> None: +        """ +        Update the embed with the new information. + +        If a setting name is provided with a _REMOVE value, remove the override. +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ +        if content is not None or description is not None: +            if content is not None: +                filter_type = self.filter_list.get_filter_type(content) +                if not filter_type: +                    if isinstance(interaction_or_msg, discord.Message): +                        send_method = interaction_or_msg.channel.send +                    else: +                        send_method = interaction_or_msg.response.send_message +                    await send_method(f":x: Could not find a filter type appropriate for `{content}`.") +                    return +                self.content = content +                self.filter_type = filter_type +            else: +                content = self.content  # If there's no content or description, use the existing values. +            if description is self._REMOVE: +                self.description = None +            elif description is not None: +                self.description = description +            else: +                description = self.description + +            # Update the embed with the new content and/or description. +            self.embed.description = f"`{content}`" if content else "*No content*" +            if description and description is not self._REMOVE: +                self.embed.description += f" - {description}" +            if len(self.embed.description) > MAX_EMBED_DESCRIPTION: +                self.embed.description = self.embed.description[:MAX_EMBED_DESCRIPTION - 5] + "[...]" + +        if setting_name: +            # Find the right dictionary to update. +            if "/" in setting_name: +                filter_name, setting_name = setting_name.split("/", maxsplit=1) +                dict_to_edit = self.filter_settings_overrides +                default_value = self.filter_type.extra_fields_type().dict()[setting_name] +            else: +                dict_to_edit = self.settings_overrides +                default_value = self.filter_list[self.list_type].default(setting_name) +            # Update the setting override value or remove it +            if setting_value is not self._REMOVE: +                if not repr_equals(setting_value, default_value): +                    dict_to_edit[setting_name] = setting_value +                # If there's already an override, remove it, since the new value is the same as the default. +                elif setting_name in dict_to_edit: +                    dict_to_edit.pop(setting_name) +            elif setting_name in dict_to_edit: +                dict_to_edit.pop(setting_name) + +        # This is inefficient, but otherwise the selects go insane if the user attempts to edit the same setting +        # multiple times, even when replacing the select with a new one. +        self.embed.clear_fields() +        new_view = self.copy() + +        try: +            if isinstance(interaction_or_msg, discord.Interaction): +                await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) +            else: +                await interaction_or_msg.edit(embed=self.embed, view=new_view) +        except discord.errors.HTTPException:  # Various unexpected errors. +            pass +        else: +            self.stop() + +    async def edit_setting_override(self, interaction: Interaction, setting_name: str, override_value: Any) -> None: +        """ +        Update the overrides with the new value and edit the embed. + +        The interaction needs to be the selection of the setting attached to the embed. +        """ +        await self.update_embed(interaction, setting_name=setting_name, setting_value=override_value) + +    async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None: +        """Replace any non-overridden settings with overrides from the given filter.""" +        try: +            settings, filter_settings = template_settings( +                template_id, self.filter_list, self.list_type, self.filter_type +            ) +        except BadArgument as e:  # The interaction object is necessary to send an ephemeral message. +            await interaction.response.send_message(f":x: {e}", ephemeral=True) +            return +        else: +            await interaction.response.defer() + +        self.settings_overrides = settings | self.settings_overrides +        self.filter_settings_overrides = filter_settings | self.filter_settings_overrides +        self.embed.clear_fields() +        await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop() + +    async def _remove_override(self, interaction: Interaction, select: discord.ui.Select) -> None: +        """ +        Remove the override for the setting the user selected, and edit the embed. + +        The interaction needs to be the selection of the setting attached to the embed. +        """ +        await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE) + +    def copy(self) -> FilterEditView: +        """Create a copy of this view.""" +        return FilterEditView( +            self.filter_list, +            self.list_type, +            self.filter_type, +            self.content, +            self.description, +            self.settings_overrides, +            self.filter_settings_overrides, +            self.loaded_settings, +            self.loaded_filter_settings, +            self.author, +            self.embed, +            self.confirm_callback +        ) + + +def description_and_settings_converter( +    filter_list: FilterList, +    list_type: ListType, +    filter_type: type[Filter], +    loaded_settings: dict, +    loaded_filter_settings: dict, +    input_data: str +) -> tuple[str, dict[str, Any], dict[str, Any]]: +    """Parse a string representing a possible description and setting overrides, and validate the setting names.""" +    if not input_data: +        return "", {}, {} + +    parsed = SETTINGS_DELIMITER.split(input_data) +    if not parsed: +        return "", {}, {} + +    description = "" +    if not SINGLE_SETTING_PATTERN.match(parsed[0]): +        description, *parsed = parsed + +    settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    template = None +    if "--template" in settings: +        template = settings.pop("--template") + +    filter_settings = {} +    for setting, _ in list(settings.items()): +        if setting in loaded_settings:  # It's a filter list setting +            type_ = loaded_settings[setting][2] +            try: +                parsed_value = parse_value(settings.pop(setting), type_) +                if not repr_equals(parsed_value, filter_list[list_type].default(setting)): +                    settings[setting] = parsed_value +            except (TypeError, ValueError) as e: +                raise BadArgument(e) +        elif "/" not in setting: +            raise BadArgument(f"{setting!r} is not a recognized setting.") +        else:  # It's a filter setting +            filter_name, filter_setting_name = setting.split("/", maxsplit=1) +            if filter_name.lower() != filter_type.name.lower(): +                raise BadArgument( +                    f"A setting for a {filter_name!r} filter was provided, but the filter name is {filter_type.name!r}" +                ) +            if filter_setting_name not in loaded_filter_settings[filter_type.name]: +                raise BadArgument(f"{setting!r} is not a recognized setting.") +            type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] +            try: +                parsed_value = parse_value(settings.pop(setting), type_) +                if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): +                    filter_settings[filter_setting_name] = parsed_value +            except (TypeError, ValueError) as e: +                raise BadArgument(e) + +    # Pull templates settings and apply them. +    if template is not None: +        try: +            t_settings, t_filter_settings = template_settings(template, filter_list, list_type, filter_type) +        except ValueError as e: +            raise BadArgument(str(e)) +        else: +            # The specified settings go on top of the template +            settings = t_settings | settings +            filter_settings = t_filter_settings | filter_settings + +    return description, settings, filter_settings + + +def filter_overrides_for_ui(filter_: Filter) -> tuple[dict, dict]: +    """Get the filter's overrides in a format that can be displayed in the UI.""" +    overrides_values, extra_fields_overrides = filter_.overrides +    return to_serializable(overrides_values, ui_repr=True), to_serializable(extra_fields_overrides, ui_repr=True) + + +def template_settings( +    filter_id: str, filter_list: FilterList, list_type: ListType, filter_type: type[Filter] +) -> tuple[dict, dict]: +    """Find the filter with specified ID, and return its settings.""" +    try: +        filter_id = int(filter_id) +        if filter_id < 0: +            raise ValueError() +    except ValueError: +        raise BadArgument("Template value must be a non-negative integer.") + +    if filter_id not in filter_list[list_type].filters: +        raise BadArgument( +            f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list." +        ) +    filter_ = filter_list[list_type].filters[filter_id] + +    if not isinstance(filter_, filter_type): +        raise BadArgument( +            f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}" +        ) +    return filter_.overrides diff --git a/bot/exts/filtering/_ui/filter_list.py b/bot/exts/filtering/_ui/filter_list.py new file mode 100644 index 000000000..4d6f76a89 --- /dev/null +++ b/bot/exts/filtering/_ui/filter_list.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +from typing import Any, Callable + +import discord +from discord import Embed, Interaction, SelectOption, User +from discord.ext.commands import BadArgument +from pydis_core.site_api import ResponseCodeError + +from bot.exts.filtering._filter_lists import FilterList, ListType +from bot.exts.filtering._ui.ui import ( +    CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, format_response_error, parse_value, +    populate_embed_from_dict +) +from bot.exts.filtering._utils import repr_equals, to_serializable + + +def settings_converter(loaded_settings: dict, input_data: str) -> dict[str, Any]: +    """Parse a string representing settings, and validate the setting names.""" +    if not input_data: +        return {} + +    parsed = SETTINGS_DELIMITER.split(input_data) +    if not parsed: +        return {} + +    try: +        settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    except ValueError: +        raise BadArgument("The settings provided are not in the correct format.") + +    for setting in settings: +        if setting not in loaded_settings: +            raise BadArgument(f"{setting!r} is not a recognized setting.") +        else: +            type_ = loaded_settings[setting][2] +            try: +                parsed_value = parse_value(settings.pop(setting), type_) +                settings[setting] = parsed_value +            except (TypeError, ValueError) as e: +                raise BadArgument(e) + +    return settings + + +def build_filterlist_repr_dict(filter_list: FilterList, list_type: ListType, new_settings: dict) -> dict: +    """Build a dictionary of field names and values to pass to `_build_embed_from_dict`.""" +    # Get filter list settings +    default_setting_values = {} +    for settings_group in filter_list[list_type].defaults: +        for _, setting in settings_group.items(): +            default_setting_values.update(to_serializable(setting.dict(), ui_repr=True)) + +    # Add new values. It's done in this way to preserve field order, since the new_values won't have all settings. +    total_values = {} +    for name, value in default_setting_values.items(): +        if name not in new_settings or repr_equals(new_settings[name], value): +            total_values[name] = value +        else: +            total_values[f"{name}~"] = new_settings[name] + +    return total_values + + +class FilterListAddView(EditBaseView): +    """A view used to add a new filter list.""" + +    def __init__( +        self, +        list_name: str, +        list_type: ListType, +        settings: dict, +        loaded_settings: dict, +        author: User, +        embed: Embed, +        confirm_callback: Callable +    ): +        super().__init__(author) +        self.list_name = list_name +        self.list_type = list_type +        self.settings = settings +        self.loaded_settings = loaded_settings +        self.embed = embed +        self.confirm_callback = confirm_callback + +        self.settings_repr_dict = {name: to_serializable(value) for name, value in settings.items()} +        populate_embed_from_dict(embed, self.settings_repr_dict) + +        self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} + +        edit_select = CustomCallbackSelect( +            self._prompt_new_value, +            placeholder="Select a setting to edit", +            options=[SelectOption(label=name) for name in sorted(settings)], +            row=0 +        ) +        self.add_item(edit_select) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=1) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Confirm the content, description, and settings, and update the filters database.""" +        await interaction.response.edit_message(view=None)  # Make sure the interaction succeeds first. +        try: +            await self.confirm_callback(interaction.message, self.list_name, self.list_type, self.settings) +        except ResponseCodeError as e: +            await interaction.message.reply(embed=format_response_error(e)) +            await interaction.message.edit(view=self) +        else: +            self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=1) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the operation.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) +        self.stop() + +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" +        if setting_name in self.settings: +            return self.settings[setting_name] +        return MISSING + +    async def update_embed( +        self, +        interaction_or_msg: discord.Interaction | discord.Message, +        *, +        setting_name: str | None = None, +        setting_value: str | None = None, +    ) -> None: +        """ +        Update the embed with the new information. + +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ +        if not setting_name:  # Obligatory check to match the signature in the parent class. +            return + +        self.settings[setting_name] = setting_value + +        self.embed.clear_fields() +        new_view = self.copy() + +        try: +            if isinstance(interaction_or_msg, discord.Interaction): +                await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) +            else: +                await interaction_or_msg.edit(embed=self.embed, view=new_view) +        except discord.errors.HTTPException:  # Various unexpected errors. +            pass +        else: +            self.stop() + +    def copy(self) -> FilterListAddView: +        """Create a copy of this view.""" +        return FilterListAddView( +            self.list_name, +            self.list_type, +            self.settings, +            self.loaded_settings, +            self.author, +            self.embed, +            self.confirm_callback +        ) + + +class FilterListEditView(EditBaseView): +    """A view used to edit a filter list's settings before updating the database.""" + +    def __init__( +        self, +        filter_list: FilterList, +        list_type: ListType, +        new_settings: dict, +        loaded_settings: dict, +        author: User, +        embed: Embed, +        confirm_callback: Callable +    ): +        super().__init__(author) +        self.filter_list = filter_list +        self.list_type = list_type +        self.settings = new_settings +        self.loaded_settings = loaded_settings +        self.embed = embed +        self.confirm_callback = confirm_callback + +        self.settings_repr_dict = build_filterlist_repr_dict(filter_list, list_type, new_settings) +        populate_embed_from_dict(embed, self.settings_repr_dict) + +        self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} + +        edit_select = CustomCallbackSelect( +            self._prompt_new_value, +            placeholder="Select a setting to edit", +            options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)], +            row=0 +        ) +        self.add_item(edit_select) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=1) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Confirm the content, description, and settings, and update the filters database.""" +        await interaction.response.edit_message(view=None)  # Make sure the interaction succeeds first. +        try: +            await self.confirm_callback(interaction.message, self.filter_list, self.list_type, self.settings) +        except ResponseCodeError as e: +            await interaction.message.reply(embed=format_response_error(e)) +            await interaction.message.edit(view=self) +        else: +            self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=1) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the operation.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) +        self.stop() + +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" +        if setting_name in self.settings: +            return self.settings[setting_name] +        if setting_name in self.settings_repr_dict: +            return self.settings_repr_dict[setting_name] +        return MISSING + +    async def update_embed( +        self, +        interaction_or_msg: discord.Interaction | discord.Message, +        *, +        setting_name: str | None = None, +        setting_value: str | None = None, +    ) -> None: +        """ +        Update the embed with the new information. + +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ +        if not setting_name:  # Obligatory check to match the signature in the parent class. +            return + +        default_value = self.filter_list[self.list_type].default(setting_name) +        if not repr_equals(setting_value, default_value): +            self.settings[setting_name] = setting_value +        # If there's already a new value, remove it, since the new value is the same as the default. +        elif setting_name in self.settings: +            self.settings.pop(setting_name) + +        self.embed.clear_fields() +        new_view = self.copy() + +        try: +            if isinstance(interaction_or_msg, discord.Interaction): +                await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) +            else: +                await interaction_or_msg.edit(embed=self.embed, view=new_view) +        except discord.errors.HTTPException:  # Various errors such as embed description being too long. +            pass +        else: +            self.stop() + +    def copy(self) -> FilterListEditView: +        """Create a copy of this view.""" +        return FilterListEditView( +            self.filter_list, +            self.list_type, +            self.settings, +            self.loaded_settings, +            self.author, +            self.embed, +            self.confirm_callback +        ) diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py new file mode 100644 index 000000000..dba7f3cea --- /dev/null +++ b/bot/exts/filtering/_ui/search.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import discord +from discord import Interaction, SelectOption +from discord.ext.commands import BadArgument + +from bot.exts.filtering._filter_lists import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings_types.settings_entry import SettingsEntry +from bot.exts.filtering._ui.filter import filter_overrides_for_ui +from bot.exts.filtering._ui.ui import ( +    COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value, +    populate_embed_from_dict +) + + +def search_criteria_converter( +    filter_lists: dict, +    loaded_filters: dict, +    loaded_settings: dict, +    loaded_filter_settings: dict, +    filter_type: type[Filter] | None, +    input_data: str +) -> tuple[dict[str, Any], dict[str, Any], type[Filter]]: +    """Parse a string representing setting overrides, and validate the setting names.""" +    if not input_data: +        return {}, {}, filter_type + +    parsed = SETTINGS_DELIMITER.split(input_data) +    if not parsed: +        return {}, {}, filter_type + +    try: +        settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    except ValueError: +        raise BadArgument("The settings provided are not in the correct format.") + +    template = None +    if "--template" in settings: +        template = settings.pop("--template") + +    filter_settings = {} +    for setting, _ in list(settings.items()): +        if setting in loaded_settings:  # It's a filter list setting +            type_ = loaded_settings[setting][2] +            try: +                settings[setting] = parse_value(settings[setting], type_) +            except (TypeError, ValueError) as e: +                raise BadArgument(e) +        elif "/" not in setting: +            raise BadArgument(f"{setting!r} is not a recognized setting.") +        else:  # It's a filter setting +            filter_name, filter_setting_name = setting.split("/", maxsplit=1) +            if not filter_type: +                if filter_name in loaded_filters: +                    filter_type = loaded_filters[filter_name] +                else: +                    raise BadArgument(f"There's no filter type named {filter_name!r}.") +            if filter_name.lower() != filter_type.name.lower(): +                raise BadArgument( +                    f"A setting for a {filter_name!r} filter was provided, " +                    f"but the filter name is {filter_type.name!r}" +                ) +            if filter_setting_name not in loaded_filter_settings[filter_type.name]: +                raise BadArgument(f"{setting!r} is not a recognized setting.") +            type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] +            try: +                filter_settings[filter_setting_name] = parse_value(settings.pop(setting), type_) +            except (TypeError, ValueError) as e: +                raise BadArgument(e) + +    # Pull templates settings and apply them. +    if template is not None: +        try: +            t_settings, t_filter_settings, filter_type = template_settings(template, filter_lists, filter_type) +        except ValueError as e: +            raise BadArgument(str(e)) +        else: +            # The specified settings go on top of the template +            settings = t_settings | settings +            filter_settings = t_filter_settings | filter_settings + +    return settings, filter_settings, filter_type + + +def get_filter(filter_id: int, filter_lists: dict) -> tuple[Filter, FilterList, ListType] | None: +    """Return a filter with the specific filter_id, if found.""" +    for filter_list in filter_lists.values(): +        for list_type, sublist in filter_list.items(): +            if filter_id in sublist.filters: +                return sublist.filters[filter_id], filter_list, list_type +    return None + + +def template_settings( +    filter_id: str, filter_lists: dict, filter_type: type[Filter] | None +) -> tuple[dict, dict, type[Filter]]: +    """Find a filter with the specified ID and filter type, and return its settings and (maybe newly found) type.""" +    try: +        filter_id = int(filter_id) +        if filter_id < 0: +            raise ValueError() +    except ValueError: +        raise BadArgument("Template value must be a non-negative integer.") + +    result = get_filter(filter_id, filter_lists) +    if not result: +        raise BadArgument(f"Could not find a filter with ID `{filter_id}`.") +    filter_, filter_list, list_type = result + +    if filter_type and not isinstance(filter_, filter_type): +        raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") + +    settings, filter_settings = filter_overrides_for_ui(filter_) +    return settings, filter_settings, type(filter_) + + +def build_search_repr_dict( +    settings: dict[str, Any], filter_settings: dict[str, Any], filter_type: type[Filter] | None +) -> dict: +    """Build a dictionary of field names and values to pass to `populate_embed_from_dict`.""" +    total_values = settings.copy() +    if filter_type: +        for setting_name, value in filter_settings.items(): +            total_values[f"{filter_type.name}/{setting_name}"] = value + +    return total_values + + +class SearchEditView(EditBaseView): +    """A view used to edit the search criteria before performing the search.""" + +    class _REMOVE: +        """Sentinel value for when an override should be removed.""" + +    def __init__( +        self, +        filter_type: type[Filter] | None, +        settings: dict[str, Any], +        filter_settings: dict[str, Any], +        loaded_filter_lists: dict[str, FilterList], +        loaded_filters: dict[str, type[Filter]], +        loaded_settings: dict[str, tuple[str, SettingsEntry, type]], +        loaded_filter_settings: dict[str, dict[str, tuple[str, SettingsEntry, type]]], +        author: discord.User | discord.Member, +        embed: discord.Embed, +        confirm_callback: Callable +    ): +        super().__init__(author) +        self.filter_type = filter_type +        self.settings = settings +        self.filter_settings = filter_settings +        self.loaded_filter_lists = loaded_filter_lists +        self.loaded_filters = loaded_filters +        self.loaded_settings = loaded_settings +        self.loaded_filter_settings = loaded_filter_settings +        self.embed = embed +        self.confirm_callback = confirm_callback + +        title = "Filters Search" +        if filter_type: +            title += f" - {filter_type.name.title()}" +        embed.set_author(name=title) + +        settings_repr_dict = build_search_repr_dict(settings, filter_settings, filter_type) +        populate_embed_from_dict(embed, settings_repr_dict) + +        self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} +        if filter_type: +            self.type_per_setting_name.update({ +                f"{filter_type.name}/{name}": type_ +                for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items() +            }) + +        add_select = CustomCallbackSelect( +            self._prompt_new_value, +            placeholder="Add or edit criterion", +            options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)], +            row=0 +        ) +        self.add_item(add_select) + +        if settings_repr_dict: +            remove_select = CustomCallbackSelect( +                self._remove_criterion, +                placeholder="Select a criterion to remove", +                options=[SelectOption(label=name) for name in sorted(settings_repr_dict)], +                row=1 +            ) +            self.add_item(remove_select) + +    @discord.ui.button(label="Template", row=2) +    async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to enter a filter template ID and copy its overrides over.""" +        modal = TemplateModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="Filter Type", row=2) +    async def enter_filter_type(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to enter a filter type.""" +        modal = FilterTypeModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=3) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Confirm the search criteria and perform the search.""" +        await interaction.response.edit_message(view=None)  # Make sure the interaction succeeds first. +        try: +            await self.confirm_callback(interaction.message, self.filter_type, self.settings, self.filter_settings) +        except BadArgument as e: +            await interaction.message.reply( +                embed=discord.Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e)) +            ) +            await interaction.message.edit(view=self) +        else: +            self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=3) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the operation.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) +        self.stop() + +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" +        if setting_name in self.settings: +            return self.settings[setting_name] +        if "/" in setting_name: +            _, setting_name = setting_name.split("/", maxsplit=1) +            if setting_name in self.filter_settings: +                return self.filter_settings[setting_name] +        return MISSING + +    async def update_embed( +        self, +        interaction_or_msg: discord.Interaction | discord.Message, +        *, +        setting_name: str | None = None, +        setting_value: str | type[SearchEditView._REMOVE] | None = None, +    ) -> None: +        """ +        Update the embed with the new information. + +        If a setting name is provided with a _REMOVE value, remove the override. +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ +        if not setting_name:  # Can be None just to make the function signature compatible with the parent class. +            return + +        if "/" in setting_name: +            filter_name, setting_name = setting_name.split("/", maxsplit=1) +            dict_to_edit = self.filter_settings +        else: +            dict_to_edit = self.settings + +        # Update the criterion value or remove it +        if setting_value is not self._REMOVE: +            dict_to_edit[setting_name] = setting_value +        elif setting_name in dict_to_edit: +            dict_to_edit.pop(setting_name) + +        self.embed.clear_fields() +        new_view = self.copy() + +        try: +            if isinstance(interaction_or_msg, discord.Interaction): +                await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) +            else: +                await interaction_or_msg.edit(embed=self.embed, view=new_view) +        except discord.errors.HTTPException:  # Just in case of faulty input. +            pass +        else: +            self.stop() + +    async def _remove_criterion(self, interaction: Interaction, select: discord.ui.Select) -> None: +        """ +        Remove the criterion the user selected, and edit the embed. + +        The interaction needs to be the selection of the setting attached to the embed. +        """ +        await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE) + +    async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None: +        """Set any unset criteria with settings values from the given filter.""" +        try: +            settings, filter_settings, self.filter_type = template_settings( +                template_id, self.loaded_filter_lists, self.filter_type +            ) +        except BadArgument as e:  # The interaction object is necessary to send an ephemeral message. +            await interaction.response.send_message(f":x: {e}", ephemeral=True) +            return +        else: +            await interaction.response.defer() + +        self.settings = settings | self.settings +        self.filter_settings = filter_settings | self.filter_settings +        self.embed.clear_fields() +        await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop() + +    async def apply_filter_type(self, type_name: str, embed_message: discord.Message, interaction: Interaction) -> None: +        """Set a new filter type and reset any criteria for settings of the old filter type.""" +        if type_name.lower() not in self.loaded_filters: +            if type_name.lower()[:-1] not in self.loaded_filters:  # In case the user entered the plural form. +                await interaction.response.send_message(f":x: No such filter type {type_name!r}.", ephemeral=True) +                return +            type_name = type_name[:-1] +        type_name = type_name.lower() +        await interaction.response.defer() + +        if self.filter_type and type_name == self.filter_type.name: +            return +        self.filter_type = self.loaded_filters[type_name] +        self.filter_settings = {} +        self.embed.clear_fields() +        await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop() + +    def copy(self) -> SearchEditView: +        """Create a copy of this view.""" +        return SearchEditView( +            self.filter_type, +            self.settings, +            self.filter_settings, +            self.loaded_filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            self.author, +            self.embed, +            self.confirm_callback +        ) + + +class TemplateModal(discord.ui.Modal, title="Template"): +    """A modal to enter a filter ID to copy its overrides over.""" + +    template = discord.ui.TextInput(label="Template Filter ID", required=False) + +    def __init__(self, embed_view: SearchEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await self.embed_view.apply_template(self.template.value, self.message, interaction) + + +class FilterTypeModal(discord.ui.Modal, title="Template"): +    """A modal to enter a filter ID to copy its overrides over.""" + +    filter_type = discord.ui.TextInput(label="Filter Type") + +    def __init__(self, embed_view: SearchEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await self.embed_view.apply_filter_type(self.filter_type.value, self.message, interaction) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py new file mode 100644 index 000000000..0de511f03 --- /dev/null +++ b/bot/exts/filtering/_ui/ui.py @@ -0,0 +1,572 @@ +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from collections.abc import Iterable +from enum import EnumMeta +from functools import partial +from typing import Any, Callable, Coroutine, Optional, TypeVar, get_origin + +import discord +from discord import Embed, Interaction +from discord.ext.commands import Context +from discord.ui.select import MISSING as SELECT_MISSING, SelectOption +from discord.utils import escape_markdown +from pydis_core.site_api import ResponseCodeError +from pydis_core.utils import scheduling +from pydis_core.utils.logging import get_logger +from pydis_core.utils.members import get_or_fetch_member + +import bot +from bot.constants import Colours +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._filter_lists import FilterList +from bot.exts.filtering._utils import FakeContext, normalize_type +from bot.utils.messages import format_channel, format_user, upload_log + +log = get_logger(__name__) + + +# Max number of characters in a Discord embed field value, minus 6 characters for a placeholder. +MAX_FIELD_SIZE = 1018 +# Max number of characters for an embed field's value before it should take its own line. +MAX_INLINE_SIZE = 50 +# Number of seconds before a settings editing view timeout. +EDIT_TIMEOUT = 600 +# Number of seconds before timeout of an editing component. +COMPONENT_TIMEOUT = 180 +# Amount of seconds to confirm the operation. +DELETION_TIMEOUT = 60 +# Max length of modal title +MAX_MODAL_TITLE_LENGTH = 45 +# Max number of items in a select +MAX_SELECT_ITEMS = 25 +MAX_EMBED_DESCRIPTION = 4080 +# Number of seconds before timeout of the alert view +ALERT_VIEW_TIMEOUT = 3600 + +SETTINGS_DELIMITER = re.compile(r"\s+(?=\S+=\S+)") +SINGLE_SETTING_PATTERN = re.compile(r"(--)?[\w/]+=.+") + +EDIT_CONFIRMED_MESSAGE = "✅ Edit for `{0}` confirmed" + +# Sentinel value to denote that a value is missing +MISSING = object() + +T = TypeVar('T') + + +async def _build_alert_message_content(ctx: FilterContext, current_message_length: int) -> str: +    """Build the content section of the alert.""" +    # For multiple messages and those with attachments or excessive newlines, use the logs API +    if ctx.messages_deletion and ctx.upload_deletion_logs and any(( +        ctx.related_messages, +        len(ctx.uploaded_attachments) > 0, +        ctx.content.count('\n') > 15 +    )): +        url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.uploaded_attachments) +        return f"A complete log of the offending messages can be found [here]({url})" + +    alert_content = escape_markdown(ctx.content) +    remaining_chars = MAX_EMBED_DESCRIPTION - current_message_length + +    if len(alert_content) > remaining_chars: +        if ctx.messages_deletion and ctx.upload_deletion_logs: +            url = await upload_log([ctx.message], bot.instance.user.id, ctx.uploaded_attachments) +            log_site_msg = f"The full message can be found [here]({url})" +            # 7 because that's the length of "[...]\n\n" +            return alert_content[:remaining_chars - (7 + len(log_site_msg))] + "[...]\n\n" + log_site_msg +        else: +            return alert_content[:remaining_chars - 5] + "[...]" + +    return alert_content + + +async def build_mod_alert(ctx: FilterContext, triggered_filters: dict[FilterList, Iterable[str]]) -> Embed: +    """Build an alert message from the filter context.""" +    embed = Embed(color=Colours.soft_orange) +    embed.set_thumbnail(url=ctx.author.display_avatar.url) +    triggered_by = f"**Triggered by:** {format_user(ctx.author)}" +    if ctx.channel: +        if ctx.channel.guild: +            triggered_in = f"**Triggered in:** {format_channel(ctx.channel)}\n" +        else: +            triggered_in = "**Triggered in:** :warning:**DM**:warning:\n" +        if len(ctx.related_channels) > 1: +            triggered_in += f"**Channels:** {', '.join(channel.mention for channel in ctx.related_channels)}\n" +    else: +        triggered_by += "\n" +        triggered_in = "" + +    filters = [] +    for filter_list, list_message in triggered_filters.items(): +        if list_message: +            filters.append(f"**{filter_list.name.title()} Filters:** {', '.join(list_message)}") +    filters = "\n".join(filters) + +    matches = "**Matches:** " + escape_markdown(", ".join(repr(match) for match in ctx.matches)) if ctx.matches else "" +    actions = "\n**Actions Taken:** " + (", ".join(ctx.action_descriptions) if ctx.action_descriptions else "-") + +    mod_alert_message = "\n".join(part for part in (triggered_by, triggered_in, filters, matches, actions) if part) +    log.debug(f"{ctx.event.name} Filter:\n{mod_alert_message}") + +    if ctx.message: +        mod_alert_message += f"\n**[Original Content]({ctx.message.jump_url})**:\n" +    else: +        mod_alert_message += "\n**Original Content**:\n" +    mod_alert_message += await _build_alert_message_content(ctx, len(mod_alert_message)) + +    embed.description = mod_alert_message +    return embed + + +def populate_embed_from_dict(embed: Embed, data: dict) -> None: +    """Populate a Discord embed by populating fields from the given dict.""" +    for setting, value in data.items(): +        if setting.startswith("_"): +            continue +        if isinstance(value, (list, set, tuple)): +            value = f"[{', '.join(map(str, value))}]" +        else: +            value = str(value) if value not in ("", None) else "-" +        if len(value) > MAX_FIELD_SIZE: +            value = value[:MAX_FIELD_SIZE] + " [...]" +        embed.add_field(name=setting, value=value, inline=len(value) < MAX_INLINE_SIZE) + + +def parse_value(value: str, type_: type[T]) -> T: +    """Parse the value provided in the CLI and attempt to convert it to the provided type.""" +    blank = value == '""' +    type_ = normalize_type(type_, prioritize_nonetype=blank) + +    if blank or isinstance(None, type_): +        return type_() +    if type_ in (tuple, list, set): +        return list(value.split(",")) +    if type_ is bool: +        return value.lower() == "true" or value == "1" +    if isinstance(type_, EnumMeta): +        return type_[value.upper()] + +    return type_(value) + + +def format_response_error(e: ResponseCodeError) -> Embed: +    """Format the response error into an embed.""" +    description = "" +    if isinstance(e.response_json, list): +        description = "\n".join(f"• {error}" for error in e.response_json) +    elif isinstance(e.response_json, dict): +        if "non_field_errors" in e.response_json: +            non_field_errors = e.response_json.pop("non_field_errors") +            description += "\n".join(f"• {error}" for error in non_field_errors) + "\n" +        for field, errors in e.response_json.items(): +            description += "\n".join(f"• {field} - {error}" for error in errors) + "\n" + +    description = description.strip() +    if len(description) > MAX_EMBED_DESCRIPTION: +        description = description[:MAX_EMBED_DESCRIPTION] + "[...]" +    if not description: +        description = "Something unexpected happened, check the logs." + +    embed = Embed(colour=discord.Colour.red(), title="Oops...", description=description) +    return embed + + +class ArgumentCompletionSelect(discord.ui.Select): +    """A select detailing the options that can be picked to assign to a missing argument.""" + +    def __init__( +        self, +        ctx: Context, +        args: list, +        arg_name: str, +        options: list[str], +        position: int, +        converter: Optional[Callable] = None +    ): +        super().__init__( +            placeholder=f"Select a value for {arg_name!r}", +            options=[discord.SelectOption(label=option) for option in options] +        ) +        self.ctx = ctx +        self.args = args +        self.position = position +        self.converter = converter + +    async def callback(self, interaction: discord.Interaction) -> None: +        """re-invoke the context command with the completed argument value.""" +        await interaction.response.defer() +        value = interaction.data["values"][0] +        if self.converter: +            value = self.converter(value) +        args = self.args.copy()  # This makes the view reusable. +        args.insert(self.position, value) +        log.trace(f"Argument filled with the value {value}. Re-invoking command") +        await self.ctx.invoke(self.ctx.command, *args) + + +class ArgumentCompletionView(discord.ui.View): +    """A view used to complete a missing argument in an in invoked command.""" + +    def __init__( +        self, +        ctx: Context, +        args: list, +        arg_name: str, +        options: list[str], +        position: int, +        converter: Optional[Callable] = None +    ): +        super().__init__() +        log.trace(f"The {arg_name} argument was designated missing in the invocation {ctx.view.buffer!r}") +        self.add_item(ArgumentCompletionSelect(ctx, args, arg_name, options, position, converter)) +        self.ctx = ctx + +    async def interaction_check(self, interaction: discord.Interaction) -> bool: +        """Check to ensure that the interacting user is the user who invoked the command.""" +        if interaction.user != self.ctx.author: +            embed = discord.Embed(description="Sorry, but this dropdown menu can only be used by the original author.") +            await interaction.response.send_message(embed=embed, ephemeral=True) +            return False +        return True + + +class CustomCallbackSelect(discord.ui.Select): +    """A selection which calls the provided callback on interaction.""" + +    def __init__( +        self, +        callback: Callable[[Interaction, discord.ui.Select], Coroutine[None]], +        *, +        custom_id: str = SELECT_MISSING, +        placeholder: str | None = None, +        min_values: int = 1, +        max_values: int = 1, +        options: list[SelectOption] = SELECT_MISSING, +        disabled: bool = False, +        row: int | None = None, +    ): +        super().__init__( +            custom_id=custom_id, +            placeholder=placeholder, +            min_values=min_values, +            max_values=max_values, +            options=options, +            disabled=disabled, +            row=row +        ) +        self.custom_callback = callback + +    async def callback(self, interaction: Interaction) -> Any: +        """Invoke the provided callback.""" +        await self.custom_callback(interaction, self) + + +class BooleanSelectView(discord.ui.View): +    """A view containing an instance of BooleanSelect.""" + +    class BooleanSelect(discord.ui.Select): +        """Select a true or false value and send it to the supplied callback.""" + +        def __init__(self, setting_name: str, update_callback: Callable): +            super().__init__(options=[SelectOption(label="True"), SelectOption(label="False")]) +            self.setting_name = setting_name +            self.update_callback = update_callback + +        async def callback(self, interaction: Interaction) -> Any: +            """Respond to the interaction by sending the boolean value to the update callback.""" +            value = self.values[0] == "True" +            await self.update_callback(setting_name=self.setting_name, setting_value=value) +            await interaction.response.edit_message(content=EDIT_CONFIRMED_MESSAGE.format(self.setting_name), view=None) + +    def __init__(self, setting_name: str, update_callback: Callable): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.add_item(self.BooleanSelect(setting_name, update_callback)) + + +class FreeInputModal(discord.ui.Modal): +    """A modal to freely enter a value for a setting.""" + +    def __init__(self, setting_name: str, type_: type, update_callback: Callable): +        title = f"{setting_name} Input" if len(setting_name) < MAX_MODAL_TITLE_LENGTH - 6 else "Setting Input" +        super().__init__(timeout=COMPONENT_TIMEOUT, title=title) + +        self.setting_name = setting_name +        self.type_ = type_ +        self.update_callback = update_callback + +        label = setting_name if len(setting_name) < MAX_MODAL_TITLE_LENGTH else "Value" +        self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=False) +        self.add_item(self.setting_input) + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the setting with the new value in the embed.""" +        try: +            if not self.setting_input.value: +                value = self.type_() +            else: +                value = self.type_(self.setting_input.value) +        except (ValueError, TypeError): +            await interaction.response.send_message( +                f"Could not process the input value for `{self.setting_name}`.", ephemeral=True +            ) +        else: +            await self.update_callback(setting_name=self.setting_name, setting_value=value) +            await interaction.response.send_message( +                content=EDIT_CONFIRMED_MESSAGE.format(self.setting_name), ephemeral=True +            ) + + +class SequenceEditView(discord.ui.View): +    """A view to modify the contents of a sequence of values.""" + +    class SingleItemModal(discord.ui.Modal): +        """A modal to enter a single list item.""" + +        new_item = discord.ui.TextInput(label="New Item") + +        def __init__(self, view: SequenceEditView): +            super().__init__(title="Item Addition", timeout=COMPONENT_TIMEOUT) +            self.view = view + +        async def on_submit(self, interaction: Interaction) -> None: +            """Send the submitted value to be added to the list.""" +            await self.view.apply_addition(interaction, self.new_item.value) + +    class NewListModal(discord.ui.Modal): +        """A modal to enter new contents for the list.""" + +        new_value = discord.ui.TextInput(label="Enter comma separated values", style=discord.TextStyle.paragraph) + +        def __init__(self, view: SequenceEditView): +            super().__init__(title="New List", timeout=COMPONENT_TIMEOUT) +            self.view = view + +        async def on_submit(self, interaction: Interaction) -> None: +            """Send the submitted value to be added to the list.""" +            await self.view.apply_edit(interaction, self.new_value.value) + +    def __init__(self, setting_name: str, starting_value: list, update_callback: Callable): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.setting_name = setting_name +        self.stored_value = starting_value +        self.update_callback = update_callback + +        options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] +        self.removal_select = CustomCallbackSelect( +            self.apply_removal, placeholder="Enter an item to remove", options=options, row=1 +        ) +        if self.stored_value: +            self.add_item(self.removal_select) + +    async def apply_removal(self, interaction: Interaction, select: discord.ui.Select) -> None: +        """Remove an item from the list.""" +        # The value might not be stored as a string. +        _i = len(self.stored_value) +        for _i, element in enumerate(self.stored_value): +            if str(element) == select.values[0]: +                break +        if _i != len(self.stored_value): +            self.stored_value.pop(_i) + +        await interaction.response.edit_message( +            content=f"Current list: [{', '.join(self.stored_value)}]", view=self.copy() +        ) +        self.stop() + +    async def apply_addition(self, interaction: Interaction, item: str) -> None: +        """Add an item to the list.""" +        if item in self.stored_value:  # Ignore duplicates +            await interaction.response.defer() +            return + +        self.stored_value.append(item) +        await interaction.response.edit_message( +            content=f"Current list: [{', '.join(self.stored_value)}]", view=self.copy() +        ) +        self.stop() + +    async def apply_edit(self, interaction: Interaction, new_list: str) -> None: +        """Change the contents of the list.""" +        self.stored_value = list(set(part.strip() for part in new_list.split(",") if part.strip())) +        await interaction.response.edit_message( +            content=f"Current list: [{', '.join(self.stored_value)}]", view=self.copy() +        ) +        self.stop() + +    @discord.ui.button(label="Add Value") +    async def add_value(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to add an item to the list.""" +        await interaction.response.send_modal(self.SingleItemModal(self)) + +    @discord.ui.button(label="Free Input") +    async def free_input(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to change the entire list.""" +        await interaction.response.send_modal(self.NewListModal(self)) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Send the final value to the embed editor.""" +        # Edit first, it might time out otherwise. +        await self.update_callback(setting_name=self.setting_name, setting_value=self.stored_value) +        await interaction.response.edit_message(content=EDIT_CONFIRMED_MESSAGE.format(self.setting_name), view=None) +        self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the list editing.""" +        await interaction.response.edit_message(content="🚫 Canceled", view=None) +        self.stop() + +    def copy(self) -> SequenceEditView: +        """Return a copy of this view.""" +        return SequenceEditView(self.setting_name, self.stored_value, self.update_callback) + + +class EnumSelectView(discord.ui.View): +    """A view containing an instance of EnumSelect.""" + +    class EnumSelect(discord.ui.Select): +        """Select an enum value and send it to the supplied callback.""" + +        def __init__(self, setting_name: str, enum_cls: EnumMeta, update_callback: Callable): +            super().__init__(options=[SelectOption(label=elem.name) for elem in enum_cls]) +            self.setting_name = setting_name +            self.enum_cls = enum_cls +            self.update_callback = update_callback + +        async def callback(self, interaction: Interaction) -> Any: +            """Respond to the interaction by sending the enum value to the update callback.""" +            await self.update_callback(setting_name=self.setting_name, setting_value=self.values[0]) +            await interaction.response.edit_message(content=EDIT_CONFIRMED_MESSAGE.format(self.setting_name), view=None) + +    def __init__(self, setting_name: str, enum_cls: EnumMeta, update_callback: Callable): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.add_item(self.EnumSelect(setting_name, enum_cls, update_callback)) + + +class EditBaseView(ABC, discord.ui.View): +    """A view used to edit embed fields based on a provided type.""" + +    def __init__(self, author: discord.User): +        super().__init__(timeout=EDIT_TIMEOUT) +        self.author = author +        self.type_per_setting_name = {} + +    async def interaction_check(self, interaction: Interaction) -> bool: +        """Only allow interactions from the command invoker.""" +        return interaction.user.id == self.author.id + +    async def _prompt_new_value(self, interaction: Interaction, select: discord.ui.Select) -> None: +        """Prompt the user to give an override value for the setting they selected, and respond to the interaction.""" +        setting_name = select.values[0] +        type_ = self.type_per_setting_name[setting_name] +        if origin := get_origin(type_):  # In case this is a types.GenericAlias or a typing._GenericAlias +            type_ = origin +        new_view = self.copy() +        # This is in order to not block the interaction response. There's a potential race condition here, since +        # a view's method is used without guaranteeing the task completed, but since it depends on user input +        # realistically it shouldn't happen. +        scheduling.create_task(interaction.message.edit(view=new_view)) +        update_callback = partial(new_view.update_embed, interaction_or_msg=interaction.message) +        if type_ is bool: +            view = BooleanSelectView(setting_name, update_callback) +            await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True) +        elif type_ in (set, list, tuple): +            if (current_value := self.current_value(setting_name)) is not MISSING: +                current_list = [str(elem) for elem in current_value] +            else: +                current_list = [] +            await interaction.response.send_message( +                f"Current list: [{', '.join(current_list)}]", +                view=SequenceEditView(setting_name, current_list, update_callback), +                ephemeral=True +            ) +        elif isinstance(type_, EnumMeta): +            view = EnumSelectView(setting_name, type_, update_callback) +            await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True) +        else: +            await interaction.response.send_modal(FreeInputModal(setting_name, type_, update_callback)) +        self.stop() + +    @abstractmethod +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" + +    @abstractmethod +    async def update_embed(self, interaction_or_msg: Interaction | discord.Message) -> None: +        """ +        Update the embed with the new information. + +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ + +    @abstractmethod +    def copy(self) -> EditBaseView: +        """Create a copy of this view.""" + + +class DeleteConfirmationView(discord.ui.View): +    """A view to confirm a deletion.""" + +    def __init__(self, author: discord.Member | discord.User, callback: Callable): +        super().__init__(timeout=DELETION_TIMEOUT) +        self.author = author +        self.callback = callback + +    async def interaction_check(self, interaction: Interaction) -> bool: +        """Only allow interactions from the command invoker.""" +        return interaction.user.id == self.author.id + +    @discord.ui.button(label="Delete", style=discord.ButtonStyle.red, row=0) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Invoke the filter list deletion.""" +        await interaction.response.edit_message(view=None) +        await self.callback() + +    @discord.ui.button(label="Cancel", row=0) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the filter list deletion.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", view=None) + + +class AlertView(discord.ui.View): +    """A view providing info about the offending user.""" + +    def __init__(self, ctx: FilterContext): +        super().__init__(timeout=ALERT_VIEW_TIMEOUT) +        self.ctx = ctx + +    @discord.ui.button(label="ID") +    async def user_id(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Reply with the ID of the offending user.""" +        await interaction.response.send_message(self.ctx.author.id, ephemeral=True) + +    @discord.ui.button(emoji="👤") +    async def user_info(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Send the info embed of the offending user.""" +        command = bot.instance.get_command("user") +        if not command: +            await interaction.response.send_message("The command `user` is not loaded.", ephemeral=True) +            return + +        await interaction.response.defer() +        fake_ctx = FakeContext(interaction.message, interaction.channel, command, author=interaction.user) +        # Get the most updated user/member object every time the button is pressed. +        author = await get_or_fetch_member(interaction.guild, self.ctx.author.id) +        if author is None: +            author = await bot.instance.fetch_user(self.ctx.author.id) +        await command(fake_ctx, author) + +    @discord.ui.button(emoji="⚠") +    async def user_infractions(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Send the infractions embed of the offending user.""" +        command = bot.instance.get_command("infraction search") +        if not command: +            await interaction.response.send_message("The command `infraction search` is not loaded.", ephemeral=True) +            return + +        await interaction.response.defer() +        fake_ctx = FakeContext(interaction.message, interaction.channel, command, author=interaction.user) +        await command(fake_ctx, self.ctx.author) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py new file mode 100644 index 000000000..97a0fa8d4 --- /dev/null +++ b/bot/exts/filtering/_utils.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import importlib +import importlib.util +import inspect +import pkgutil +import types +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from functools import cache +from typing import Any, Iterable, TypeVar, Union, get_args, get_origin + +import discord +import regex +from discord.ext.commands import Command +from typing_extensions import Self + +import bot +from bot.bot import Bot +from bot.constants import Guild + +VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF" +INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1) +ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1) + + +T = TypeVar('T') + +Serializable = Union[bool, int, float, str, list, dict, None] + + +def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]: +    """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" +    subclasses = set() + +    # Find all modules in the package. +    for module_info in pkgutil.iter_modules([package], prefix): +        if not module_info.ispkg: +            module = importlib.import_module(module_info.name) +            # Find all classes in each module... +            for _, class_ in inspect.getmembers(module, inspect.isclass): +                # That are a subclass of the given class. +                if parent in class_.__mro__: +                    subclasses.add(class_) + +    return subclasses + + +def clean_input(string: str) -> str: +    """Remove zalgo and invisible characters from `string`.""" +    # For future consideration: remove characters in the Mc, Sk, and Lm categories too. +    # Can be normalised with form C to merge char + combining char into a single char to avoid +    # removing legit diacritics, but this would open up a way to bypass _filters. +    no_zalgo = ZALGO_RE.sub("", string) +    return INVISIBLE_RE.sub("", no_zalgo) + + +def past_tense(word: str) -> str: +    """Return the past tense form of the input word.""" +    if not word: +        return word +    if word.endswith("e"): +        return word + "d" +    if word.endswith("y") and len(word) > 1 and word[-2] not in "aeiou": +        return word[:-1] + "ied" +    return word + "ed" + + +def to_serializable(item: Any, *, ui_repr: bool = False) -> Serializable: +    """ +    Convert the item into an object that can be converted to JSON. + +    `ui_repr` dictates whether to use the UI representation of `CustomIOField` instances (if any) +    or the DB-oriented representation. +    """ +    if isinstance(item, (bool, int, float, str, type(None))): +        return item +    if isinstance(item, dict): +        result = {} +        for key, value in item.items(): +            if not isinstance(key, (bool, int, float, str, type(None))): +                key = str(key) +            result[key] = to_serializable(value, ui_repr=ui_repr) +        return result +    if isinstance(item, Iterable): +        return [to_serializable(subitem, ui_repr=ui_repr) for subitem in item] +    if not ui_repr and hasattr(item, "serialize"): +        return item.serialize() +    return str(item) + + +@cache +def resolve_mention(mention: str) -> str: +    """Return the appropriate formatting for the mention, be it a literal, a user ID, or a role ID.""" +    guild = bot.instance.get_guild(Guild.id) +    if mention in ("here", "everyone"): +        return f"@{mention}" +    try: +        mention = int(mention)  # It's an ID. +    except ValueError: +        pass +    else: +        if any(mention == role.id for role in guild.roles): +            return f"<@&{mention}>" +        else: +            return f"<@{mention}>" + +    # It's a name +    for role in guild.roles: +        if role.name == mention: +            return role.mention +    for member in guild.members: +        if str(member) == mention: +            return member.mention +    return mention + + +def repr_equals(override: Any, default: Any) -> bool: +    """Return whether the override and the default have the same representation.""" +    if override is None:  # It's not an override +        return True + +    override_is_sequence = isinstance(override, (tuple, list, set)) +    default_is_sequence = isinstance(default, (tuple, list, set)) +    if override_is_sequence != default_is_sequence:  # One is a sequence and the other isn't. +        return False +    if override_is_sequence: +        if len(override) != len(default): +            return False +        return all(str(item1) == str(item2) for item1, item2 in zip(set(override), set(default))) +    return str(override) == str(default) + + +def normalize_type(type_: type, *, prioritize_nonetype: bool = True) -> type: +    """Reduce a given type to one that can be initialized.""" +    if get_origin(type_) in (Union, types.UnionType):  # In case of a Union +        args = get_args(type_) +        if type(None) in args: +            if prioritize_nonetype: +                return type(None) +            else: +                args = tuple(set(args) - {type(None)}) +        type_ = args[0]  # Pick one, doesn't matter +    if origin := get_origin(type_):  # In case of a parameterized List, Set, Dict etc. +        return origin +    return type_ + + +def starting_value(type_: type[T]) -> T: +    """Return a value of the given type.""" +    type_ = normalize_type(type_) +    try: +        return type_() +    except TypeError:  # In case it all fails, return a string and let the user handle it. +        return "" + + +class FieldRequiring(ABC): +    """A mixin class that can force its concrete subclasses to set a value for specific class attributes.""" + +    # Sentinel value that mustn't remain in a concrete subclass. +    MUST_SET = object() + +    # Sentinel value that mustn't remain in a concrete subclass. +    # Overriding value must be unique in the subclasses of the abstract class in which the attribute was set. +    MUST_SET_UNIQUE = object() + +    # A mapping of the attributes which must be unique, and their unique values, per FieldRequiring subclass. +    __unique_attributes: defaultdict[type, dict[str, set]] = defaultdict(dict) + +    @abstractmethod +    def __init__(self): +        ... + +    def __init_subclass__(cls, **kwargs): +        def inherited(attr: str) -> bool: +            """True if `attr` was inherited from a parent class.""" +            for parent in cls.__mro__[1:-1]:  # The first element is the class itself, last element is object. +                if hasattr(parent, attr):  # The attribute was inherited. +                    return True +            return False + +        # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it. +        if inspect.isabstract(cls): +            for attribute in dir(cls): +                if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE: +                    if not inherited(attribute): +                        # A new attribute with the value MUST_SET_UNIQUE. +                        FieldRequiring.__unique_attributes[cls][attribute] = set() +            return + +        for attribute in dir(cls): +            if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"): +                continue +            value = getattr(cls, attribute) +            if value is FieldRequiring.MUST_SET and inherited(attribute): +                raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}") +            elif value is FieldRequiring.MUST_SET_UNIQUE and inherited(attribute): +                raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}") +            else: +                # Check if the value needs to be unique. +                for parent in cls.__mro__[1:-1]: +                    # Find the parent class the attribute was first defined in. +                    if attribute in FieldRequiring.__unique_attributes[parent]: +                        if value in FieldRequiring.__unique_attributes[parent][attribute]: +                            raise ValueError(f"Value of {attribute!r} in {cls!r} is not unique for parent {parent!r}.") +                        else: +                            # Add to the set of unique values for that field. +                            FieldRequiring.__unique_attributes[parent][attribute].add(value) + + +@dataclass +class FakeContext: +    """ +    A class representing a context-like object that can be sent to infraction commands. + +    The goal is to be able to apply infractions without depending on the existence of a message or an interaction +    (which are the two ways to create a Context), e.g. in API events which aren't message-driven, or in custom filtering +    events. +    """ + +    message: discord.Message +    channel: discord.abc.Messageable +    command: Command | None +    bot: Bot | None = None +    guild: discord.Guild | None = None +    author: discord.Member | discord.User | None = None +    me: discord.Member | None = None + +    def __post_init__(self): +        """Initialize the missing information.""" +        if not self.bot: +            self.bot = bot.instance +        if not self.guild: +            self.guild = self.bot.get_guild(Guild.id) +        if not self.me: +            self.me = self.guild.me +        if not self.author: +            self.author = self.me + +    async def send(self, *args, **kwargs) -> discord.Message: +        """A wrapper for channel.send.""" +        return await self.channel.send(*args, **kwargs) + + +class CustomIOField: +    """ +    A class to be used as a data type in SettingEntry subclasses. + +    Its subclasses can have custom methods to read and represent the value, which will be used by the UI. +    """ + +    def __init__(self, value: Any): +        self.value = self.process_value(value) + +    @classmethod +    def __get_validators__(cls): +        """Boilerplate for Pydantic.""" +        yield cls.validate + +    @classmethod +    def validate(cls, v: Any) -> Self: +        """Takes the given value and returns a class instance with that value.""" +        if isinstance(v, CustomIOField): +            return cls(v.value) + +        return cls(v) + +    def __eq__(self, other: CustomIOField): +        if not isinstance(other, CustomIOField): +            return NotImplemented +        return self.value == other.value + +    @classmethod +    def process_value(cls, v: str) -> Any: +        """ +        Perform any necessary transformations before the value is stored in a new instance. + +        Override this method to customize the input behavior. +        """ +        return v + +    def serialize(self) -> Serializable: +        """Override this method to customize how the value will be serialized.""" +        return self.value + +    def __str__(self): +        """Override this method to change how the value will be displayed by the UI.""" +        return self.value diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py new file mode 100644 index 000000000..392428bb0 --- /dev/null +++ b/bot/exts/filtering/filtering.py @@ -0,0 +1,1431 @@ +import datetime +import json +import re +import unicodedata +from collections import defaultdict +from collections.abc import Iterable, Mapping +from functools import partial, reduce +from io import BytesIO +from operator import attrgetter +from typing import Literal, Optional, get_type_hints + +import arrow +import discord +from async_rediscache import RedisCache +from discord import Colour, Embed, HTTPException, Message, MessageType +from discord.ext import commands, tasks +from discord.ext.commands import BadArgument, Cog, Context, command, has_any_role +from pydis_core.site_api import ResponseCodeError +from pydis_core.utils import scheduling + +import bot +import bot.exts.filtering._ui.filter as filters_ui +from bot import constants +from bot.bot import Bot +from bot.constants import Channels, Guild, MODERATION_ROLES, Roles +from bot.exts.backend.branding._repository import HEADERS, PARAMS +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import FilterList, ListType, filter_list_types, list_type_converter +from bot.exts.filtering._filter_lists.filter_list import AtomicList +from bot.exts.filtering._filters.filter import Filter, UniqueFilter +from bot.exts.filtering._settings import ActionSettings +from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction +from bot.exts.filtering._ui.filter import ( +    build_filter_repr_dict, description_and_settings_converter, filter_overrides_for_ui, populate_embed_from_dict +) +from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter +from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter +from bot.exts.filtering._ui.ui import ( +    AlertView, ArgumentCompletionView, DeleteConfirmationView, build_mod_alert, format_response_error +) +from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable +from bot.exts.moderation.infraction.infractions import COMP_BAN_DURATION, COMP_BAN_REASON +from bot.exts.utils.snekbox._io import FileAttachment +from bot.log import get_logger +from bot.pagination import LinePaginator +from bot.utils.channel import is_mod_channel +from bot.utils.lock import lock_arg +from bot.utils.message_cache import MessageCache + +log = get_logger(__name__) + +WEBHOOK_ICON_URL = r"https://github.com/python-discord/branding/raw/main/icons/filter/filter_pfp.png" +WEBHOOK_NAME = "Filtering System" +CACHE_SIZE = 1000 +HOURS_BETWEEN_NICKNAME_ALERTS = 1 +OFFENSIVE_MSG_DELETE_TIME = datetime.timedelta(days=7) +WEEKLY_REPORT_ISO_DAY = 3  # 1=Monday, 7=Sunday + + +class Filtering(Cog): +    """Filtering and alerting for content posted on the server.""" + +    # A set of filter list names with missing implementations that already caused a warning. +    already_warned = set() + +    # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent. +    name_alerts = RedisCache() + +    # region: init + +    def __init__(self, bot: Bot): +        self.bot = bot +        self.filter_lists: dict[str, FilterList] = {} +        self._subscriptions: defaultdict[Event, list[FilterList]] = defaultdict(list) +        self.delete_scheduler = scheduling.Scheduler(self.__class__.__name__) +        self.webhook: discord.Webhook | None = None + +        self.loaded_settings = {} +        self.loaded_filters = {} +        self.loaded_filter_settings = {} + +        self.message_cache = MessageCache(CACHE_SIZE, newest_first=True) + +    async def cog_load(self) -> None: +        """ +        Fetch the filter data from the API, parse it, and load it to the appropriate data structures. + +        Additionally, fetch the alerting webhook. +        """ +        await self.bot.wait_until_guild_available() + +        log.trace("Loading filtering information from the database.") +        raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists") +        example_list = None +        for raw_filter_list in raw_filter_lists: +            loaded_list = self._load_raw_filter_list(raw_filter_list) +            if not example_list and loaded_list: +                example_list = loaded_list + +        # The webhook must be generated by the bot to send messages with components through it. +        self.webhook = await self._fetch_or_generate_filtering_webhook() + +        self.collect_loaded_types(example_list) +        await self.schedule_offending_messages_deletion() +        self.weekly_auto_infraction_report_task.start() + +    def subscribe(self, filter_list: FilterList, *events: Event) -> None: +        """ +        Subscribe a filter list to the given events. + +        The filter list is added to a list for each event. When the event is triggered, the filter context will be +        dispatched to the subscribed filter lists. + +        While it's possible to just make each filter list check the context's event, these are only the events a filter +        list expects to receive from the filtering cog, there isn't an actual limitation on the kinds of events a filter +        list can handle as long as the filter context is built properly. If for whatever reason we want to invoke a +        filter list outside of the usual procedure with the filtering cog, it will be more problematic if the events are +        hard-coded into each filter list. +        """ +        for event in events: +            if filter_list not in self._subscriptions[event]: +                self._subscriptions[event].append(filter_list) + +    def unsubscribe(self, filter_list: FilterList, *events: Event) -> None: +        """Unsubscribe a filter list from the given events. If no events given, unsubscribe from every event.""" +        if not events: +            events = list(self._subscriptions) + +        for event in events: +            if filter_list in self._subscriptions.get(event, []): +                self._subscriptions[event].remove(filter_list) + +    def collect_loaded_types(self, example_list: AtomicList) -> None: +        """ +        Go over the classes used in initialization and collect them to dictionaries. + +        The information that is collected is about the types actually used to load the API response, not all types +        available in the filtering extension. + +        Any filter list has the fields for all settings in the DB schema, so picking any one of them is enough. +        """ +        # Get the filter types used by each filter list. +        for filter_list in self.filter_lists.values(): +            self.loaded_filters.update({filter_type.name: filter_type for filter_type in filter_list.filter_types}) + +        # Get the setting types used by each filter list. +        if self.filter_lists: +            settings_entries = set() +            # The settings are split between actions and validations. +            for settings_group in example_list.defaults: +                settings_entries.update(type(setting) for _, setting in settings_group.items()) + +            for setting_entry in settings_entries: +                type_hints = get_type_hints(setting_entry) +                # The description should be either a string or a dictionary. +                if isinstance(setting_entry.description, str): +                    # If it's a string, then the settings entry matches a single field in the DB, +                    # and its name is the setting type's name attribute. +                    self.loaded_settings[setting_entry.name] = ( +                        setting_entry.description, setting_entry, type_hints[setting_entry.name] +                    ) +                else: +                    # Otherwise, the setting entry works with compound settings. +                    self.loaded_settings.update({ +                        subsetting: (description, setting_entry, type_hints[subsetting]) +                        for subsetting, description in setting_entry.description.items() +                    }) + +        # Get the settings per filter as well. +        for filter_name, filter_type in self.loaded_filters.items(): +            extra_fields_type = filter_type.extra_fields_type +            if not extra_fields_type: +                continue +            type_hints = get_type_hints(extra_fields_type) +            # A class var with a `_description` suffix is expected per field name. +            self.loaded_filter_settings[filter_name] = { +                field_name: ( +                    getattr(extra_fields_type, f"{field_name}_description", ""), +                    extra_fields_type, +                    type_hints[field_name] +                ) +                for field_name in extra_fields_type.__fields__ +            } + +    async def schedule_offending_messages_deletion(self) -> None: +        """Load the messages that need to be scheduled for deletion from the database.""" +        response = await self.bot.api_client.get('bot/offensive-messages') + +        now = arrow.utcnow() +        for msg in response: +            delete_at = arrow.get(msg['delete_date']) +            if delete_at < now: +                await self._delete_offensive_msg(msg) +            else: +                self._schedule_msg_delete(msg) + +    async def cog_check(self, ctx: Context) -> bool: +        """Only allow moderators to invoke the commands in this cog.""" +        return await has_any_role(*MODERATION_ROLES).predicate(ctx) + +    # endregion +    # region: listeners and event handlers + +    @Cog.listener() +    async def on_message(self, msg: Message) -> None: +        """Filter the contents of a sent message.""" +        if msg.author.bot or msg.webhook_id or msg.type == MessageType.auto_moderation_action: +            return +        self.message_cache.append(msg) + +        ctx = FilterContext.from_message(Event.MESSAGE, msg, None, self.message_cache) +        result_actions, list_messages, triggers = await self._resolve_action(ctx) +        self.message_cache.update(msg, metadata=triggers) +        if result_actions: +            await result_actions.action(ctx) +        if ctx.send_alert: +            await self._send_alert(ctx, list_messages) + +        nick_ctx = FilterContext.from_message(Event.NICKNAME, msg) +        nick_ctx.content = msg.author.display_name +        await self._check_bad_name(nick_ctx) + +        await self._maybe_schedule_msg_delete(ctx, result_actions) +        self._increment_stats(triggers) + +    @Cog.listener() +    async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: +        """Filter the contents of an edited message. Don't reinvoke filters already invoked on the `before` version.""" +        # Only check changes to the message contents/attachments and embed additions, not pin status etc. +        if all(( +            before.content == after.content,  # content hasn't changed +            before.attachments == after.attachments,  # attachments haven't changed +            len(before.embeds) >= len(after.embeds)  # embeds haven't been added +        )): +            return + +        # Update the cache first, it might be used by the antispam filter. +        # No need to update the triggers, they're going to be updated inside the sublists if necessary. +        self.message_cache.update(after) +        ctx = FilterContext.from_message(Event.MESSAGE_EDIT, after, before, self.message_cache) +        result_actions, list_messages, triggers = await self._resolve_action(ctx) +        if result_actions: +            await result_actions.action(ctx) +        if ctx.send_alert: +            await self._send_alert(ctx, list_messages) +        await self._maybe_schedule_msg_delete(ctx, result_actions) +        self._increment_stats(triggers) + +    @Cog.listener() +    async def on_voice_state_update(self, member: discord.Member, *_) -> None: +        """Checks for bad words in usernames when users join, switch or leave a voice channel.""" +        ctx = FilterContext(Event.NICKNAME, member, None, member.display_name, None) +        await self._check_bad_name(ctx) + +    async def filter_snekbox_output( +        self, stdout: str, files: list[FileAttachment], msg: Message +    ) -> tuple[bool, set[str]]: +        """ +        Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. + +        Also requires the original message, to check whether to filter and for alerting. +        Any action (deletion, infraction) will be applied in the context of the original message. + +        Returns whether the output should be blocked, as well as a list of blocked file extensions. +        """ +        content = stdout +        if files:  # Filter the filenames as well. +            content += "\n\n" + "\n".join(file.filename for file in files) +        ctx = FilterContext.from_message(Event.SNEKBOX, msg).replace(content=content, attachments=files) + +        result_actions, list_messages, triggers = await self._resolve_action(ctx) +        if result_actions: +            await result_actions.action(ctx) +        if ctx.send_alert: +            await self._send_alert(ctx, list_messages) + +        self._increment_stats(triggers) +        return result_actions is not None, ctx.blocked_exts + +    # endregion +    # region: blacklist commands + +    @commands.group(aliases=("bl", "blacklist", "denylist", "dl")) +    async def blocklist(self, ctx: Context) -> None: +        """Group for managing blacklisted items.""" +        if not ctx.invoked_subcommand: +            await ctx.send_help(ctx.command) + +    @blocklist.command(name="list", aliases=("get",)) +    async def bl_list(self, ctx: Context, list_name: Optional[str] = None) -> None: +        """List the contents of a specified blacklist.""" +        result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type") +        if not result: +            return +        list_type, filter_list = result +        await self._send_list(ctx, filter_list, list_type) + +    @blocklist.command(name="add", aliases=("a",)) +    async def bl_add( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        list_name: Optional[str], +        content: str, +        *, +        description_and_settings: Optional[str] = None +    ) -> None: +        """ +        Add a blocked filter to the specified filter list. + +        Unless `noui` is specified, a UI will be provided to edit the content, description, and settings +        before confirmation. + +        The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the +        equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces. +        """ +        result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type") +        if result is None: +            return +        list_type, filter_list = result +        await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings) + +    # endregion +    # region: whitelist commands + +    @commands.group(aliases=("wl", "whitelist", "al")) +    async def allowlist(self, ctx: Context) -> None: +        """Group for managing blacklisted items.""" +        if not ctx.invoked_subcommand: +            await ctx.send_help(ctx.command) + +    @allowlist.command(name="list", aliases=("get",)) +    async def al_list(self, ctx: Context, list_name: Optional[str] = None) -> None: +        """List the contents of a specified whitelist.""" +        result = await self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name, exclude="list_type") +        if not result: +            return +        list_type, filter_list = result +        await self._send_list(ctx, filter_list, list_type) + +    @allowlist.command(name="add", aliases=("a",)) +    async def al_add( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        list_name: Optional[str], +        content: str, +        *, +        description_and_settings: Optional[str] = None +    ) -> None: +        """ +        Add an allowed filter to the specified filter list. + +        Unless `noui` is specified, a UI will be provided to edit the content, description, and settings +        before confirmation. + +        The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the +        equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces. +        """ +        result = await self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name, exclude="list_type") +        if result is None: +            return +        list_type, filter_list = result +        await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings) + +    # endregion +    # region: filter commands + +    @commands.group(aliases=("filters", "f"), invoke_without_command=True) +    async def filter(self, ctx: Context, id_: Optional[int] = None) -> None: +        """ +        Group for managing filters. + +        If a valid filter ID is provided, an embed describing the filter will be posted. +        """ +        if not ctx.invoked_subcommand and not id_: +            await ctx.send_help(ctx.command) +            return + +        result = self._get_filter_by_id(id_) +        if result is None: +            await ctx.send(f":x: Could not find a filter with ID `{id_}`.") +            return +        filter_, filter_list, list_type = result + +        overrides_values, extra_fields_overrides = filter_overrides_for_ui(filter_) + +        all_settings_repr_dict = build_filter_repr_dict( +            filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides +        ) +        embed = Embed(colour=Colour.blue()) +        populate_embed_from_dict(embed, all_settings_repr_dict) +        embed.description = f"`{filter_.content}`" +        if filter_.description: +            embed.description += f" - {filter_.description}" +        embed.set_author(name=f"Filter {id_} - " + f"{filter_list[list_type].label}".title()) +        embed.set_footer(text=( +            "Field names with an asterisk have values which override the defaults of the containing filter list. " +            f"To view all defaults of the list, " +            f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`." +        )) +        await ctx.send(embed=embed) + +    @filter.command(name="list", aliases=("get",)) +    async def f_list( +        self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None +    ) -> None: +        """List the contents of a specified list of filters.""" +        result = await self._resolve_list_type_and_name(ctx, list_type, list_name) +        if result is None: +            return +        list_type, filter_list = result + +        await self._send_list(ctx, filter_list, list_type) + +    @filter.command(name="describe", aliases=("explain", "manual")) +    async def f_describe(self, ctx: Context, filter_name: Optional[str]) -> None: +        """Show a description of the specified filter, or a list of possible values if no name is specified.""" +        if not filter_name: +            filter_names = [f"» {f}" for f in self.loaded_filters] +            embed = Embed(colour=Colour.blue()) +            embed.set_author(name="List of filter names") +            await LinePaginator.paginate(filter_names, ctx, embed, max_lines=10, empty=False) +        else: +            filter_type = self.loaded_filters.get(filter_name) +            if not filter_type: +                filter_type = self.loaded_filters.get(filter_name[:-1])  # A plural form or a typo. +                if not filter_type: +                    await ctx.send(f":x: There's no filter type named {filter_name!r}.") +                    return +            # Use the class's docstring, and ignore single newlines. +            embed = Embed(description=re.sub(r"(?<!\n)\n(?!\n)", " ", filter_type.__doc__), colour=Colour.blue()) +            embed.set_author(name=f"Description of the {filter_name} filter") +            await ctx.send(embed=embed) + +    @filter.command(name="add", aliases=("a",)) +    async def f_add( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        list_type: Optional[list_type_converter], +        list_name: Optional[str], +        content: str, +        *, +        description_and_settings: Optional[str] = None +    ) -> None: +        """ +        Add a filter to the specified filter list. + +        Unless `noui` is specified, a UI will be provided to edit the content, description, and settings +        before confirmation. + +        The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the +        equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces. + +        A template filter can be specified in the settings area to copy overrides from. The setting name is "--template" +        and the value is the filter ID. The template will be used before applying any other override. + +        Example: `!filter add denied token "Scaleios is great" remove_context=True send_alert=False --template=100` +        """ +        result = await self._resolve_list_type_and_name(ctx, list_type, list_name) +        if result is None: +            return +        list_type, filter_list = result +        await self._add_filter(ctx, noui, list_type, filter_list, content, description_and_settings) + +    @filter.command(name="edit", aliases=("e",)) +    async def f_edit( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        filter_id: int, +        *, +        description_and_settings: Optional[str] = None +    ) -> None: +        """ +        Edit a filter specified by its ID. + +        Unless `noui` is specified, a UI will be provided to edit the content, description, and settings +        before confirmation. + +        The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the +        equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces. + +        A template filter can be specified in the settings area to copy overrides from. The setting name is "--template" +        and the value is the filter ID. The template will be used before applying any other override. + +        To edit the filter's content, use the UI. +        """ +        result = self._get_filter_by_id(filter_id) +        if result is None: +            await ctx.send(f":x: Could not find a filter with ID `{filter_id}`.") +            return +        filter_, filter_list, list_type = result +        filter_type = type(filter_) +        settings, filter_settings = filter_overrides_for_ui(filter_) +        description, new_settings, new_filter_settings = description_and_settings_converter( +            filter_list, +            list_type, filter_type, +            self.loaded_settings, +            self.loaded_filter_settings, +            description_and_settings +        ) + +        content = filter_.content +        description = description or filter_.description +        settings.update(new_settings) +        filter_settings.update(new_filter_settings) +        patch_func = partial(self._patch_filter, filter_) + +        if noui: +            try: +                await patch_func( +                    ctx.message, filter_list, list_type, filter_type, content, description, settings, filter_settings +                ) +            except ResponseCodeError as e: +                await ctx.reply(embed=format_response_error(e)) +            return + +        embed = Embed(colour=Colour.blue()) +        embed.description = f"`{filter_.content}`" +        if description: +            embed.description += f" - {description}" +        embed.set_author( +            name=f"Filter {filter_id} - {filter_list[list_type].label}".title()) +        embed.set_footer(text=( +            "Field names with an asterisk have values which override the defaults of the containing filter list. " +            f"To view all defaults of the list, " +            f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`." +        )) + +        view = filters_ui.FilterEditView( +            filter_list, +            list_type, +            filter_type, +            content, +            description, +            settings, +            filter_settings, +            self.loaded_settings, +            self.loaded_filter_settings, +            ctx.author, +            embed, +            patch_func +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) + +    @filter.command(name="delete", aliases=("d", "remove")) +    async def f_delete(self, ctx: Context, filter_id: int) -> None: +        """Delete the filter specified by its ID.""" +        async def delete_list() -> None: +            """The actual removal routine.""" +            await bot.instance.api_client.delete(f'bot/filter/filters/{filter_id}') +            log.info(f"Successfully deleted filter with ID {filter_id}.") +            filter_list[list_type].filters.pop(filter_id) +            await ctx.reply(f"✅ Deleted filter: {filter_}") + +        result = self._get_filter_by_id(filter_id) +        if result is None: +            await ctx.send(f":x: Could not find a filter with ID `{filter_id}`.") +            return +        filter_, filter_list, list_type = result +        await ctx.reply( +            f"Are you sure you want to delete filter {filter_}?", +            view=DeleteConfirmationView(ctx.author, delete_list) +        ) + +    @filter.command(aliases=("settings",)) +    async def setting(self, ctx: Context, setting_name: str | None) -> None: +        """Show a description of the specified setting, or a list of possible settings if no name is specified.""" +        if not setting_name: +            settings_list = [f"» {setting_name}" for setting_name in self.loaded_settings] +            for filter_name, filter_settings in self.loaded_filter_settings.items(): +                settings_list.extend(f"» {filter_name}/{setting}" for setting in filter_settings) +            embed = Embed(colour=Colour.blue()) +            embed.set_author(name="List of setting names") +            await LinePaginator.paginate(settings_list, ctx, embed, max_lines=10, empty=False) + +        else: +            # The setting is either in a SettingsEntry subclass, or a pydantic model. +            setting_data = self.loaded_settings.get(setting_name) +            description = None +            if setting_data: +                description = setting_data[0] +            elif "/" in setting_name:  # It's a filter specific setting. +                filter_name, filter_setting_name = setting_name.split("/", maxsplit=1) +                if filter_name in self.loaded_filter_settings: +                    if filter_setting_name in self.loaded_filter_settings[filter_name]: +                        description = self.loaded_filter_settings[filter_name][filter_setting_name][0] +            if description is None: +                await ctx.send(f":x: There's no setting type named {setting_name!r}.") +                return +            embed = Embed(colour=Colour.blue(), description=description) +            embed.set_author(name=f"Description of the {setting_name} setting") +            await ctx.send(embed=embed) + +    @filter.command(name="match") +    async def f_match( +        self, ctx: Context, no_user: bool | None, message: Message | None, *, string: str | None +    ) -> None: +        """ +        List the filters triggered for the given message or string. + +        If there's a `message`, the `string` will be ignored. Note that if a `message` is provided, it will go through +        all validations appropriate to where it was sent and who sent it. To check for matches regardless of the author +        (for example if the message was sent by another staff member or yourself) set `no_user` to '1' or 'True'. + +        If a `string` is provided, it will be validated in the context of a user with no roles in python-general. +        """ +        if not message and not string: +            raise BadArgument("Please provide input.") +        if message: +            user = None if no_user else message.author +            filter_ctx = FilterContext(Event.MESSAGE, user, message.channel, message.content, message, message.embeds) +        else: +            python_general = ctx.guild.get_channel(Channels.python_general) +            filter_ctx = FilterContext(Event.MESSAGE, None, python_general, string, None) + +        _, _, triggers = await self._resolve_action(filter_ctx) +        lines = [] +        for sublist, sublist_triggers in triggers.items(): +            if sublist_triggers: +                triggers_repr = map(str, sublist_triggers) +                lines.extend([f"**{sublist.label.title()}s**", *triggers_repr, "\n"]) +        lines = lines[:-1]  # Remove last newline. + +        embed = Embed(colour=Colour.blue(), title="Match results") +        await LinePaginator.paginate(lines, ctx, embed, max_lines=10, empty=False) + +    @filter.command(name="search") +    async def f_search( +        self, +        ctx: Context, +        noui: Literal["noui"] | None, +        filter_type_name: str | None, +        *, +        settings: str = "" +    ) -> None: +        """ +        Find filters with the provided settings. The format is identical to that of the add and edit commands. + +        If a list type and/or a list name are provided, the search will be limited to those parameters. A list name must +        be provided in order to search by filter-specific settings. +        """ +        filter_type = None +        if filter_type_name: +            filter_type_name = filter_type_name.lower() +            filter_type = self.loaded_filters.get(filter_type_name) +            if not filter_type: +                self.loaded_filters.get(filter_type_name[:-1])  # In case the user tried to specify the plural form. +        # If settings were provided with no filter_type, discord.py will capture the first word as the filter type. +        if filter_type is None and filter_type_name is not None: +            if settings: +                settings = f"{filter_type_name} {settings}" +            else: +                settings = filter_type_name +            filter_type_name = None + +        settings, filter_settings, filter_type = search_criteria_converter( +            self.filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            filter_type, +            settings +        ) + +        if noui: +            await self._search_filters(ctx.message, filter_type, settings, filter_settings) +            return + +        embed = Embed(colour=Colour.blue()) +        view = SearchEditView( +            filter_type, +            settings, +            filter_settings, +            self.filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            ctx.author, +            embed, +            self._search_filters +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) + +    @filter.command(root_aliases=("compfilter", "compf")) +    async def compadd( +        self, ctx: Context, list_name: Optional[str], content: str, *, description: Optional[str] = "Phishing" +    ) -> None: +        """Add a filter to detect a compromised account. Will apply the equivalent of a compban if triggered.""" +        result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name, exclude="list_type") +        if result is None: +            return +        list_type, filter_list = result + +        settings = ( +            "remove_context=True " +            "dm_pings=Moderators " +            "infraction_type=BAN " +            "infraction_channel=1 "  # Post the ban in #mod-alerts +            f"infraction_duration={COMP_BAN_DURATION.total_seconds()} " +            f"infraction_reason={COMP_BAN_REASON}" +        ) +        description_and_settings = f"{description} {settings}" +        await self._add_filter(ctx, "noui", list_type, filter_list, content, description_and_settings) + +    # endregion +    # region: filterlist group + +    @commands.group(aliases=("fl",)) +    async def filterlist(self, ctx: Context) -> None: +        """Group for managing filter lists.""" +        if not ctx.invoked_subcommand: +            await ctx.send_help(ctx.command) + +    @filterlist.command(name="describe", aliases=("explain", "manual", "id")) +    async def fl_describe( +        self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None +    ) -> None: +        """Show a description of the specified filter list, or a list of possible values if no values are provided.""" +        if not list_type and not list_name: +            list_names = [f"» {fl}" for fl in self.filter_lists] +            embed = Embed(colour=Colour.blue()) +            embed.set_author(name="List of filter lists names") +            await LinePaginator.paginate(list_names, ctx, embed, max_lines=10, empty=False) +            return + +        result = await self._resolve_list_type_and_name(ctx, list_type, list_name) +        if result is None: +            return +        list_type, filter_list = result + +        setting_values = {} +        for settings_group in filter_list[list_type].defaults: +            for _, setting in settings_group.items(): +                setting_values.update(to_serializable(setting.dict(), ui_repr=True)) + +        embed = Embed(colour=Colour.blue()) +        populate_embed_from_dict(embed, setting_values) +        # Use the class's docstring, and ignore single newlines. +        embed.description = re.sub(r"(?<!\n)\n(?!\n)", " ", filter_list.__doc__) +        embed.set_author( +            name=f"Description of the {filter_list[list_type].label} filter list" +        ) +        await ctx.send(embed=embed) + +    @filterlist.command(name="add", aliases=("a",)) +    @has_any_role(Roles.admins) +    async def fl_add(self, ctx: Context, list_type: list_type_converter, list_name: str) -> None: +        """Add a new filter list.""" +        # Check if there's an implementation. +        if list_name.lower() not in filter_list_types: +            if list_name.lower()[:-1] not in filter_list_types:  # Maybe the name was given with uppercase or in plural? +                await ctx.reply(f":x: Cannot add a `{list_name}` filter list, as there is no matching implementation.") +                return +            else: +                list_name = list_name.lower()[:-1] + +        # Check it doesn't already exist. +        list_description = f"{past_tense(list_type.name.lower())} {list_name.lower()}" +        if list_name in self.filter_lists: +            filter_list = self.filter_lists[list_name] +            if list_type in filter_list: +                await ctx.reply(f":x: The {list_description} filter list already exists.") +                return + +        embed = Embed(colour=Colour.blue()) +        embed.set_author(name=f"New Filter List - {list_description.title()}") +        settings = {name: starting_value(value[2]) for name, value in self.loaded_settings.items()} + +        view = FilterListAddView( +            list_name, +            list_type, +            settings, +            self.loaded_settings, +            ctx.author, +            embed, +            self._post_filter_list +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) + +    @filterlist.command(name="edit", aliases=("e",)) +    @has_any_role(Roles.admins) +    async def fl_edit( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        list_type: Optional[list_type_converter] = None, +        list_name: Optional[str] = None, +        *, +        settings: str | None +    ) -> None: +        """ +        Edit the filter list. + +        Unless `noui` is specified, a UI will be provided to edit the settings before confirmation. + +        The settings can be provided in the command itself, in the format of `setting_name=value` (no spaces around the +        equal sign). The value doesn't need to (shouldn't) be surrounded in quotes even if it contains spaces. +        """ +        result = await self._resolve_list_type_and_name(ctx, list_type, list_name) +        if result is None: +            return +        list_type, filter_list = result +        settings = settings_converter(self.loaded_settings, settings) +        if noui: +            try: +                await self._patch_filter_list(ctx.message, filter_list, list_type, settings) +            except ResponseCodeError as e: +                await ctx.reply(embed=format_response_error(e)) +            return + +        embed = Embed(colour=Colour.blue()) +        embed.set_author(name=f"{filter_list[list_type].label.title()} Filter List") +        embed.set_footer(text="Field names with a ~ have values which change the existing value in the filter list.") + +        view = FilterListEditView( +            filter_list, +            list_type, +            settings, +            self.loaded_settings, +            ctx.author, +            embed, +            self._patch_filter_list +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) + +    @filterlist.command(name="delete", aliases=("remove",)) +    @has_any_role(Roles.admins) +    async def fl_delete( +        self, ctx: Context, list_type: Optional[list_type_converter] = None, list_name: Optional[str] = None +    ) -> None: +        """Remove the filter list and all of its filters from the database.""" +        async def delete_list() -> None: +            """The actual removal routine.""" +            list_data = await bot.instance.api_client.get(f"bot/filter/filter_lists/{list_id}") +            file = discord.File(BytesIO(json.dumps(list_data, indent=4).encode("utf-8")), f"{list_description}.json") +            message = await ctx.send("⏳ Annihilation in progress, please hold...", file=file) +            # Unload the filter list. +            filter_list.pop(list_type) +            if not filter_list:  # There's nothing left, remove from the cog. +                self.filter_lists.pop(filter_list.name) +                self.unsubscribe(filter_list) + +            await bot.instance.api_client.delete(f"bot/filter/filter_lists/{list_id}") +            log.info(f"Successfully deleted the {filter_list[list_type].label} filterlist.") +            await message.edit(content=f"✅ The {list_description} list has been deleted.") + +        result = await self._resolve_list_type_and_name(ctx, list_type, list_name) +        if result is None: +            return +        list_type, filter_list = result +        list_id = filter_list[list_type].id +        list_description = filter_list[list_type].label +        await ctx.reply( +            f"Are you sure you want to delete the {list_description} list?", +            view=DeleteConfirmationView(ctx.author, delete_list) +        ) + +    # endregion +    # region: utility commands + +    @command(name="filter_report") +    async def force_send_weekly_report(self, ctx: Context) -> None: +        """Respond with a list of auto-infractions added in the last 7 days.""" +        await self.send_weekly_auto_infraction_report(ctx.channel) + +    # endregion +    # region: helper functions + +    def _load_raw_filter_list(self, list_data: dict) -> AtomicList | None: +        """Load the raw list data to the cog.""" +        list_name = list_data["name"] +        if list_name not in self.filter_lists: +            if list_name not in filter_list_types: +                if list_name not in self.already_warned: +                    log.warning( +                        f"A filter list named {list_name} was loaded from the database, but no matching class." +                    ) +                    self.already_warned.add(list_name) +                return None +            self.filter_lists[list_name] = filter_list_types[list_name](self) +        return self.filter_lists[list_name].add_list(list_data) + +    async def _fetch_or_generate_filtering_webhook(self) -> discord.Webhook | None: +        """Generate a webhook with the filtering avatar.""" +        alerts_channel = self.bot.get_guild(Guild.id).get_channel(Channels.mod_alerts) +        # Try to find an existing webhook. +        for webhook in await alerts_channel.webhooks(): +            if webhook.name == WEBHOOK_NAME and webhook.user == self.bot.user and webhook.is_authenticated(): +                log.trace(f"Found existing filters webhook with ID {webhook.id}.") +                return webhook + +        # Download the filtering avatar from the branding repository. +        webhook_icon = None +        async with self.bot.http_session.get(WEBHOOK_ICON_URL, params=PARAMS, headers=HEADERS) as response: +            if response.status == 200: +                log.debug("Successfully fetched filtering webhook icon, reading payload.") +                webhook_icon = await response.read() +            else: +                log.warning(f"Failed to fetch filtering webhook icon due to status: {response.status}") + +        # Generate a new webhook. +        try: +            webhook = await alerts_channel.create_webhook(name=WEBHOOK_NAME, avatar=webhook_icon) +            log.trace(f"Generated new filters webhook with ID {webhook.id},") +            return webhook +        except HTTPException as e: +            log.error(f"Failed to create filters webhook: {e}") +            return None + +    async def _resolve_action( +        self, ctx: FilterContext +    ) -> tuple[ActionSettings | None, dict[FilterList, list[str]], dict[AtomicList, list[Filter]]]: +        """ +        Return the actions that should be taken for all filter lists in the given context. + +        Additionally, a message is possibly provided from each filter list describing the triggers, +        which should be relayed to the moderators. +        """ +        actions = [] +        messages = {} +        triggers = {} +        for filter_list in self._subscriptions[ctx.event]: +            list_actions, list_message, list_triggers = await filter_list.actions_for(ctx) +            triggers.update({filter_list[list_type]: filters for list_type, filters in list_triggers.items()}) +            if list_actions: +                actions.append(list_actions) +            if list_message: +                messages[filter_list] = list_message + +        result_actions = None +        if actions: +            result_actions = reduce(ActionSettings.union, actions) + +        return result_actions, messages, triggers + +    async def _send_alert(self, ctx: FilterContext, triggered_filters: dict[FilterList, Iterable[str]]) -> None: +        """Build an alert message from the filter context, and send it via the alert webhook.""" +        if not self.webhook: +            return + +        name = f"{ctx.event.name.replace('_', ' ').title()} Filter" +        embed = await build_mod_alert(ctx, triggered_filters) +        # There shouldn't be more than 10, but if there are it's not very useful to send them all. +        await self.webhook.send( +            username=name, content=ctx.alert_content, embeds=[embed, *ctx.alert_embeds][:10], view=AlertView(ctx) +        ) + +    def _increment_stats(self, triggered_filters: dict[AtomicList, list[Filter]]) -> None: +        """Increment the stats for every filter triggered.""" +        for filters in triggered_filters.values(): +            for filter_ in filters: +                if isinstance(filter_, UniqueFilter): +                    self.bot.stats.incr(f"filters.{filter_.name}") + +    async def _recently_alerted_name(self, member: discord.Member) -> bool: +        """When it hasn't been `HOURS_BETWEEN_NICKNAME_ALERTS` since last alert, return False, otherwise True.""" +        if last_alert := await self.name_alerts.get(member.id): +            last_alert = arrow.get(last_alert) +            if arrow.utcnow() - last_alert < datetime.timedelta(days=HOURS_BETWEEN_NICKNAME_ALERTS): +                log.trace(f"Last alert was too recent for {member}'s nickname.") +                return True + +        return False + +    @lock_arg("filtering.check_bad_name", "ctx", attrgetter("author.id")) +    async def _check_bad_name(self, ctx: FilterContext) -> None: +        """Check filter triggers in the passed context - a member's display name.""" +        if await self._recently_alerted_name(ctx.author): +            return + +        name = ctx.content +        normalised_name = unicodedata.normalize("NFKC", name) +        cleaned_normalised_name = "".join([c for c in normalised_name if not unicodedata.combining(c)]) + +        # Run filters against normalised, cleaned normalised and the original name, +        # in case there are filters for one but not another. +        names_to_check = (name, normalised_name, cleaned_normalised_name) + +        new_ctx = ctx.replace(content=" ".join(names_to_check)) +        result_actions, list_messages, triggers = await self._resolve_action(new_ctx) +        if result_actions: +            await result_actions.action(ctx) +        if ctx.send_alert: +            await self._send_alert(ctx, list_messages)  # `ctx` has the original content. +            # Update time when alert sent +            await self.name_alerts.set(ctx.author.id, arrow.utcnow().timestamp()) +        self._increment_stats(triggers) + +    async def _resolve_list_type_and_name( +        self, ctx: Context, list_type: ListType | None = None, list_name: str | None = None, *, exclude: str = "" +    ) -> tuple[ListType, FilterList] | None: +        """Prompt the user to complete the list type or list name if one of them is missing.""" +        if list_name is None: +            args = [list_type] if exclude != "list_type" else [] +            await ctx.send( +                "The **list_name** argument is unspecified. Please pick a value from the options below:", +                view=ArgumentCompletionView(ctx, args, "list_name", list(self.filter_lists), 1, None) +            ) +            return None + +        filter_list = self._get_list_by_name(list_name) +        if list_type is None: +            if len(filter_list) > 1: +                args = [list_name] if exclude != "list_name" else [] +                await ctx.send( +                    "The **list_type** argument is unspecified. Please pick a value from the options below:", +                    view=ArgumentCompletionView( +                        ctx, args, "list_type", [option.name for option in ListType], 0, list_type_converter +                    ) +                ) +                return None +            list_type = list(filter_list)[0] +        return list_type, filter_list + +    def _get_list_by_name(self, list_name: str) -> FilterList: +        """Get a filter list by its name, or raise an error if there's no such list.""" +        log.trace(f"Getting the filter list matching the name {list_name}") +        filter_list = self.filter_lists.get(list_name) +        if not filter_list: +            if list_name.endswith("s"):  # The user may have attempted to use the plural form. +                filter_list = self.filter_lists.get(list_name[:-1]) +            if not filter_list: +                raise BadArgument(f"There's no filter list named {list_name!r}.") +        log.trace(f"Found list named {filter_list.name}") +        return filter_list + +    @staticmethod +    async def _send_list(ctx: Context, filter_list: FilterList, list_type: ListType) -> None: +        """Show the list of filters identified by the list name and type.""" +        if list_type not in filter_list: +            await ctx.send(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.") +            return + +        lines = list(map(str, filter_list[list_type].filters.values())) +        log.trace(f"Sending a list of {len(lines)} filters.") + +        embed = Embed(colour=Colour.blue()) +        embed.set_author(name=f"List of {filter_list[list_type].label}s ({len(lines)} total)") + +        await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True) + +    def _get_filter_by_id(self, id_: int) -> Optional[tuple[Filter, FilterList, ListType]]: +        """Get the filter object corresponding to the provided ID, along with its containing list and list type.""" +        for filter_list in self.filter_lists.values(): +            for list_type, sublist in filter_list.items(): +                if id_ in sublist.filters: +                    return sublist.filters[id_], filter_list, list_type + +    async def _add_filter( +        self, +        ctx: Context, +        noui: Optional[Literal["noui"]], +        list_type: ListType, +        filter_list: FilterList, +        content: str, +        description_and_settings: Optional[str] = None +    ) -> None: +        """Add a filter to the database.""" +        # Validations. +        if list_type not in filter_list: +            await ctx.reply(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.") +            return +        filter_type = filter_list.get_filter_type(content) +        if not filter_type: +            await ctx.reply(f":x: Could not find a filter type appropriate for `{content}`.") +            return +        # Parse the description and settings. +        description, settings, filter_settings = description_and_settings_converter( +            filter_list, +            list_type, +            filter_type, +            self.loaded_settings, +            self.loaded_filter_settings, +            description_and_settings +        ) + +        if noui:  # Add directly with no UI. +            try: +                await self._post_new_filter( +                    ctx.message, filter_list, list_type, filter_type, content, description, settings, filter_settings +                ) +            except ResponseCodeError as e: +                await ctx.reply(embed=format_response_error(e)) +            except ValueError as e: +                raise BadArgument(str(e)) +            return +        # Bring up the UI. +        embed = Embed(colour=Colour.blue()) +        embed.description = f"`{content}`" if content else "*No content*" +        if description: +            embed.description += f" - {description}" +        embed.set_author( +            name=f"New Filter - {filter_list[list_type].label}".title()) +        embed.set_footer(text=( +            "Field names with an asterisk have values which override the defaults of the containing filter list. " +            f"To view all defaults of the list, " +            f"run `{constants.Bot.prefix}filterlist describe {list_type.name} {filter_list.name}`." +        )) + +        view = filters_ui.FilterEditView( +            filter_list, +            list_type, +            filter_type, +            content, +            description, +            settings, +            filter_settings, +            self.loaded_settings, +            self.loaded_filter_settings, +            ctx.author, +            embed, +            self._post_new_filter +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) + +    @staticmethod +    def _identical_filters_message(content: str, filter_list: FilterList, list_type: ListType, filter_: Filter) -> str: +        """Returns all the filters in the list with content identical to the content supplied.""" +        if list_type not in filter_list: +            return "" +        duplicates = [ +            f for f in filter_list[list_type].filters.values() +            if f.content == content and f.id != filter_.id +        ] +        msg = "" +        if duplicates: +            msg = f"\n:warning: The filter(s) #{', #'.join(str(dup.id) for dup in duplicates)} have the same content. " +            msg += "Please make sure this is intentional." + +        return msg + +    @staticmethod +    async def _maybe_alert_auto_infraction( +        filter_list: FilterList, list_type: ListType, filter_: Filter, old_filter: Filter | None = None +    ) -> None: +        """If the filter is new and applies an auto-infraction, or was edited to apply a different one, log it.""" +        infraction_type = filter_.overrides[0].get("infraction_type") +        if not infraction_type: +            infraction_type = filter_list[list_type].default("infraction_type") +        if old_filter: +            old_infraction_type = old_filter.overrides[0].get("infraction_type") +            if not old_infraction_type: +                old_infraction_type = filter_list[list_type].default("infraction_type") +            if infraction_type == old_infraction_type: +                return + +        if infraction_type != Infraction.NONE: +            filter_log = bot.instance.get_channel(Channels.filter_log) +            if filter_log: +                await filter_log.send( +                    f":warning: Heads up! The new {filter_list[list_type].label} filter " +                    f"({filter_}) will automatically {infraction_type.name.lower()} users." +                ) + +    async def _post_new_filter( +        self, +        msg: Message, +        filter_list: FilterList, +        list_type: ListType, +        filter_type: type[Filter], +        content: str, +        description: str | None, +        settings: dict, +        filter_settings: dict +    ) -> None: +        """POST the data of the new filter to the site API.""" +        valid, error_msg = filter_type.validate_filter_settings(filter_settings) +        if not valid: +            raise BadArgument(f"Error while validating filter-specific settings: {error_msg}") + +        content, description = await filter_type.process_input(content, description) + +        list_id = filter_list[list_type].id +        description = description or None +        payload = { +            "filter_list": list_id, "content": content, "description": description, +            "additional_settings": filter_settings, **settings +        } +        response = await bot.instance.api_client.post('bot/filter/filters', json=to_serializable(payload)) +        new_filter = filter_list.add_filter(list_type, response) +        log.info(f"Added new filter: {new_filter}.") +        if new_filter: +            await self._maybe_alert_auto_infraction(filter_list, list_type, new_filter) +            extra_msg = Filtering._identical_filters_message(content, filter_list, list_type, new_filter) +            await msg.reply(f"✅ Added filter: {new_filter}" + extra_msg) +        else: +            await msg.reply(":x: Could not create the filter. Are you sure it's implemented?") + +    async def _patch_filter( +        self, +        filter_: Filter, +        msg: Message, +        filter_list: FilterList, +        list_type: ListType, +        filter_type: type[Filter], +        content: str, +        description: str | None, +        settings: dict, +        filter_settings: dict +    ) -> None: +        """PATCH the new data of the filter to the site API.""" +        valid, error_msg = filter_type.validate_filter_settings(filter_settings) +        if not valid: +            raise BadArgument(f"Error while validating filter-specific settings: {error_msg}") + +        if content != filter_.content: +            content, description = await filter_type.process_input(content, description) + +        # If the setting is not in `settings`, the override was either removed, or there wasn't one in the first place. +        for current_settings in (filter_.actions, filter_.validations): +            if current_settings: +                for setting_entry in current_settings.values(): +                    settings.update({setting: None for setting in setting_entry.dict() if setting not in settings}) + +        # Even though the list ID remains unchanged, it still needs to be provided for correct serializer validation. +        list_id = filter_list[list_type].id +        description = description or None +        payload = { +            "filter_list": list_id, "content": content, "description": description, +            "additional_settings": filter_settings, **settings +        } +        response = await bot.instance.api_client.patch( +            f'bot/filter/filters/{filter_.id}', json=to_serializable(payload) +        ) +        # Return type can be None, but if it's being edited then it's not supposed to be. +        edited_filter = filter_list.add_filter(list_type, response) +        log.info(f"Successfully patched filter {edited_filter}.") +        await self._maybe_alert_auto_infraction(filter_list, list_type, edited_filter, filter_) +        extra_msg = Filtering._identical_filters_message(content, filter_list, list_type, edited_filter) +        await msg.reply(f"✅ Edited filter: {edited_filter}" + extra_msg) + +    async def _post_filter_list(self, msg: Message, list_name: str, list_type: ListType, settings: dict) -> None: +        """POST the new data of the filter list to the site API.""" +        payload = {"name": list_name, "list_type": list_type.value, **to_serializable(settings)} +        filterlist_name = f"{past_tense(list_type.name.lower())} {list_name}" +        response = await bot.instance.api_client.post('bot/filter/filter_lists', json=payload) +        log.info(f"Successfully posted the new {filterlist_name} filterlist.") +        self._load_raw_filter_list(response) +        await msg.reply(f"✅ Added a new filter list: {filterlist_name}") + +    @staticmethod +    async def _patch_filter_list(msg: Message, filter_list: FilterList, list_type: ListType, settings: dict) -> None: +        """PATCH the new data of the filter list to the site API.""" +        list_id = filter_list[list_type].id +        response = await bot.instance.api_client.patch( +            f'bot/filter/filter_lists/{list_id}', json=to_serializable(settings) +        ) +        log.info(f"Successfully patched the {filter_list[list_type].label} filterlist, reloading...") +        filter_list.pop(list_type, None) +        filter_list.add_list(response) +        await msg.reply(f"✅ Edited filter list: {filter_list[list_type].label}") + +    def _filter_match_query( +        self, filter_: Filter, settings_query: dict, filter_settings_query: dict, differ_by_default: set[str] +    ) -> bool: +        """Return whether the given filter matches the query.""" +        override_matches = set() +        overrides, _ = filter_.overrides +        for setting_name, setting_value in settings_query.items(): +            if setting_name not in overrides: +                continue +            if repr_equals(overrides[setting_name], setting_value): +                override_matches.add(setting_name) +            else:  # If an override doesn't match then the filter doesn't match. +                return False +        if not (differ_by_default <= override_matches):  # The overrides didn't cover for the default mismatches. +            return False + +        filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {} +        # If the dict changes then some fields were not the same. +        return (filter_settings | filter_settings_query) == filter_settings + +    def _search_filter_list( +        self, atomic_list: AtomicList, filter_type: type[Filter] | None, settings: dict, filter_settings: dict +    ) -> list[Filter]: +        """Find all filters in the filter list which match the settings.""" +        # If the default answers are known, only the overrides need to be checked for each filter. +        all_defaults = atomic_list.defaults.dict() +        match_by_default = set() +        differ_by_default = set() +        for setting_name, setting_value in settings.items(): +            if repr_equals(all_defaults[setting_name], setting_value): +                match_by_default.add(setting_name) +            else: +                differ_by_default.add(setting_name) + +        result_filters = [] +        for filter_ in atomic_list.filters.values(): +            if filter_type and not isinstance(filter_, filter_type): +                continue +            if self._filter_match_query(filter_, settings, filter_settings, differ_by_default): +                result_filters.append(filter_) + +        return result_filters + +    async def _search_filters( +        self, message: Message, filter_type: type[Filter] | None, settings: dict, filter_settings: dict +    ) -> None: +        """Find all filters which match the settings and display them.""" +        lines = [] +        result_count = 0 +        for filter_list in self.filter_lists.values(): +            if filter_type and filter_type not in filter_list.filter_types: +                continue +            for atomic_list in filter_list.values(): +                list_results = self._search_filter_list(atomic_list, filter_type, settings, filter_settings) +                if list_results: +                    lines.append(f"**{atomic_list.label.title()}**") +                    lines.extend(map(str, list_results)) +                    lines.append("") +                    result_count += len(list_results) + +        embed = Embed(colour=Colour.blue()) +        embed.set_author(name=f"Search Results ({result_count} total)") +        ctx = await bot.instance.get_context(message) +        await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True) + +    async def _delete_offensive_msg(self, msg: Mapping[str, int]) -> None: +        """Delete an offensive message, and then delete it from the DB.""" +        try: +            channel = self.bot.get_channel(msg['channel_id']) +            if channel: +                msg_obj = await channel.fetch_message(msg['id']) +                await msg_obj.delete() +        except discord.NotFound: +            log.info( +                f"Tried to delete message {msg['id']}, but the message can't be found " +                f"(it has been probably already deleted)." +            ) +        except HTTPException as e: +            log.warning(f"Failed to delete message {msg['id']}: status {e.status}") + +        await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') +        log.info(f"Deleted the offensive message with id {msg['id']}.") + +    def _schedule_msg_delete(self, msg: dict) -> None: +        """Delete an offensive message once its deletion date is reached.""" +        delete_at = arrow.get(msg['delete_date']).datetime +        self.delete_scheduler.schedule_at(delete_at, msg['id'], self._delete_offensive_msg(msg)) + +    async def _maybe_schedule_msg_delete(self, ctx: FilterContext, actions: ActionSettings | None) -> None: +        """Post the message to the database and schedule it for deletion if it's not set to be deleted already.""" +        msg = ctx.message +        if not msg or not actions or actions.get_setting("remove_context", True): +            return + +        delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() +        data = { +            'id': msg.id, +            'channel_id': msg.channel.id, +            'delete_date': delete_date +        } + +        try: +            await self.bot.api_client.post('bot/offensive-messages', json=data) +        except ResponseCodeError as e: +            if e.status == 400 and "already exists" in e.response_json.get("id", [""])[0]: +                log.debug(f"Offensive message {msg.id} already exists.") +            else: +                log.error(f"Offensive message {msg.id} failed to post: {e}") +        else: +            self._schedule_msg_delete(data) +            log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") + +    # endregion +    # region: tasks + +    @tasks.loop(time=datetime.time(hour=18)) +    async def weekly_auto_infraction_report_task(self) -> None: +        """Trigger an auto-infraction report to be sent if it is the desired day of the week (WEEKLY_REPORT_ISO_DAY).""" +        if arrow.utcnow().isoweekday() != WEEKLY_REPORT_ISO_DAY: +            return + +        await self.send_weekly_auto_infraction_report() + +    async def send_weekly_auto_infraction_report(self, channel: discord.TextChannel | discord.Thread = None) -> None: +        """ +        Send a list of auto-infractions added in the last 7 days to the specified channel. + +        If `channel` is not specified, it is sent to #mod-meta. +        """ +        log.trace("Preparing weekly auto-infraction report.") +        seven_days_ago = arrow.utcnow().shift(days=-7) +        if not channel: +            log.info("Auto-infraction report: the channel to report to is missing.") +            channel = self.bot.get_channel(Channels.mod_meta) +        elif not is_mod_channel(channel): +            # Silently fail if output is going to be a non-mod channel. +            log.info(f"Auto-infraction report: the channel {channel} is not a mod channel.") +            return + +        found_filters = defaultdict(list) +        # Extract all auto-infraction filters added in the past 7 days from each filter type +        for filter_list in self.filter_lists.values(): +            for sublist in filter_list.values(): +                default_infraction_type = sublist.default("infraction_type") +                for filter_ in sublist.filters.values(): +                    if max(filter_.created_at, filter_.updated_at) < seven_days_ago: +                        continue +                    infraction_type = filter_.overrides[0].get("infraction_type") +                    if ( +                        (infraction_type and infraction_type != Infraction.NONE) +                        or (not infraction_type and default_infraction_type != Infraction.NONE) +                    ): +                        found_filters[sublist.label].append((filter_, infraction_type or default_infraction_type)) + +        # Nicely format the output so each filter list type is grouped +        lines = [f"**Auto-infraction filters added since {seven_days_ago.format('YYYY-MM-DD')}**"] +        for list_label, filters in found_filters.items(): +            lines.append("\n".join([f"**{list_label.title()}**"]+[f"{filter_} ({infr})" for filter_, infr in filters])) + +        if len(lines) == 1: +            lines.append("Nothing to show") + +        await channel.send("\n\n".join(lines)) +        log.info("Successfully sent auto-infraction report.") + +    # endregion + +    async def cog_unload(self) -> None: +        """Cancel the weekly auto-infraction filter report and deletion scheduling on cog unload.""" +        self.weekly_auto_infraction_report_task.cancel() +        self.delete_scheduler.cancel_all() + + +async def setup(bot: Bot) -> None: +    """Load the Filtering cog.""" +    await bot.add_cog(Filtering(bot)) diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py deleted file mode 100644 index 0a72a6db7..000000000 --- a/bot/exts/filters/antimalware.py +++ /dev/null @@ -1,106 +0,0 @@ -import typing as t -from os.path import splitext - -from discord import Embed, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import BaseURLs, Channels, Filter -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME -from bot.log import get_logger - -log = get_logger(__name__) - -PY_EMBED_DESCRIPTION = ( -    "It looks like you tried to attach a Python file - " -    f"please use a code-pasting service such as {BaseURLs.site_paste}" -) - -TXT_LIKE_FILES = {".txt", ".csv", ".json"} -TXT_EMBED_DESCRIPTION = ( -    "You either uploaded a `{blocked_extension}` file or entered a message that was too long. " -    f"Please use our [paste bin]({BaseURLs.site_paste}) instead." -) - -DISALLOWED_EMBED_DESCRIPTION = ( -    "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " -    "We currently allow the following file types: **{joined_whitelist}**.\n\n" -    "Feel free to ask in {meta_channel_mention} if you think this is a mistake." -) - - -class AntiMalware(Cog): -    """Delete messages which contain attachments with non-whitelisted file extensions.""" - -    def __init__(self, bot: Bot): -        self.bot = bot - -    def _get_whitelisted_file_formats(self) -> list: -        """Get the file formats currently on the whitelist.""" -        return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() - -    def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: -        """Get an iterable containing all the disallowed extensions of attachments.""" -        file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} -        extensions_blocked = file_extensions - set(self._get_whitelisted_file_formats()) -        return extensions_blocked - -    @Cog.listener() -    async def on_message(self, message: Message) -> None: -        """Identify messages with prohibited attachments.""" -        # Return when message don't have attachment and don't moderate DMs -        if not message.attachments or not message.guild: -            return - -        # Ignore webhook and bot messages -        if message.webhook_id or message.author.bot: -            return - -        # Ignore code jam channels -        if getattr(message.channel, "category", None) and message.channel.category.name == JAM_CATEGORY_NAME: -            return - -        # Check if user is staff, if is, return -        # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance -        if hasattr(message.author, "roles") and any(role.id in Filter.role_whitelist for role in message.author.roles): -            return - -        embed = Embed() -        extensions_blocked = self._get_disallowed_extensions(message) -        blocked_extensions_str = ', '.join(extensions_blocked) -        if ".py" in extensions_blocked: -            # Short-circuit on *.py files to provide a pastebin link -            embed.description = PY_EMBED_DESCRIPTION -        elif extensions := TXT_LIKE_FILES.intersection(extensions_blocked): -            # Work around Discord AutoConversion of messages longer than 2000 chars to .txt -            cmd_channel = self.bot.get_channel(Channels.bot_commands) -            embed.description = TXT_EMBED_DESCRIPTION.format( -                blocked_extension=extensions.pop(), -                cmd_channel_mention=cmd_channel.mention -            ) -        elif extensions_blocked: -            meta_channel = self.bot.get_channel(Channels.meta) -            embed.description = DISALLOWED_EMBED_DESCRIPTION.format( -                joined_whitelist=', '.join(self._get_whitelisted_file_formats()), -                blocked_extensions_str=blocked_extensions_str, -                meta_channel_mention=meta_channel.mention, -            ) - -        if embed.description: -            log.info( -                f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", -                extra={"attachment_list": [attachment.filename for attachment in message.attachments]} -            ) - -            await message.channel.send(f"Hey {message.author.mention}!", embed=embed) - -            # Delete the offending message: -            try: -                await message.delete() -            except NotFound: -                log.info(f"Tried to delete message `{message.id}`, but message could not be found.") - - -async def setup(bot: Bot) -> None: -    """Load the AntiMalware cog.""" -    await bot.add_cog(AntiMalware(bot)) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py deleted file mode 100644 index 0d02edabf..000000000 --- a/bot/exts/filters/antispam.py +++ /dev/null @@ -1,326 +0,0 @@ -import asyncio -from collections import defaultdict -from collections.abc import Mapping -from dataclasses import dataclass, field -from datetime import timedelta -from itertools import takewhile -from operator import attrgetter, itemgetter -from typing import Dict, Iterable, List, Set - -import arrow -from discord import Colour, Member, Message, MessageType, NotFound, TextChannel -from discord.ext.commands import Cog -from pydis_core.utils import scheduling - -from bot import rules -from bot.bot import Bot -from bot.constants import ( -    AntiSpam as AntiSpamConfig, Channels, Colours, DEBUG_MODE, Event, Filter, Guild as GuildConfig, Icons -) -from bot.converters import Duration -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME -from bot.exts.moderation.modlog import ModLog -from bot.log import get_logger -from bot.utils import lock -from bot.utils.message_cache import MessageCache -from bot.utils.messages import format_user, send_attachments - -log = get_logger(__name__) - -RULE_FUNCTION_MAPPING = { -    'attachments': rules.apply_attachments, -    'burst': rules.apply_burst, -    # burst shared is temporarily disabled due to a bug -    # 'burst_shared': rules.apply_burst_shared, -    'chars': rules.apply_chars, -    'discord_emojis': rules.apply_discord_emojis, -    'duplicates': rules.apply_duplicates, -    'links': rules.apply_links, -    'mentions': rules.apply_mentions, -    'newlines': rules.apply_newlines, -    'role_mentions': rules.apply_role_mentions, -} - -ANTI_SPAM_RULES = AntiSpamConfig.rules.dict() - - -@dataclass -class DeletionContext: -    """Represents a Deletion Context for a single spam event.""" - -    members: frozenset[Member] -    triggered_in: TextChannel -    channels: set[TextChannel] = field(default_factory=set) -    rules: Set[str] = field(default_factory=set) -    messages: Dict[int, Message] = field(default_factory=dict) -    attachments: List[List[str]] = field(default_factory=list) - -    async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None: -        """Adds new rule violation events to the deletion context.""" -        self.rules.add(rule_name) - -        self.channels.update(channels) - -        for message in messages: -            if message.id not in self.messages: -                self.messages[message.id] = message - -                # Re-upload attachments -                destination = message.guild.get_channel(Channels.attachment_log) -                urls = await send_attachments(message, destination, link_large=False) -                self.attachments.append(urls) - -    async def upload_messages(self, actor_id: int, modlog: ModLog) -> None: -        """Method that takes care of uploading the queue and posting modlog alert.""" -        triggered_by_users = ", ".join(format_user(m) for m in self.members) -        triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" -        channels_description = ", ".join(channel.mention for channel in self.channels) - -        mod_alert_message = ( -            f"**Triggered by:** {triggered_by_users}\n" -            f"{triggered_in_channel}" -            f"**Channels:** {channels_description}\n" -            f"**Rules:** {', '.join(rule for rule in self.rules)}\n" -        ) - -        messages_as_list = list(self.messages.values()) -        first_message = messages_as_list[0] -        # For multiple messages and those with attachments or excessive newlines, use the logs API -        if any(( -            len(messages_as_list) > 1, -            len(first_message.attachments) > 0, -            first_message.content.count('\n') > 15 -        )): -            url = await modlog.upload_log(self.messages.values(), actor_id, self.attachments) -            mod_alert_message += f"A complete log of the offending messages can be found [here]({url})" -        else: -            mod_alert_message += "Message:\n" -            content = first_message.clean_content -            remaining_chars = 4080 - len(mod_alert_message) - -            if len(content) > remaining_chars: -                url = await modlog.upload_log([first_message], actor_id, self.attachments) -                log_site_msg = f"The full message can be found [here]({url})" -                content = content[:remaining_chars - (3 + len(log_site_msg))] + "..." - -            mod_alert_message += content - -        await modlog.send_log_message( -            content=", ".join(str(m.id) for m in self.members),  # quality-of-life improvement for mobile moderators -            icon_url=Icons.filtering, -            colour=Colour(Colours.soft_red), -            title="Spam detected!", -            text=mod_alert_message, -            thumbnail=first_message.author.display_avatar.url, -            channel_id=Channels.mod_alerts, -            ping_everyone=AntiSpamConfig.ping_everyone -        ) - - -class AntiSpam(Cog): -    """Cog that controls our anti-spam measures.""" - -    def __init__(self, bot: Bot, validation_errors: Dict[str, str]) -> None: -        self.bot = bot -        self.validation_errors = validation_errors -        self.expiration_date_converter = Duration() - -        self.message_deletion_queue = dict() - -        # Fetch the rule configuration with the highest rule interval. -        max_interval_config = max( -            ANTI_SPAM_RULES.values(), -            key=itemgetter('interval') -        ) -        self.max_interval = max_interval_config['interval'] -        self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) - -    @property -    def mod_log(self) -> ModLog: -        """Allows for easy access of the ModLog cog.""" -        return self.bot.get_cog("ModLog") - -    async def cog_load(self) -> None: -        """Unloads the cog and alerts admins if configuration validation failed.""" -        await self.bot.wait_until_guild_available() -        if self.validation_errors: -            body = "**The following errors were encountered:**\n" -            body += "\n".join(f"- {error}" for error in self.validation_errors.values()) -            body += "\n\n**The cog has been unloaded.**" - -            await self.mod_log.send_log_message( -                title="Error: AntiSpam configuration validation failed!", -                text=body, -                ping_everyone=True, -                icon_url=Icons.token_removed, -                colour=Colour.red() -            ) - -            await self.bot.remove_cog(self.__class__.__name__) -            return - -    @Cog.listener() -    async def on_message(self, message: Message) -> None: -        """Applies the antispam rules to each received message.""" -        if ( -            not message.guild -            or message.guild.id != GuildConfig.id -            or message.author.bot -            or (getattr(message.channel, "category", None) and message.channel.category.name == JAM_CATEGORY_NAME) -            or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) -            or (any(role.id in Filter.role_whitelist for role in message.author.roles) and not DEBUG_MODE) -            or message.type == MessageType.auto_moderation_action -        ): -            return - -        self.cache.append(message) - -        earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.max_interval) -        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache)) - -        for rule_name, rule_config in ANTI_SPAM_RULES.items(): -            rule_function = RULE_FUNCTION_MAPPING[rule_name] - -            # Create a list of messages that were sent in the interval that the rule cares about. -            latest_interesting_stamp = arrow.utcnow() - timedelta(seconds=rule_config['interval']) -            messages_for_rule = list( -                takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages)  # noqa: B023 -            ) - -            result = await rule_function(message, messages_for_rule, rule_config) - -            # If the rule returns `None`, that means the message didn't violate it. -            # If it doesn't, it returns a tuple in the form `(str, Iterable[discord.Member])` -            # which contains the reason for why the message violated the rule and -            # an iterable of all members that violated the rule. -            if result is not None: -                self.bot.stats.incr(f"mod_alerts.{rule_name}") -                reason, members, relevant_messages = result -                full_reason = f"`{rule_name}` rule: {reason}" - -                # If there's no spam event going on for this channel, start a new Message Deletion Context -                authors_set = frozenset(members) -                if authors_set not in self.message_deletion_queue: -                    log.trace(f"Creating queue for members `{authors_set}`") -                    self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel) -                    scheduling.create_task( -                        self._process_deletion_context(authors_set), -                        name=f"AntiSpam._process_deletion_context({authors_set})" -                    ) - -                # Add the relevant of this trigger to the Deletion Context -                await self.message_deletion_queue[authors_set].add( -                    rule_name=rule_name, -                    channels=set(message.channel for message in relevant_messages), -                    messages=relevant_messages -                ) - -                for member in members: -                    scheduling.create_task( -                        self.punish(message, member, full_reason), -                        name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})" -                    ) - -                await self.maybe_delete_messages(relevant_messages) -                break - -    @lock.lock_arg("antispam.punish", "member", attrgetter("id")) -    async def punish(self, msg: Message, member: Member, reason: str) -> None: -        """Punishes the given member for triggering an antispam rule.""" -        if not member.is_timed_out(): -            remove_timeout_after = AntiSpamConfig.remove_timeout_after - -            # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes -            context = await self.bot.get_context(msg) -            command = self.bot.get_command("timeout") -            context.author = context.guild.get_member(self.bot.user.id) -            context.command = command - -            # Since we're going to invoke the timeout command directly, we need to manually call the converter. -            dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_timeout_after}S") -            await context.invoke( -                command, -                member, -                dt_remove_role_after, -                reason=reason -            ) - -    async def maybe_delete_messages(self, messages: List[Message]) -> None: -        """Cleans the messages if cleaning is configured.""" -        if AntiSpamConfig.clean_offending: -            # If we have more than one message, we can use bulk delete. -            if len(messages) > 1: -                message_ids = [message.id for message in messages] -                self.mod_log.ignore(Event.message_delete, *message_ids) -                channel_messages = defaultdict(list) -                for message in messages: -                    channel_messages[message.channel].append(message) -                for channel, messages in channel_messages.items(): -                    try: -                        await channel.delete_messages(messages) -                    except NotFound: -                        # In the rare case where we found messages matching the -                        # spam filter across multiple channels, it is possible -                        # that a single channel will only contain a single message -                        # to delete. If that should be the case, discord.py will -                        # use the "delete single message" endpoint instead of the -                        # bulk delete endpoint, and the single message deletion -                        # endpoint will complain if you give it that does not exist. -                        # As this means that we have no other message to delete in -                        # this channel (and message deletes work per-channel), -                        # we can just log an exception and carry on with business. -                        log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - -            # Otherwise, the bulk delete endpoint will throw up. -            # Delete the message directly instead. -            else: -                self.mod_log.ignore(Event.message_delete, messages[0].id) -                try: -                    await messages[0].delete() -                except NotFound: -                    log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") - -    async def _process_deletion_context(self, context_id: frozenset) -> None: -        """Processes the Deletion Context queue.""" -        log.trace("Sleeping before processing message deletion queue.") -        await asyncio.sleep(10) - -        if context_id not in self.message_deletion_queue: -            log.error(f"Started processing deletion queue for context `{context_id}`, but it was not found!") -            return - -        deletion_context = self.message_deletion_queue.pop(context_id) -        await deletion_context.upload_messages(self.bot.user.id, self.mod_log) - -    @Cog.listener() -    async def on_message_edit(self, before: Message, after: Message) -> None: -        """Updates the message in the cache, if it's cached.""" -        self.cache.update(after) - - -def validate_config(rules_: Mapping = ANTI_SPAM_RULES) -> Dict[str, str]: -    """Validates the antispam configs.""" -    validation_errors = {} -    for name, config in rules_.items(): -        config = config -        if name not in RULE_FUNCTION_MAPPING: -            log.error( -                f"Unrecognized antispam rule `{name}`. " -                f"Valid rules are: {', '.join(RULE_FUNCTION_MAPPING)}" -            ) -            validation_errors[name] = f"`{name}` is not recognized as an antispam rule." -            continue -        for required_key in ('interval', 'max'): -            if required_key not in config: -                log.error( -                    f"`{required_key}` is required but was not " -                    f"set in rule `{name}`'s configuration." -                ) -                validation_errors[name] = f"Key `{required_key}` is required but not set for rule `{name}`" -    return validation_errors - - -async def setup(bot: Bot) -> None: -    """Validate the AntiSpam configs and load the AntiSpam cog.""" -    validation_errors = validate_config() -    await bot.add_cog(AntiSpam(bot, validation_errors)) diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py deleted file mode 100644 index 538744204..000000000 --- a/bot/exts/filters/filter_lists.py +++ /dev/null @@ -1,359 +0,0 @@ -import datetime -import re -from collections import defaultdict -from typing import Optional - -import arrow -import discord -from discord.ext import tasks -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, command, group, has_any_role -from pydis_core.site_api import ResponseCodeError - -from bot import constants -from bot.bot import Bot -from bot.constants import Channels, Colours -from bot.converters import ValidDiscordServerInvite, ValidFilterListType -from bot.log import get_logger -from bot.pagination import LinePaginator -from bot.utils.channel import is_mod_channel - -log = get_logger(__name__) -WEEKLY_REPORT_ISO_DAY = 3  # 1=Monday, 7=Sunday - - -class FilterLists(Cog): -    """Commands for blacklisting and whitelisting things.""" - -    methods_with_filterlist_types = [ -        "allow_add", -        "allow_delete", -        "allow_get", -        "deny_add", -        "deny_delete", -        "deny_get", -    ] - -    def __init__(self, bot: Bot) -> None: -        self.bot = bot - -    async def cog_load(self) -> None: -        """Add the valid FilterList types to the docstrings, so they'll appear in !help invocations.""" -        await self.bot.wait_until_guild_available() -        self.weekly_autoban_report_task.start() - -        # Add valid filterlist types to the docstrings -        valid_types = await ValidFilterListType.get_valid_types(self.bot) -        valid_types = [f"`{type_.lower()}`" for type_ in valid_types] - -        for method_name in self.methods_with_filterlist_types: -            command = getattr(self, method_name) -            command.help = ( -                f"{command.help}\n\nValid **list_type** values are {', '.join(valid_types)}." -            ) - -    async def _add_data( -        self, -        ctx: Context, -        allowed: bool, -        list_type: ValidFilterListType, -        content: str, -        comment: Optional[str] = None, -    ) -> None: -        """Add an item to a filterlist.""" -        allow_type = "whitelist" if allowed else "blacklist" - -        # If this is a guild invite, we gotta validate it. -        if list_type == "GUILD_INVITE": -            guild_data = await self._validate_guild_invite(ctx, content) -            content = guild_data.get("id") - -            # Some guild invites are autoban filters, which require the mod -            # to set a comment which includes [autoban]. -            # Having the guild name in the comment is still useful when reviewing -            # filter list, so prepend it to the set comment in case some mod forgets. -            guild_name_part = f'Guild "{guild_data["name"]}"' if "name" in guild_data else None - -            comment = " - ".join( -                comment_part -                for comment_part in (guild_name_part, comment) -                if comment_part -            ) - -        # If it's a file format, let's make sure it has a leading dot. -        elif list_type == "FILE_FORMAT" and not content.startswith("."): -            content = f".{content}" - -        # If it's a filter token, validate the passed regex -        elif list_type == "FILTER_TOKEN": -            try: -                re.compile(content) -            except re.error as e: -                await ctx.message.add_reaction("❌") -                await ctx.send( -                    f"{ctx.author.mention} that's not a valid regex! " -                    f"Regex error message: {e.msg}." -                ) -                return - -        # Try to add the item to the database -        log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") -        payload = { -            "allowed": allowed, -            "type": list_type, -            "content": content, -            "comment": comment, -        } - -        try: -            item = await self.bot.api_client.post( -                "bot/filter-lists", -                json=payload -            ) -        except ResponseCodeError as e: -            if e.status == 400: -                await ctx.message.add_reaction("❌") -                log.debug( -                    f"{ctx.author} tried to add data to a {allow_type}, but the API returned 400, " -                    "probably because the request violated the UniqueConstraint." -                ) -                raise BadArgument( -                    f"Unable to add the item to the {allow_type}. " -                    "The item probably already exists. Keep in mind that a " -                    "blacklist and a whitelist for the same item cannot co-exist, " -                    "and we do not permit any duplicates." -                ) -            raise - -        # If it is an autoban trigger we send a warning in #filter-log -        if comment and "[autoban]" in comment: -            await self.bot.get_channel(Channels.filter_log).send( -                f":warning: Heads-up! The new `{list_type}` filter " -                f"`{content}` (`{comment}`) will automatically ban users." -            ) - -        # Insert the item into the cache -        self.bot.insert_item_into_filter_list_cache(item) -        await ctx.message.add_reaction("✅") - -    async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: -        """Remove an item from a filterlist.""" -        allow_type = "whitelist" if allowed else "blacklist" - -        # If this is a server invite, we need to convert it. -        if list_type == "GUILD_INVITE" and not IDConverter()._get_id_match(content): -            guild_data = await self._validate_guild_invite(ctx, content) -            content = guild_data.get("id") - -        # If it's a file format, let's make sure it has a leading dot. -        elif list_type == "FILE_FORMAT" and not content.startswith("."): -            content = f".{content}" - -        # Find the content and delete it. -        log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") -        item = self.bot.filter_list_cache[f"{list_type}.{allowed}"].get(content) - -        if item is not None: -            try: -                await self.bot.api_client.delete( -                    f"bot/filter-lists/{item['id']}" -                ) -                del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] -                await ctx.message.add_reaction("✅") -            except ResponseCodeError as e: -                log.debug( -                    f"{ctx.author} tried to delete an item with the id {item['id']}, but " -                    f"the API raised an unexpected error: {e}" -                ) -                await ctx.message.add_reaction("❌") -        else: -            await ctx.message.add_reaction("❌") - -    async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: -        """Paginate and display all items in a filterlist.""" -        allow_type = "whitelist" if allowed else "blacklist" -        result = self.bot.filter_list_cache[f"{list_type}.{allowed}"] - -        # Build a list of lines we want to show in the paginator -        lines = [] -        for content, metadata in result.items(): -            line = f"• `{content}`" - -            if comment := metadata.get("comment"): -                line += f" - {comment}" - -            lines.append(line) -        lines = sorted(lines) - -        # Build the embed -        list_type_plural = list_type.lower().replace("_", " ").title() + "s" -        embed = discord.Embed( -            title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", -            colour=Colours.blue -        ) -        log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - -        if result: -            await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) -        else: -            embed.description = "Hmmm, seems like there's nothing here yet." -            await ctx.send(embed=embed) -            await ctx.message.add_reaction("❌") - -    async def _sync_data(self, ctx: Context) -> None: -        """Syncs the filterlists with the API.""" -        try: -            log.trace("Attempting to sync FilterList cache with data from the API.") -            await self.bot.cache_filter_list_data() -            await ctx.message.add_reaction("✅") -        except ResponseCodeError as e: -            log.debug( -                f"{ctx.author} tried to sync FilterList cache data but " -                f"the API raised an unexpected error: {e}" -            ) -            await ctx.message.add_reaction("❌") - -    @staticmethod -    async def _validate_guild_invite(ctx: Context, invite: str) -> dict: -        """ -        Validates a guild invite, and returns the guild info as a dict. - -        Will raise a BadArgument if the guild invite is invalid. -        """ -        log.trace(f"Attempting to validate whether or not {invite} is a guild invite.") -        validator = ValidDiscordServerInvite() -        guild_data = await validator.convert(ctx, invite) - -        # If we make it this far without raising a BadArgument, the invite is -        # valid. Let's return a dict of guild information. -        log.trace(f"{invite} validated as server invite. Converting to ID.") -        return guild_data - -    @group(aliases=("allowlist", "allow", "al", "wl")) -    async def whitelist(self, ctx: Context) -> None: -        """Group for whitelisting commands.""" -        if not ctx.invoked_subcommand: -            await ctx.send_help(ctx.command) - -    @group(aliases=("denylist", "deny", "bl", "dl")) -    async def blacklist(self, ctx: Context) -> None: -        """Group for blacklisting commands.""" -        if not ctx.invoked_subcommand: -            await ctx.send_help(ctx.command) - -    @whitelist.command(name="add", aliases=("a", "set")) -    async def allow_add( -        self, -        ctx: Context, -        list_type: ValidFilterListType, -        content: str, -        *, -        comment: Optional[str] = None, -    ) -> None: -        """Add an item to the specified allowlist.""" -        await self._add_data(ctx, True, list_type, content, comment) - -    @blacklist.command(name="add", aliases=("a", "set")) -    async def deny_add( -        self, -        ctx: Context, -        list_type: ValidFilterListType, -        content: str, -        *, -        comment: Optional[str] = None, -    ) -> None: -        """Add an item to the specified denylist.""" -        await self._add_data(ctx, False, list_type, content, comment) - -    @whitelist.command(name="remove", aliases=("delete", "rm",)) -    async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: -        """Remove an item from the specified allowlist.""" -        await self._delete_data(ctx, True, list_type, content) - -    @blacklist.command(name="remove", aliases=("delete", "rm",)) -    async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: -        """Remove an item from the specified denylist.""" -        await self._delete_data(ctx, False, list_type, content) - -    @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) -    async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: -        """Get the contents of a specified allowlist.""" -        await self._list_all_data(ctx, True, list_type) - -    @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) -    async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: -        """Get the contents of a specified denylist.""" -        await self._list_all_data(ctx, False, list_type) - -    @whitelist.command(name="sync", aliases=("s",)) -    async def allow_sync(self, ctx: Context) -> None: -        """Syncs both allowlists and denylists with the API.""" -        await self._sync_data(ctx) - -    @blacklist.command(name="sync", aliases=("s",)) -    async def deny_sync(self, ctx: Context) -> None: -        """Syncs both allowlists and denylists with the API.""" -        await self._sync_data(ctx) - -    @command(name="filter_report") -    async def force_send_weekly_report(self, ctx: Context) -> None: -        """Respond with a list of autobans added in the last 7 days.""" -        await self.send_weekly_autoban_report(ctx.channel) - -    @tasks.loop(time=datetime.time(hour=18)) -    async def weekly_autoban_report_task(self) -> None: -        """Trigger autoban report to be sent if it is the desired day of the week (WEEKLY_REPORT_ISO_DAY).""" -        if arrow.utcnow().isoweekday() != WEEKLY_REPORT_ISO_DAY: -            return - -        await self.send_weekly_autoban_report() - -    async def send_weekly_autoban_report(self, channel: discord.abc.Messageable = None) -> None: -        """ -        Send a list of autobans added in the last 7 days to the specified channel. - -        If chanel is not specified, it is sent to #mod-meta. -        """ -        seven_days_ago = arrow.utcnow().shift(days=-7) -        if not channel: -            channel = self.bot.get_channel(Channels.mod_meta) -        elif not is_mod_channel(channel): -            # Silently fail if output is going to be a non-mod channel. -            return - -        added_autobans = defaultdict(list) -        # Extract all autoban filters added in the past 7 days from each filter type -        for filter_list, filters in self.bot.filter_list_cache.items(): -            filter_type, allow = filter_list.split(".") -            allow_type = "Allow list" if allow.lower() == "true" else "Deny list" - -            for filter_content, filter_details in filters.items(): -                created_at = arrow.get(filter_details["created_at"]) -                updated_at = arrow.get(filter_details["updated_at"]) -                # Default to empty string so that the in check below doesn't error on None type -                comment = filter_details["comment"] or "" -                if max(created_at, updated_at) > seven_days_ago and "[autoban]" in comment: -                    line = f"`{filter_content}`: {comment}" -                    added_autobans[f"**{filter_type} {allow_type}**"].append(line) - -        # Nicely format the output so each filter list type is grouped -        lines = [f"**Autoban filters added since {seven_days_ago.format('YYYY-MM-DD')}**"] -        for filter_list, recently_added_autobans in added_autobans.items(): -            lines.append("\n".join([filter_list]+recently_added_autobans)) - -        if len(lines) == 1: -            lines.append("Nothing to show") - -        await channel.send("\n\n".join(lines)) - -    async def cog_check(self, ctx: Context) -> bool: -        """Only allow moderators to invoke the commands in this cog.""" -        return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx) - -    async def cog_unload(self) -> None: -        """Cancel the weekly autoban filter report on cog unload.""" -        self.weekly_autoban_report_task.cancel() - - -async def setup(bot: Bot) -> None: -    """Load the FilterLists cog.""" -    await bot.add_cog(FilterLists(bot)) diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py deleted file mode 100644 index 23a6f2d92..000000000 --- a/bot/exts/filters/filtering.py +++ /dev/null @@ -1,743 +0,0 @@ -import asyncio -import re -import unicodedata -import urllib.parse -from datetime import timedelta -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union - -import arrow -import dateutil.parser -import regex -import tldextract -from async_rediscache import RedisCache -from dateutil.relativedelta import relativedelta -from discord import ChannelType, Colour, Embed, Forbidden, HTTPException, Member, Message, NotFound, TextChannel -from discord.ext.commands import Cog -from discord.utils import escape_markdown -from pydis_core.site_api import ResponseCodeError -from pydis_core.utils import scheduling -from pydis_core.utils.regex import DISCORD_INVITE - -from bot.bot import Bot -from bot.constants import Bot as BotConfig, Channels, Colours, Filter, Guild, Icons, URLs -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME -from bot.exts.moderation.modlog import ModLog -from bot.log import get_logger -from bot.utils.helpers import remove_subdomain_from_url -from bot.utils.messages import format_user - -log = get_logger(__name__) - - -# Regular expressions -CODE_BLOCK_RE = re.compile( -    r"(?P<delim>``?)[^`]+?(?P=delim)(?!`+)"  # Inline codeblock -    r"|```(.+?)```",  # Multiline codeblock -    re.DOTALL | re.MULTILINE -) -EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{Guild.id}>|@here") -SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) -URL_RE = re.compile(r"(https?://[^\s]+)", flags=re.IGNORECASE) - -# Exclude variation selectors from zalgo because they're actually invisible. -VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF" -INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1) -ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1) - -# Other constants. -DAYS_BETWEEN_ALERTS = 3 -OFFENSIVE_MSG_DELETE_TIME = timedelta(days=Filter.offensive_msg_delete_days) - -# Autoban -LINK_PASSWORD = "https://support.discord.com/hc/en-us/articles/218410947-I-forgot-my-Password-Where-can-I-set-a-new-one" -LINK_2FA = "https://support.discord.com/hc/en-us/articles/219576828-Setting-up-Two-Factor-Authentication" -AUTO_BAN_REASON = ( -    "Your account has been used to send links to a phishing website. You have been automatically banned. " -    "If you are not aware of sending them, that means your account has been compromised.\n\n" - -    f"Here is a guide from Discord on [how to change your password]({LINK_PASSWORD}).\n\n" - -    f"We also highly recommend that you [enable 2 factor authentication on your account]({LINK_2FA}), " -    "for heightened security.\n\n" - -    "Once you have changed your password, feel free to follow the instructions at the bottom of " -    "this message to appeal your ban." -) -AUTO_BAN_DURATION = timedelta(days=4) - -FilterMatch = Union[re.Match, dict, bool, List[Embed]] - - -class Stats(NamedTuple): -    """Additional stats on a triggered filter to append to a mod log.""" - -    message_content: str -    additional_embeds: Optional[List[Embed]] - - -class Filtering(Cog): -    """Filtering out invites, blacklisting domains, and warning us of certain regular expressions.""" - -    # Redis cache mapping a user ID to the last timestamp a bad nickname alert was sent -    name_alerts = RedisCache() - -    def __init__(self, bot: Bot): -        self.bot = bot -        self.scheduler = scheduling.Scheduler(self.__class__.__name__) -        self.name_lock = asyncio.Lock() - -        staff_mistake_str = "If you believe this was a mistake, please let staff know!" -        self.filters = { -            "filter_zalgo": { -                "enabled": Filter.filter_zalgo, -                "function": self._has_zalgo, -                "type": "filter", -                "content_only": True, -                "user_notification": Filter.notify_user_zalgo, -                "notification_msg": ( -                    "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " -                    f"{staff_mistake_str}" -                ), -                "schedule_deletion": False -            }, -            "filter_invites": { -                "enabled": Filter.filter_invites, -                "function": self._has_invites, -                "type": "filter", -                "content_only": True, -                "user_notification": Filter.notify_user_invites, -                "notification_msg": ( -                    f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" -                    r"Our server rules can be found here: <https://pythondiscord.com/pages/rules>" -                ), -                "schedule_deletion": False -            }, -            "filter_domains": { -                "enabled": Filter.filter_domains, -                "function": self._has_urls, -                "type": "filter", -                "content_only": True, -                "user_notification": Filter.notify_user_domains, -                "notification_msg": ( -                    f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" -                ), -                "schedule_deletion": False -            }, -            "watch_regex": { -                "enabled": Filter.watch_regex, -                "function": self._has_watch_regex_match, -                "type": "watchlist", -                "content_only": True, -                "schedule_deletion": True -            }, -            "watch_rich_embeds": { -                "enabled": Filter.watch_rich_embeds, -                "function": self._has_rich_embed, -                "type": "watchlist", -                "content_only": False, -                "schedule_deletion": False -            }, -            "filter_everyone_ping": { -                "enabled": Filter.filter_everyone_ping, -                "function": self._has_everyone_ping, -                "type": "filter", -                "content_only": True, -                "user_notification": Filter.notify_user_everyone_ping, -                "notification_msg": ( -                    "Please don't try to ping `@everyone` or `@here`. " -                    f"Your message has been removed. {staff_mistake_str}" -                ), -                "schedule_deletion": False, -                "ping_everyone": False -            }, -        } - -    async def cog_unload(self) -> None: -        """Cancel scheduled tasks.""" -        self.scheduler.cancel_all() - -    def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: -        """Fetch items from the filter_list_cache.""" -        return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() - -    def _get_filterlist_value(self, list_type: str, value: Any, *, allowed: bool) -> dict: -        """Fetch one specific value from filter_list_cache.""" -        return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"][value] - -    @staticmethod -    def _expand_spoilers(text: str) -> str: -        """Return a string containing all interpretations of a spoilered message.""" -        split_text = SPOILER_RE.split(text) -        return ''.join( -            split_text[0::2] + split_text[1::2] + split_text -        ) - -    @property -    def mod_log(self) -> ModLog: -        """Get currently loaded ModLog cog instance.""" -        return self.bot.get_cog("ModLog") - -    @Cog.listener() -    async def on_message(self, msg: Message) -> None: -        """Invoke message filter for new messages.""" -        await self._filter_message(msg) - -        # Ignore webhook messages. -        if msg.webhook_id is None: -            await self.check_bad_words_in_name(msg.author) - -    @Cog.listener() -    async def on_message_edit(self, before: Message, after: Message) -> None: -        """ -        Invoke message filter for message edits. - -        Also calculates the time delta from the previous edit or when message was sent if there's no prior edits. -        """ -        # We only care about changes to the message contents/attachments and embed additions, not pin status etc. -        if all(( -            before.content == after.content,  # content hasn't changed -            before.attachments == after.attachments,  # attachments haven't changed -            len(before.embeds) >= len(after.embeds)  # embeds haven't been added -        )): -            return - -        if not before.edited_at: -            delta = relativedelta(after.edited_at, before.created_at).microseconds -        else: -            delta = relativedelta(after.edited_at, before.edited_at).microseconds -        await self._filter_message(after, delta) - -    @Cog.listener() -    async def on_voice_state_update(self, member: Member, *_) -> None: -        """Checks for bad words in usernames when users join, switch or leave a voice channel.""" -        await self.check_bad_words_in_name(member) - -    def get_name_match(self, name: str) -> Optional[re.Match]: -        """Check bad words from passed string (name). Return the first match found.""" -        normalised_name = unicodedata.normalize("NFKC", name) -        cleaned_normalised_name = "".join([c for c in normalised_name if not unicodedata.combining(c)]) - -        # Run filters against normalised, cleaned normalised and the original name, -        # in case we have filters for one but not the other. -        names_to_check = (name, normalised_name, cleaned_normalised_name) - -        watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) -        for pattern in watchlist_patterns: -            for name in names_to_check: -                if match := re.search(pattern, name, flags=re.IGNORECASE): -                    return match -        return None - -    async def check_send_alert(self, member: Member) -> bool: -        """When there is less than 3 days after last alert, return `False`, otherwise `True`.""" -        if last_alert := await self.name_alerts.get(member.id): -            last_alert = arrow.get(last_alert) -            if arrow.utcnow() - timedelta(days=DAYS_BETWEEN_ALERTS) < last_alert: -                log.trace(f"Last alert was too recent for {member}'s nickname.") -                return False - -        return True - -    async def check_bad_words_in_name(self, member: Member) -> None: -        """Send a mod alert every 3 days if a username still matches a watchlist pattern.""" -        # Use lock to avoid race conditions -        async with self.name_lock: -            # Check if we recently alerted about this user first, -            # to avoid running all the filter tokens against their name again. -            if not await self.check_send_alert(member): -                return - -            # Check whether the users display name contains any words in our blacklist -            match = self.get_name_match(member.display_name) -            if not match: -                return - -            log.info(f"Sending bad nickname alert for '{member.display_name}' ({member.id}).") - -            log_string = ( -                f"**User:** {format_user(member)}\n" -                f"**Display Name:** {escape_markdown(member.display_name)}\n" -                f"**Bad Match:** {match.group()}" -            ) - -            await self.mod_log.send_log_message( -                content=str(member.id),  # quality-of-life improvement for mobile moderators -                icon_url=Icons.token_removed, -                colour=Colours.soft_red, -                title="Username filtering alert", -                text=log_string, -                channel_id=Channels.mod_alerts, -                thumbnail=member.display_avatar.url, -                ping_everyone=True -            ) - -            # Update time when alert sent -            await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - -    async def filter_snekbox_output(self, result: str, msg: Message) -> bool: -        """ -        Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. - -        Also requires the original message, to check whether to filter and for mod logs. -        Returns whether a filter was triggered or not. -        """ -        filter_triggered = False -        # Should we filter this message? -        if self._check_filter(msg): -            for filter_name, _filter in self.filters.items(): -                # Is this specific filter enabled in the config? -                # We also do not need to worry about filters that take the full message, -                # since all we have is an arbitrary string. -                if _filter["enabled"] and _filter["content_only"]: -                    filter_result = await _filter["function"](result) -                    reason = None - -                    if isinstance(filter_result, tuple): -                        match, reason = filter_result -                    else: -                        match = filter_result - -                    if match: -                        # If this is a filter (not a watchlist), we set the variable so we know -                        # that it has been triggered -                        if _filter["type"] == "filter": -                            filter_triggered = True - -                        stats = self._add_stats(filter_name, match, result) -                        await self._send_log(filter_name, _filter, msg, stats, reason, is_eval=True) - -                        break  # We don't want multiple filters to trigger - -        return filter_triggered - -    async def _filter_message(self, msg: Message, delta: Optional[int] = None) -> None: -        """Filter the input message to see if it violates any of our rules, and then respond accordingly.""" -        # Should we filter this message? -        if self._check_filter(msg): -            for filter_name, _filter in self.filters.items(): -                # Is this specific filter enabled in the config? -                if _filter["enabled"]: -                    # Double trigger check for the embeds filter -                    if filter_name == "watch_rich_embeds": -                        # If the edit delta is less than 0.001 seconds, then we're probably dealing -                        # with a double filter trigger. -                        if delta is not None and delta < 100: -                            continue - -                    if filter_name in ("filter_invites", "filter_everyone_ping"): -                        # Disable invites filter in codejam team channels -                        category = getattr(msg.channel, "category", None) -                        if category and category.name == JAM_CATEGORY_NAME: -                            continue - -                    # Does the filter only need the message content or the full message? -                    if _filter["content_only"]: -                        payload = msg.content -                    else: -                        payload = msg - -                    result = await _filter["function"](payload) -                    reason = None - -                    if isinstance(result, tuple): -                        match, reason = result -                    else: -                        match = result - -                    if match: -                        is_private = msg.channel.type is ChannelType.private - -                        # If this is a filter (not a watchlist) and not in a DM, delete the message. -                        if _filter["type"] == "filter" and not is_private: -                            try: -                                # Embeds (can?) trigger both the `on_message` and `on_message_edit` -                                # event handlers, triggering filtering twice for the same message. -                                # -                                # If `on_message`-triggered filtering already deleted the message -                                # then `on_message_edit`-triggered filtering will raise exception -                                # since the message no longer exists. -                                # -                                # In addition, to avoid sending two notifications to the user, the -                                # logs, and mod_alert, we return if the message no longer exists. -                                await msg.delete() -                            except NotFound: -                                return - -                            # Notify the user if the filter specifies -                            if _filter["user_notification"]: -                                await self.notify_member(msg.author, _filter["notification_msg"], msg.channel) - -                        # If the message is classed as offensive, we store it in the site db and -                        # it will be deleted after one week. -                        if _filter["schedule_deletion"] and not is_private: -                            delete_date = (msg.created_at + OFFENSIVE_MSG_DELETE_TIME).isoformat() -                            data = { -                                'id': msg.id, -                                'channel_id': msg.channel.id, -                                'delete_date': delete_date -                            } - -                            try: -                                await self.bot.api_client.post('bot/offensive-messages', json=data) -                            except ResponseCodeError as e: -                                if e.status == 400 and "already exists" in e.response_json.get("id", [""])[0]: -                                    log.debug(f"Offensive message {msg.id} already exists.") -                                else: -                                    log.error(f"Offensive message {msg.id} failed to post: {e}") -                            else: -                                self.schedule_msg_delete(data) -                                log.trace(f"Offensive message {msg.id} will be deleted on {delete_date}") - -                        stats = self._add_stats(filter_name, match, msg.content) - -                        # If the filter reason contains `[autoban]`, we want to auto-ban the user. -                        # Also pass this to _send_log so mods are not pinged filter matches that are auto-actioned -                        autoban = reason and "[autoban]" in reason.lower() -                        if not autoban and filter_name == "filter_invites" and isinstance(result, dict): -                            autoban = any( -                                "[autoban]" in invite_info["reason"].lower() -                                for invite_info in result.values() -                                if invite_info.get("reason") -                            ) - -                        await self._send_log(filter_name, _filter, msg, stats, reason, autoban=autoban) - -                        if autoban: -                            # Create a new context, with the author as is the bot, and the channel as #mod-alerts. -                            # This sends the ban confirmation directly under watchlist trigger embed, to inform -                            # mods that the user was auto-banned for the message. -                            context = await self.bot.get_context(msg) -                            context.guild = self.bot.get_guild(Guild.id) -                            context.author = context.guild.get_member(self.bot.user.id) -                            context.channel = self.bot.get_channel(Channels.mod_alerts) -                            context.command = self.bot.get_command("tempban") - -                            await context.invoke( -                                context.command, -                                msg.author, -                                (arrow.utcnow() + AUTO_BAN_DURATION).datetime, -                                reason=AUTO_BAN_REASON -                            ) - -                        break  # We don't want multiple filters to trigger - -    async def _send_log( -        self, -        filter_name: str, -        _filter: Dict[str, Any], -        msg: Message, -        stats: Stats, -        reason: Optional[str] = None, -        *, -        is_eval: bool = False, -        autoban: bool = False, -    ) -> None: -        """Send a mod log for a triggered filter.""" -        if msg.channel.type is ChannelType.private: -            channel_str = "via DM" -            ping_everyone = False -        else: -            channel_str = f"in {msg.channel.mention}" -            # Allow specific filters to override ping_everyone -            ping_everyone = Filter.ping_everyone and _filter.get("ping_everyone", True) - -        content = str(msg.author.id)  # quality-of-life improvement for mobile moderators - -        # If we are going to autoban, we don't want to ping and don't need the user ID -        if autoban: -            ping_everyone = False -            content = None - -        eval_msg = f"using {BotConfig.prefix}eval " if is_eval else "" -        footer = f"Reason: {reason}" if reason else None -        message = ( -            f"The {filter_name} {_filter['type']} was triggered by {format_user(msg.author)} " -            f"{channel_str} {eval_msg}with [the following message]({msg.jump_url}):\n\n" -            f"{stats.message_content}" -        ) - -        log.debug(message) - -        # Send pretty mod log embed to mod-alerts -        await self.mod_log.send_log_message( -            content=content, -            icon_url=Icons.filtering, -            colour=Colour(Colours.soft_red), -            title=f"{_filter['type'].title()} triggered!", -            text=message, -            thumbnail=msg.author.display_avatar.url, -            channel_id=Channels.mod_alerts, -            ping_everyone=ping_everyone, -            additional_embeds=stats.additional_embeds, -            footer=footer, -        ) - -    def _add_stats(self, name: str, match: FilterMatch, content: str) -> Stats: -        """Adds relevant statistical information to the relevant filter and increments the bot's stats.""" -        # Word and match stats for watch_regex -        if name == "watch_regex": -            surroundings = match.string[max(match.start() - 10, 0): match.end() + 10] -            message_content = ( -                f"**Match:** '{match[0]}'\n" -                f"**Location:** '...{escape_markdown(surroundings)}...'\n" -                f"\n**Original Message:**\n{escape_markdown(content)}" -            ) -        else:  # Use original content -            message_content = content - -        additional_embeds = None - -        self.bot.stats.incr(f"filters.{name}") - -        # The function returns True for invalid invites. -        # They have no data so additional embeds can't be created for them. -        if name == "filter_invites" and match is not True: -            additional_embeds = [] -            for _, data in match.items(): -                reason = f"Reason: {data['reason']} | " if data.get('reason') else "" -                embed = Embed(description=( -                    f"**Members:**\n{data['members']}\n" -                    f"**Active:**\n{data['active']}" -                )) -                embed.set_author(name=data["name"]) -                embed.set_thumbnail(url=data["icon"]) -                embed.set_footer(text=f"{reason}Guild ID: {data['id']}") -                additional_embeds.append(embed) - -        elif name == "watch_rich_embeds": -            additional_embeds = match - -        return Stats(message_content, additional_embeds) - -    @staticmethod -    def _check_filter(msg: Message) -> bool: -        """Check whitelists to see if we should filter this message.""" -        role_whitelisted = False - -        if type(msg.author) is Member:  # Only Member has roles, not User. -            for role in msg.author.roles: -                if role.id in Filter.role_whitelist: -                    role_whitelisted = True - -        return ( -            msg.channel.id not in Filter.channel_whitelist  # Channel not in whitelist -            and not role_whitelisted                        # Role not in whitelist -            and not msg.author.bot                          # Author not a bot -        ) - -    async def _has_watch_regex_match(self, text: str) -> Tuple[Union[bool, re.Match], Optional[str]]: -        """ -        Return True if `text` matches any regex from `word_watchlist` or `token_watchlist` configs. - -        `word_watchlist`'s patterns are placed between word boundaries while `token_watchlist` is -        matched as-is. Spoilers are expanded, if any, and URLs are ignored. -        Second return value is a reason written to database about blacklist entry (can be None). -        """ -        if SPOILER_RE.search(text): -            text = self._expand_spoilers(text) - -        text = self.clean_input(text) - -        watchlist_patterns = self._get_filterlist_items('filter_token', allowed=False) -        for pattern in watchlist_patterns: -            match = re.search(pattern, text, flags=re.IGNORECASE) -            if match: -                return match, self._get_filterlist_value('filter_token', pattern, allowed=False)['comment'] - -        return False, None - -    async def _has_urls(self, text: str) -> Tuple[bool, Optional[str]]: -        """ -        Returns True if the text contains one of the blacklisted URLs from the config file. - -        Second return value is a reason of URL blacklisting (can be None). -        """ -        text = self.clean_input(text) - -        domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) -        for match in URL_RE.finditer(text): -            for url in domain_blacklist: -                if url.lower() in match.group(1).lower(): -                    blacklisted_parsed = tldextract.extract(url.lower()) -                    url_parsed = tldextract.extract(match.group(1).lower()) -                    if blacklisted_parsed.registered_domain == url_parsed.registered_domain: -                        return True, self._get_filterlist_value("domain_name", url, allowed=False)["comment"] -        return False, None - -    @staticmethod -    async def _has_zalgo(text: str) -> bool: -        """ -        Returns True if the text contains zalgo characters. - -        Zalgo range is \u0300 – \u036F and \u0489. -        """ -        return bool(ZALGO_RE.search(text)) - -    async def _has_invites(self, text: str) -> Union[dict, bool]: -        """ -        Checks if there's any invites in the text content that aren't in the guild whitelist. - -        If any are detected, a dictionary of invite data is returned, with a key per invite. -        If none are detected, False is returned. -        If we are unable to process an invite, True is returned. - -        Attempts to catch some of common ways to try to cheat the system. -        """ -        text = self.clean_input(text) - -        # Remove backslashes to prevent escape character fuckaroundery like -        # discord\.gg/gdudes-pony-farm -        text = text.replace("\\", "") - -        invites = [m.group("invite") for m in DISCORD_INVITE.finditer(text)] -        invite_data = dict() -        for invite in invites: -            invite = urllib.parse.quote_plus(invite.rstrip("/")) -            if invite in invite_data: -                continue - -            response = await self.bot.http_session.get( -                f"{URLs.discord_invite_api}/{invite}", params={"with_counts": "true"} -            ) -            response = await response.json() -            guild = response.get("guild") -            if guild is None: -                # Lack of a "guild" key in the JSON response indicates either an group DM invite, an -                # expired invite, or an invalid invite. The API does not currently differentiate -                # between invalid and expired invites -                return True - -            guild_id = guild.get("id") -            guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) -            guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) - -            # Is this invite allowed? -            guild_partnered_or_verified = ( -                'PARTNERED' in guild.get("features", []) -                or 'VERIFIED' in guild.get("features", []) -            ) -            invite_not_allowed = ( -                guild_id in guild_invite_blacklist           # Blacklisted guilds are never permitted. -                or guild_id not in guild_invite_whitelist    # Whitelisted guilds are always permitted. -                and not guild_partnered_or_verified          # Otherwise guilds have to be Verified or Partnered. -            ) - -            if invite_not_allowed: -                reason = None -                if guild_id in guild_invite_blacklist: -                    reason = self._get_filterlist_value("guild_invite", guild_id, allowed=False)["comment"] - -                guild_icon_hash = guild["icon"] -                guild_icon = ( -                    "https://cdn.discordapp.com/icons/" -                    f"{guild_id}/{guild_icon_hash}.png?size=512" -                ) - -                invite_data[invite] = { -                    "name": guild["name"], -                    "id": guild['id'], -                    "icon": guild_icon, -                    "members": response["approximate_member_count"], -                    "active": response["approximate_presence_count"], -                    "reason": reason -                } - -        return invite_data if invite_data else False - -    @staticmethod -    async def _has_rich_embed(msg: Message) -> Union[bool, List[Embed]]: -        """Determines if `msg` contains any rich embeds not auto-generated from a URL.""" -        if msg.embeds: -            for embed in msg.embeds: -                if embed.type == "rich": -                    urls = URL_RE.findall(msg.content) -                    final_urls = set(urls) -                    # This is due to way discord renders relative urls in Embdes -                    # if we send the following url: https://mobile.twitter.com/something -                    # Discord renders it as https://twitter.com/something -                    for url in urls: -                        final_urls.add(remove_subdomain_from_url(url)) -                    if not embed.url or embed.url not in final_urls: -                        # If `embed.url` does not exist or if `embed.url` is not part of the content -                        # of the message, it's unlikely to be an auto-generated embed by Discord. -                        return msg.embeds -                    else: -                        log.trace( -                            "Found a rich embed sent by a regular user account, " -                            "but it was likely just an automatic URL embed." -                        ) -                        return False -        return False - -    @staticmethod -    async def _has_everyone_ping(text: str) -> bool: -        """Determines if `msg` contains an @everyone or @here ping outside of a codeblock.""" -        # First pass to avoid running re.sub on every message -        if not EVERYONE_PING_RE.search(text): -            return False - -        content_without_codeblocks = CODE_BLOCK_RE.sub("", text) -        return bool(EVERYONE_PING_RE.search(content_without_codeblocks)) - -    async def notify_member(self, filtered_member: Member, reason: str, channel: TextChannel) -> None: -        """ -        Notify filtered_member about a moderation action with the reason str. - -        First attempts to DM the user, fall back to in-channel notification if user has DMs disabled -        """ -        try: -            await filtered_member.send(reason) -        except Forbidden: -            await channel.send(f"{filtered_member.mention} {reason}") - -    def schedule_msg_delete(self, msg: dict) -> None: -        """Delete an offensive message once its deletion date is reached.""" -        delete_at = dateutil.parser.isoparse(msg['delete_date']) -        self.scheduler.schedule_at(delete_at, msg['id'], self.delete_offensive_msg(msg)) - -    async def cog_load(self) -> None: -        """Get all the pending message deletion from the API and reschedule them.""" -        await self.bot.wait_until_ready() -        response = await self.bot.api_client.get('bot/offensive-messages',) - -        now = arrow.utcnow() - -        for msg in response: -            delete_at = dateutil.parser.isoparse(msg['delete_date']) - -            if delete_at < now: -                await self.delete_offensive_msg(msg) -            else: -                self.schedule_msg_delete(msg) - -    async def delete_offensive_msg(self, msg: Mapping[str, int]) -> None: -        """Delete an offensive message, and then delete it from the db.""" -        try: -            channel = self.bot.get_channel(msg['channel_id']) -            if channel: -                msg_obj = await channel.fetch_message(msg['id']) -                await msg_obj.delete() -        except NotFound: -            log.info( -                f"Tried to delete message {msg['id']}, but the message can't be found " -                f"(it has been probably already deleted)." -            ) -        except HTTPException as e: -            log.warning(f"Failed to delete message {msg['id']}: status {e.status}") - -        await self.bot.api_client.delete(f'bot/offensive-messages/{msg["id"]}') -        log.info(f"Deleted the offensive message with id {msg['id']}.") - -    @staticmethod -    def clean_input(string: str) -> str: -        """Remove zalgo and invisible characters from `string`.""" -        # For future consideration: remove characters in the Mc, Sk, and Lm categories too. -        # Can be normalised with form C to merge char + combining char into a single char to avoid -        # removing legit diacritics, but this would open up a way to bypass filters. -        no_zalgo = ZALGO_RE.sub("", string) -        return INVISIBLE_RE.sub("", no_zalgo) - - -async def setup(bot: Bot) -> None: -    """Load the Filtering cog.""" -    await bot.add_cog(Filtering(bot)) diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py deleted file mode 100644 index b42613804..000000000 --- a/bot/exts/filters/webhook_remover.py +++ /dev/null @@ -1,94 +0,0 @@ -import re - -from discord import Colour, Message, NotFound -from discord.ext.commands import Cog - -from bot.bot import Bot -from bot.constants import Channels, Colours, Event, Icons -from bot.exts.moderation.modlog import ModLog -from bot.log import get_logger -from bot.utils.messages import format_user - -WEBHOOK_URL_RE = re.compile( -    r"((?:https?:\/\/)?(?:ptb\.|canary\.)?discord(?:app)?\.com\/api\/webhooks\/\d+\/)\S+\/?", -    re.IGNORECASE -) - -ALERT_MESSAGE_TEMPLATE = ( -    "{user}, looks like you posted a Discord webhook URL. Therefore, your " -    "message has been removed, and your webhook has been deleted. " -    "You can re-create it if you wish to. If you believe this was a " -    "mistake, please let us know." -) - -log = get_logger(__name__) - - -class WebhookRemover(Cog): -    """Scan messages to detect Discord webhooks links.""" - -    def __init__(self, bot: Bot): -        self.bot = bot - -    @property -    def mod_log(self) -> ModLog: -        """Get current instance of `ModLog`.""" -        return self.bot.get_cog("ModLog") - -    async def delete_and_respond(self, msg: Message, redacted_url: str, *, webhook_deleted: bool) -> None: -        """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`.""" -        # Don't log this, due internal delete, not by user. Will make different entry. -        self.mod_log.ignore(Event.message_delete, msg.id) - -        try: -            await msg.delete() -        except NotFound: -            log.debug(f"Failed to remove webhook in message {msg.id}: message already deleted.") -            return - -        await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) -        if webhook_deleted: -            delete_state = "The webhook was successfully deleted." -        else: -            delete_state = "There was an error when deleting the webhook, it might have already been removed." -        message = ( -            f"{format_user(msg.author)} posted a Discord webhook URL to {msg.channel.mention}. {delete_state} " -            f"Webhook URL was `{redacted_url}`" -        ) -        log.debug(message) - -        # Send entry to moderation alerts. -        await self.mod_log.send_log_message( -            icon_url=Icons.token_removed, -            colour=Colour(Colours.soft_red), -            title="Discord webhook URL removed!", -            text=message, -            thumbnail=msg.author.display_avatar.url, -            channel_id=Channels.mod_alerts -        ) - -        self.bot.stats.incr("tokens.removed_webhooks") - -    @Cog.listener() -    async def on_message(self, msg: Message) -> None: -        """Check if a Discord webhook URL is in `message`.""" -        # Ignore DMs; can't delete messages in there anyway. -        if not msg.guild or msg.author.bot: -            return - -        matches = WEBHOOK_URL_RE.search(msg.content) -        if matches: -            async with self.bot.http_session.delete(matches[0]) as resp: -                # The Discord API Returns a 204 NO CONTENT response on success. -                deleted_successfully = resp.status == 204 -            await self.delete_and_respond(msg, matches[1] + "xxx", webhook_deleted=deleted_successfully) - -    @Cog.listener() -    async def on_message_edit(self, before: Message, after: Message) -> None: -        """Check if a Discord webhook URL is in the edited message `after`.""" -        await self.on_message(after) - - -async def setup(bot: Bot) -> None: -    """Load `WebhookRemover` cog.""" -    await bot.add_cog(WebhookRemover(bot)) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index 073a91a53..e72f32887 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -8,8 +8,8 @@ from pydis_core.utils import scheduling  from bot import constants  from bot.bot import Bot -from bot.exts.filters.token_remover import TokenRemover -from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter +from bot.exts.filtering._filters.unique.webhook import WEBHOOK_URL_RE  from bot.exts.help_channels._channel import is_help_forum_post  from bot.exts.info.codeblock._instructions import get_instructions  from bot.log import get_logger @@ -135,7 +135,7 @@ class CodeBlockCog(Cog, name="Code Block"):              not message.author.bot              and self.is_valid_channel(message.channel)              and has_lines(message.content, constants.CodeBlock.minimum_lines) -            and not TokenRemover.find_token_in_message(message) +            and not DiscordTokenFilter.find_token_in_message(message.content)              and not WEBHOOK_URL_RE.search(message.content)          ) diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index fd9404b1a..aee751345 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -19,6 +19,7 @@ from bot.converters import Age, ISODateTime  from bot.exts.moderation.modlog import ModLog  from bot.log import get_logger  from bot.utils.channel import is_mod_channel +from bot.utils.messages import upload_log  log = get_logger(__name__) @@ -351,7 +352,7 @@ class Clean(Cog):          # Reverse the list to have reverse chronological order          log_messages = reversed(messages) -        log_url = await self.mod_log.upload_log(log_messages, ctx.author.id) +        log_url = await upload_log(log_messages, ctx.author.id)          # Build the embed and send it          if channels == "*": diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index d61a3fa5c..e785216c9 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -14,7 +14,6 @@ from bot.bot import Bot  from bot.constants import Channels, Event  from bot.converters import Age, Duration, DurationOrExpiry, MemberOrUser, UnambiguousMemberOrUser  from bot.decorators import ensure_future_timestamp, respect_role_hierarchy -from bot.exts.filters.filtering import AUTO_BAN_DURATION, AUTO_BAN_REASON  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.log import get_logger @@ -30,6 +29,23 @@ if t.TYPE_CHECKING:      from bot.exts.moderation.watchchannels.bigbrother import BigBrother +# Comp ban +LINK_PASSWORD = "https://support.discord.com/hc/en-us/articles/218410947-I-forgot-my-Password-Where-can-I-set-a-new-one" +LINK_2FA = "https://support.discord.com/hc/en-us/articles/219576828-Setting-up-Two-Factor-Authentication" +COMP_BAN_REASON = ( +    "Your account has been used to send links to a phishing website. You have been automatically banned. " +    "If you are not aware of sending them, that means your account has been compromised.\n\n" + +    f"Here is a guide from Discord on [how to change your password]({LINK_PASSWORD}).\n\n" + +    f"We also highly recommend that you [enable 2 factor authentication on your account]({LINK_2FA}), " +    "for heightened security.\n\n" + +    "Once you have changed your password, feel free to follow the instructions at the bottom of " +    "this message to appeal your ban." +) +COMP_BAN_DURATION = timedelta(days=4) +# Timeout  MAXIMUM_TIMEOUT_DAYS = timedelta(days=28)  TIMEOUT_CAP_MESSAGE = (      f"The timeout for {{0}} can't be longer than {MAXIMUM_TIMEOUT_DAYS.days} days." @@ -51,7 +67,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent infractions -    @command() +    @command(aliases=("warning",))      async def warn(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Warn a user for the given reason."""          if not isinstance(user, Member): @@ -147,7 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog):      @command()      async def compban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None:          """Same as cleanban, but specifically with the ban reason and duration used for compromised accounts.""" -        await self.cleanban(ctx, user, duration=(arrow.utcnow() + AUTO_BAN_DURATION).datetime, reason=AUTO_BAN_REASON) +        await self.cleanban(ctx, user, duration=(arrow.utcnow() + COMP_BAN_DURATION).datetime, reason=COMP_BAN_REASON)      @command(aliases=("vban",))      async def voiceban(self, ctx: Context) -> None: diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 2c94d1af8..47a21753c 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -3,7 +3,6 @@ import difflib  import itertools  import typing as t  from datetime import datetime, timezone -from itertools import zip_longest  import discord  from dateutil.relativedelta import relativedelta @@ -12,14 +11,12 @@ from discord import Colour, Message, Thread  from discord.abc import GuildChannel  from discord.ext.commands import Cog, Context  from discord.utils import escape_markdown, format_dt, snowflake_time -from pydis_core.site_api import ResponseCodeError -from sentry_sdk import add_breadcrumb  from bot.bot import Bot -from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs +from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles  from bot.log import get_logger  from bot.utils import time -from bot.utils.messages import format_user +from bot.utils.messages import format_user, upload_log  log = get_logger(__name__) @@ -45,48 +42,6 @@ class ModLog(Cog, name="ModLog"):          self._cached_edits = [] -    async def upload_log( -        self, -        messages: t.Iterable[discord.Message], -        actor_id: int, -        attachments: t.Iterable[t.List[str]] = None -    ) -> str: -        """Upload message logs to the database and return a URL to a page for viewing the logs.""" -        if attachments is None: -            attachments = [] - -        deletedmessage_set = [ -            { -                "id": message.id, -                "author": message.author.id, -                "channel_id": message.channel.id, -                "content": message.content.replace("\0", ""),  # Null chars cause 400. -                "embeds": [embed.to_dict() for embed in message.embeds], -                "attachments": attachment, -            } -            for message, attachment in zip_longest(messages, attachments, fillvalue=[]) -        ] - -        try: -            response = await self.bot.api_client.post( -                "bot/deleted-messages", -                json={ -                    "actor": actor_id, -                    "creation": datetime.now(timezone.utc).isoformat(), -                    "deletedmessage_set": deletedmessage_set, -                } -            ) -        except ResponseCodeError as e: -            add_breadcrumb( -                category="api_error", -                message=str(e), -                level="error", -                data=deletedmessage_set, -            ) -            raise - -        return f"{URLs.site_logs_view}/{response['id']}" -      def ignore(self, event: Event, *items: int) -> None:          """Add event to ignored events to suppress log emission."""          for item in items: @@ -604,7 +559,7 @@ class ModLog(Cog, name="ModLog"):          remaining_chars = 4090 - len(response)          if len(content) > remaining_chars: -            botlog_url = await self.upload_log(messages=[message], actor_id=message.author.id) +            botlog_url = await upload_log(messages=[message], actor_id=message.author.id)              ending = f"\n\nMessage truncated, [full message here]({botlog_url})."              truncation_point = remaining_chars - len(ending)              content = f"{content[:truncation_point]}...{ending}" diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index bc70a8c1d..7566021c5 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -14,8 +14,8 @@ from pydis_core.utils import scheduling  from bot.bot import Bot  from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons -from bot.exts.filters.token_remover import TokenRemover -from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter +from bot.exts.filtering._filters.unique.webhook import WEBHOOK_URL_RE  from bot.exts.moderation.modlog import ModLog  from bot.log import CustomLogger, get_logger  from bot.pagination import LinePaginator @@ -235,7 +235,7 @@ class WatchChannel(metaclass=CogABCMeta):              await self.send_header(msg) -        if TokenRemover.find_token_in_message(msg) or WEBHOOK_URL_RE.search(msg.content): +        if DiscordTokenFilter.find_token_in_message(msg.content) or WEBHOOK_URL_RE.search(msg.content):              cleaned_content = "Content is censored because it contains a bot or webhook token."          elif cleaned_content := msg.clean_content:              # Put all non-media URLs in a code block to prevent embeds diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index b48fcf592..d7e8bc93c 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -14,10 +14,9 @@ from pydis_core.utils import interactions  from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX  from bot.bot import Bot -from bot.constants import Channels, Emojis, Filter, MODERATION_ROLES, Roles, URLs +from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, URLs  from bot.decorators import redirect_output -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME -from bot.exts.filters.antimalware import TXT_LIKE_FILES +from bot.exts.filtering._filter_lists.extension import TXT_LIKE_FILES  from bot.exts.help_channels._channel import is_help_forum_post  from bot.exts.utils.snekbox._eval import EvalJob, EvalResult  from bot.exts.utils.snekbox._io import FileAttachment @@ -27,7 +26,7 @@ from bot.utils.lock import LockedResourceError, lock_arg  from bot.utils.services import PasteTooLongError, PasteUploadError  if TYPE_CHECKING: -    from bot.exts.filters.filtering import Filtering +    from bot.exts.filtering.filtering import Filtering  log = get_logger(__name__) @@ -288,37 +287,22 @@ class Snekbox(Cog):          return output, paste_link -    def get_extensions_whitelist(self) -> set[str]: -        """Return a set of whitelisted file extensions.""" -        return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES - -    def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles: +    def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles:          """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" -        # Check if user is staff, if is, return -        # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance -        if hasattr(ctx.author, "roles") and any(role.id in Filter.role_whitelist for role in ctx.author.roles): -            return FilteredFiles(files, []) -        # Ignore code jam channels -        if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME: -            return FilteredFiles(files, []) - -        # Get whitelisted extensions -        whitelist = self.get_extensions_whitelist() -          # Filter files into allowed and blocked          blocked = []          allowed = []          for file in files: -            if file.suffix in whitelist: -                allowed.append(file) -            else: +            if file.suffix in blocked_exts:                  blocked.append(file) +            else: +                allowed.append(file)          if blocked:              blocked_str = ", ".join(f.suffix for f in blocked)              log.info(                  f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", -                extra={"attachment_list": [f.path for f in files]} +                extra={"attachment_list": [f.filename for f in files]}              )          return FilteredFiles(allowed, blocked) @@ -365,31 +349,8 @@ class Snekbox(Cog):              else:                  self.bot.stats.incr("snekbox.python.success") -            # Filter file extensions -            allowed, blocked = self._filter_files(ctx, result.files) -            # Also scan failed files for blocked extensions -            failed_files = [FileAttachment(name, b"") for name in result.failed_files] -            blocked.extend(self._filter_files(ctx, failed_files).blocked) -            # Add notice if any files were blocked -            if blocked: -                blocked_sorted = sorted(set(f.suffix for f in blocked)) -                # Only no extension -                if len(blocked_sorted) == 1 and blocked_sorted[0] == "": -                    blocked_msg = "Files with no extension can't be uploaded." -                # Both -                elif "" in blocked_sorted: -                    blocked_str = ", ".join(ext for ext in blocked_sorted if ext) -                    blocked_msg = ( -                        f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" -                    ) -                else: -                    blocked_str = ", ".join(blocked_sorted) -                    blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" - -                msg += f"\n{Emojis.failed_file} {blocked_msg}" -              # Split text files -            text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES] +            text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES]              # Inline until budget, then upload to paste service              # Budget is shared with stdout, so subtract what we've already used              budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) @@ -417,8 +378,35 @@ class Snekbox(Cog):                          budget_chars -= len(file_text)              filter_cog: Filtering | None = self.bot.get_cog("Filtering") -            if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)): -                return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") +            blocked_exts = set() +            # Include failed files in the scan. +            failed_files = [FileAttachment(name, b"") for name in result.failed_files] +            total_files = result.files + failed_files +            if filter_cog: +                block_output, blocked_exts = await filter_cog.filter_snekbox_output(msg, total_files, ctx.message) +                if block_output: +                    return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + +            # Filter file extensions +            allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) +            blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked) +            # Add notice if any files were blocked +            if blocked: +                blocked_sorted = sorted(set(f.suffix for f in blocked)) +                # Only no extension +                if len(blocked_sorted) == 1 and blocked_sorted[0] == "": +                    blocked_msg = "Files with no extension can't be uploaded." +                # Both +                elif "" in blocked_sorted: +                    blocked_str = ", ".join(ext for ext in blocked_sorted if ext) +                    blocked_msg = ( +                        f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" +                    ) +                else: +                    blocked_str = ", ".join(blocked_sorted) +                    blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + +                msg += f"\n{Emojis.failed_file} {blocked_msg}"              # Upload remaining non-text files              files = [f.to_file() for f in allowed if f not in text_files] diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py index 9be396335..a45ecec1a 100644 --- a/bot/exts/utils/snekbox/_io.py +++ b/bot/exts/utils/snekbox/_io.py @@ -53,23 +53,23 @@ def normalize_discord_file_name(name: str) -> str:  class FileAttachment:      """File Attachment from Snekbox eval.""" -    path: str +    filename: str      content: bytes      def __repr__(self) -> str:          """Return the content as a string."""          content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content -        return f"FileAttachment(path={self.path!r}, content={content})" +        return f"FileAttachment(path={self.filename!r}, content={content})"      @property      def suffix(self) -> str:          """Return the file suffix.""" -        return PurePosixPath(self.path).suffix +        return PurePosixPath(self.filename).suffix      @property      def name(self) -> str:          """Return the file name.""" -        return PurePosixPath(self.path).name +        return PurePosixPath(self.filename).name      @classmethod      def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: @@ -92,7 +92,7 @@ class FileAttachment:              content = content.encode("utf-8")          return { -            "path": self.path, +            "path": self.filename,              "content": b64encode(content).decode("ascii"),          } diff --git a/bot/pagination.py b/bot/pagination.py index c39ce211b..679108933 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -204,6 +204,7 @@ class LinePaginator(Paginator):          footer_text: str = None,          url: str = None,          exception_on_empty_embed: bool = False, +        reply: bool = False,      ) -> t.Optional[discord.Message]:          """          Use a paginator and set of reactions to provide pagination over a set of lines. @@ -254,6 +255,8 @@ class LinePaginator(Paginator):          embed.description = paginator.pages[current_page] +        reference = ctx.message if reply else None +          if len(paginator.pages) <= 1:              if footer_text:                  embed.set_footer(text=footer_text) @@ -264,9 +267,10 @@ class LinePaginator(Paginator):                  log.trace(f"Setting embed url to '{url}'")              log.debug("There's less than two pages, so we won't paginate - sending single page on its own") +              if isinstance(ctx, discord.Interaction):                  return await ctx.response.send_message(embed=embed) -            return await ctx.send(embed=embed) +            return await ctx.send(embed=embed, reference=reference)          else:              if footer_text:                  embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -279,11 +283,12 @@ class LinePaginator(Paginator):                  log.trace(f"Setting embed url to '{url}'")              log.debug("Sending first page to channel...") +              if isinstance(ctx, discord.Interaction):                  await ctx.response.send_message(embed=embed)                  message = await ctx.original_response()              else: -                message = await ctx.send(embed=embed) +                message = await ctx.send(embed=embed, reference=reference)          log.debug("Adding emoji reactions to message...") diff --git a/bot/rules/__init__.py b/bot/rules/__init__.py deleted file mode 100644 index a01ceae73..000000000 --- a/bot/rules/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# flake8: noqa - -from .attachments import apply as apply_attachments -from .burst import apply as apply_burst -from .burst_shared import apply as apply_burst_shared -from .chars import apply as apply_chars -from .discord_emojis import apply as apply_discord_emojis -from .duplicates import apply as apply_duplicates -from .links import apply as apply_links -from .mentions import apply as apply_mentions -from .newlines import apply as apply_newlines -from .role_mentions import apply as apply_role_mentions diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py deleted file mode 100644 index 8903c385c..000000000 --- a/bot/rules/attachments.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total attachments exceeding the limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if ( -            msg.author == last_message.author -            and len(msg.attachments) > 0 -        ) -    ) -    total_recent_attachments = sum(len(msg.attachments) for msg in relevant_messages) - -    if total_recent_attachments > config['max']: -        return ( -            f"sent {total_recent_attachments} attachments in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/burst.py b/bot/rules/burst.py deleted file mode 100644 index 25c5a2f33..000000000 --- a/bot/rules/burst.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects repeated messages sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) -    total_relevant = len(relevant_messages) - -    if total_relevant > config['max']: -        return ( -            f"sent {total_relevant} messages in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/burst_shared.py b/bot/rules/burst_shared.py deleted file mode 100644 index bbe9271b3..000000000 --- a/bot/rules/burst_shared.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects repeated messages sent by multiple users.""" -    total_recent = len(recent_messages) - -    if total_recent > config['max']: -        return ( -            f"sent {total_recent} messages in {config['interval']}s", -            set(msg.author for msg in recent_messages), -            recent_messages -        ) -    return None diff --git a/bot/rules/chars.py b/bot/rules/chars.py deleted file mode 100644 index 1f587422c..000000000 --- a/bot/rules/chars.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total message char count exceeding the limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) - -    total_recent_chars = sum(len(msg.content) for msg in relevant_messages) - -    if total_recent_chars > config['max']: -        return ( -            f"sent {total_recent_chars} characters in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/discord_emojis.py b/bot/rules/discord_emojis.py deleted file mode 100644 index d979ac5e7..000000000 --- a/bot/rules/discord_emojis.py +++ /dev/null @@ -1,34 +0,0 @@ -import re -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message -from emoji import demojize - -DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") -CODE_BLOCK_RE = re.compile(r"```.*?```", flags=re.DOTALL) - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total Discord emojis exceeding the limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) - -    # Get rid of code blocks in the message before searching for emojis. -    # Convert Unicode emojis to :emoji: format to get their count. -    total_emojis = sum( -        len(DISCORD_EMOJI_RE.findall(demojize(CODE_BLOCK_RE.sub("", msg.content)))) -        for msg in relevant_messages -    ) - -    if total_emojis > config['max']: -        return ( -            f"sent {total_emojis} emojis in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/duplicates.py b/bot/rules/duplicates.py deleted file mode 100644 index 8e4fbc12d..000000000 --- a/bot/rules/duplicates.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects duplicated messages sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if ( -            msg.author == last_message.author -            and msg.content == last_message.content -            and msg.content -        ) -    ) - -    total_duplicated = len(relevant_messages) - -    if total_duplicated > config['max']: -        return ( -            f"sent {total_duplicated} duplicated messages in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/links.py b/bot/rules/links.py deleted file mode 100644 index c46b783c5..000000000 --- a/bot/rules/links.py +++ /dev/null @@ -1,36 +0,0 @@ -import re -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - -LINK_RE = re.compile(r"(https?://[^\s]+)") - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total links exceeding the limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) -    total_links = 0 -    messages_with_links = 0 - -    for msg in relevant_messages: -        total_matches = len(LINK_RE.findall(msg.content)) -        if total_matches: -            messages_with_links += 1 -            total_links += total_matches - -    # Only apply the filter if we found more than one message with -    # links to prevent wrongfully firing the rule on users posting -    # e.g. an installation log of pip packages from GitHub. -    if total_links > config['max'] and messages_with_links > 1: -        return ( -            f"sent {total_links} links in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/rules/newlines.py b/bot/rules/newlines.py deleted file mode 100644 index 4e66e1359..000000000 --- a/bot/rules/newlines.py +++ /dev/null @@ -1,45 +0,0 @@ -import re -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total newlines exceeding the set limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) - -    # Identify groups of newline characters and get group & total counts -    exp = r"(\n+)" -    newline_counts = [] -    for msg in relevant_messages: -        newline_counts += [len(group) for group in re.findall(exp, msg.content)] -    total_recent_newlines = sum(newline_counts) - -    # Get maximum newline group size -    if newline_counts: -        max_newline_group = max(newline_counts) -    else: -        # If no newlines are found, newline_counts will be an empty list, which will error out max() -        max_newline_group = 0 - -    # Check first for total newlines, if this passes then check for large groupings -    if total_recent_newlines > config['max']: -        return ( -            f"sent {total_recent_newlines} newlines in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    elif max_newline_group > config['max_consecutive']: -        return ( -            f"sent {max_newline_group} consecutive newlines in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) - -    return None diff --git a/bot/rules/role_mentions.py b/bot/rules/role_mentions.py deleted file mode 100644 index 0649540b6..000000000 --- a/bot/rules/role_mentions.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Member, Message - - -async def apply( -    last_message: Message, recent_messages: List[Message], config: Dict[str, int] -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total role mentions exceeding the limit sent by a single user.""" -    relevant_messages = tuple( -        msg -        for msg in recent_messages -        if msg.author == last_message.author -    ) - -    total_recent_mentions = sum(len(msg.role_mentions) for msg in relevant_messages) - -    if total_recent_mentions > config['max']: -        return ( -            f"sent {total_recent_mentions} role mentions in {config['interval']}s", -            (last_message.author,), -            relevant_messages -        ) -    return None diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py index f68d280c9..5deb2376b 100644 --- a/bot/utils/message_cache.py +++ b/bot/utils/message_cache.py @@ -31,20 +31,23 @@ class MessageCache:          self._start = 0          self._end = 0 -        self._messages: list[t.Optional[Message]] = [None] * self.maxlen +        self._messages: list[Message | None] = [None] * self.maxlen          self._message_id_mapping = {} +        self._message_metadata = {} -    def append(self, message: Message) -> None: +    def append(self, message: Message, *, metadata: dict | None = None) -> None:          """Add the received message to the cache, depending on the order of messages defined by `newest_first`."""          if self.newest_first:              self._appendleft(message)          else:              self._appendright(message) +        self._message_metadata[message.id] = metadata      def _appendright(self, message: Message) -> None:          """Add the received message to the end of the cache."""          if self._is_full():              del self._message_id_mapping[self._messages[self._start].id] +            del self._message_metadata[self._messages[self._start].id]              self._start = (self._start + 1) % self.maxlen          self._messages[self._end] = message @@ -56,6 +59,7 @@ class MessageCache:          if self._is_full():              self._end = (self._end - 1) % self.maxlen              del self._message_id_mapping[self._messages[self._end].id] +            del self._message_metadata[self._messages[self._end].id]          self._start = (self._start - 1) % self.maxlen          self._messages[self._start] = message @@ -69,6 +73,7 @@ class MessageCache:          self._end = (self._end - 1) % self.maxlen          message = self._messages[self._end]          del self._message_id_mapping[message.id] +        del self._message_metadata[message.id]          self._messages[self._end] = None          return message @@ -80,6 +85,7 @@ class MessageCache:          message = self._messages[self._start]          del self._message_id_mapping[message.id] +        del self._message_metadata[message.id]          self._messages[self._start] = None          self._start = (self._start + 1) % self.maxlen @@ -89,16 +95,21 @@ class MessageCache:          """Remove all messages from the cache."""          self._messages = [None] * self.maxlen          self._message_id_mapping = {} +        self._message_metadata = {}          self._start = 0          self._end = 0 -    def get_message(self, message_id: int) -> t.Optional[Message]: +    def get_message(self, message_id: int) -> Message | None:          """Return the message that has the given message ID, if it is cached."""          index = self._message_id_mapping.get(message_id, None)          return self._messages[index] if index is not None else None -    def update(self, message: Message) -> bool: +    def get_message_metadata(self, message_id: int) -> dict | None: +        """Return the metadata of the message that has the given message ID, if it is cached.""" +        return self._message_metadata.get(message_id, None) + +    def update(self, message: Message, *, metadata: dict | None = None) -> bool:          """          Update a cached message with new contents. @@ -108,13 +119,15 @@ class MessageCache:          if index is None:              return False          self._messages[index] = message +        if metadata is not None: +            self._message_metadata[message.id] = metadata          return True      def __contains__(self, message_id: int) -> bool:          """Return True if the cache contains a message with the given ID ."""          return message_id in self._message_id_mapping -    def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: +    def __getitem__(self, item: int | slice) -> Message | list[Message]:          """          Return the message(s) in the index or slice provided. diff --git a/bot/utils/messages.py b/bot/utils/messages.py index f6bdceaef..8d765ebfc 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -1,16 +1,21 @@  import asyncio  import random  import re +from collections.abc import Iterable +from datetime import datetime, timezone  from functools import partial  from io import BytesIO  from typing import Callable, List, Optional, Sequence, Union  import discord +from discord import Message  from discord.ext.commands import Context +from pydis_core.site_api import ResponseCodeError  from pydis_core.utils import scheduling +from sentry_sdk import add_breadcrumb  import bot -from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES +from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES, URLs  from bot.log import get_logger  log = get_logger(__name__) @@ -241,6 +246,55 @@ async def send_denial(ctx: Context, reason: str) -> discord.Message:      return await ctx.send(embed=embed) -def format_user(user: discord.abc.User) -> str: +def format_user(user: discord.User | discord.Member) -> str:      """Return a string for `user` which has their mention and ID."""      return f"{user.mention} (`{user.id}`)" + + +def format_channel(channel: discord.abc.Messageable) -> str: +    """Return a string for `channel` with its mention, ID, and the parent channel if it is a thread.""" +    formatted = f"{channel.mention} ({channel.category}/#{channel}" +    if hasattr(channel, "parent"): +        formatted += f"/{channel.parent}" +    formatted += ")" +    return formatted + + +async def upload_log(messages: Iterable[Message], actor_id: int, attachments: dict[int, list[str]] = None) -> str: +    """Upload message logs to the database and return a URL to a page for viewing the logs.""" +    if attachments is None: +        attachments = [] +    else: +        attachments = [attachments.get(message.id, []) for message in messages] + +    deletedmessage_set = [ +        { +            "id": message.id, +            "author": message.author.id, +            "channel_id": message.channel.id, +            "content": message.content.replace("\0", ""),  # Null chars cause 400. +            "embeds": [embed.to_dict() for embed in message.embeds], +            "attachments": attachment, +        } +        for message, attachment in zip(messages, attachments) +    ] + +    try: +        response = await bot.instance.api_client.post( +            "bot/deleted-messages", +            json={ +                "actor": actor_id, +                "creation": datetime.now(timezone.utc).isoformat(), +                "deletedmessage_set": deletedmessage_set, +            } +        ) +    except ResponseCodeError as e: +        add_breadcrumb( +            category="api_error", +            message=str(e), +            level="error", +            data=deletedmessage_set, +        ) +        raise + +    return f"{URLs.site_logs_view}/{response['id']}" diff --git a/tests/bot/exts/filtering/__init__.py b/tests/bot/exts/filtering/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/filtering/__init__.py diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py new file mode 100644 index 000000000..a5cddf8d9 --- /dev/null +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -0,0 +1,276 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock, patch + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.unique import discord_token +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token +from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec + + +class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): +    """Tests the DiscordTokenFilter class.""" + +    def setUp(self): +        """Adds the filter, a bot, and a message to the instance for usage in tests.""" +        now = arrow.utcnow().timestamp() +        self.filter = DiscordTokenFilter({ +            "id": 1, +            "content": "discord_token", +            "description": None, +            "settings": {}, +            "additional_settings": {}, +            "created_at": now, +            "updated_at": now +        }) + +        self.msg = MockMessage(id=555, content="hello world") +        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg) + +    def test_extract_user_id_valid(self): +        """Should consider user IDs valid if they decode into an integer ID.""" +        id_pairs = ( +            ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), +            ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), +            ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), +        ) + +        for token_id, user_id in id_pairs: +            with self.subTest(token_id=token_id): +                result = DiscordTokenFilter.extract_user_id(token_id) +                self.assertEqual(result, user_id) + +    def test_extract_user_id_invalid(self): +        """Should consider non-digit and non-ASCII IDs invalid.""" +        ids = ( +            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), +            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), +            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), +            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), +            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for user_id, msg in ids: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.extract_user_id(user_id) +                self.assertIsNone(result) + +    def test_is_valid_timestamp_valid(self): +        """Should consider timestamps valid if they're greater than the Discord epoch.""" +        timestamps = ( +            "XsyRkw", +            "Xrim9Q", +            "XsyR-w", +            "XsySD_", +            "Dn9r_A", +        ) + +        for timestamp in timestamps: +            with self.subTest(timestamp=timestamp): +                result = DiscordTokenFilter.is_valid_timestamp(timestamp) +                self.assertTrue(result) + +    def test_is_valid_timestamp_invalid(self): +        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" +        timestamps = ( +            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), +            ("ew", "123"), +            ("AoIKgA", "42076800"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for timestamp, msg in timestamps: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.is_valid_timestamp(timestamp) +                self.assertFalse(result) + +    def test_is_valid_hmac_valid(self): +        """Should consider an HMAC valid if it has at least 3 unique characters.""" +        valid_hmacs = ( +            "VXmErH7j511turNpfURmb0rVNm8", +            "Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "sJf6omBPORBPju3WJEIAcwW9Zds", +            "s45jqDV_Iisn-symw0yDRrk_jf4", +        ) + +        for hmac in valid_hmacs: +            with self.subTest(msg=hmac): +                result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) +                self.assertTrue(result) + +    def test_is_invalid_hmac_invalid(self): +        """Should consider an HMAC invalid if has fewer than 3 unique characters.""" +        invalid_hmacs = ( +            ("xxxxxxxxxxxxxxxxxx", "Single character"), +            ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), +            ("ASFasfASFasfASFASsf", "Three characters alternating-case"), +            ("asdasdasdasdasdasdasd", "Three characters one case"), +        ) + +        for hmac, msg in invalid_hmacs: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) +                self.assertFalse(result) + +    async def test_no_trigger_when_no_token(self): +        """False should be returned if the message doesn't contain a Discord token.""" +        return_value = await self.filter.triggered_on(self.ctx) + +        self.assertFalse(return_value) + +    @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") +    def test_find_token_valid_match( +        self, +        token_re, +        token_cls, +        extract_user_id, +        is_valid_timestamp, +        is_maybe_valid_hmac, +    ): +        """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" +        matches = [ +            mock.create_autospec(Match, spec_set=True, instance=True), +            mock.create_autospec(Match, spec_set=True, instance=True), +        ] +        tokens = [ +            mock.create_autospec(Token, spec_set=True, instance=True), +            mock.create_autospec(Token, spec_set=True, instance=True), +        ] + +        token_re.finditer.return_value = matches +        token_cls.side_effect = tokens +        extract_user_id.side_effect = (None, True)  # The 1st match will be invalid, 2nd one valid. +        is_valid_timestamp.return_value = True +        is_maybe_valid_hmac.return_value = True + +        return_value = DiscordTokenFilter.find_token_in_message(self.msg) + +        self.assertEqual(tokens[1], return_value) + +    @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") +    def test_find_token_invalid_matches( +        self, +        token_re, +        token_cls, +        extract_user_id, +        is_valid_timestamp, +        is_maybe_valid_hmac, +    ): +        """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" +        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] +        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) +        extract_user_id.return_value = None +        is_valid_timestamp.return_value = False +        is_maybe_valid_hmac.return_value = False + +        return_value = DiscordTokenFilter.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) + +    def test_regex_invalid_tokens(self): +        """Messages without anything looking like a token are not matched.""" +        tokens = ( +            "", +            "lemon wins", +            "..", +            "x.y", +            "x.y.", +            ".y.z", +            ".y.", +            "..z", +            "x..z", +            " . . ", +            "\n.\n.\n", +            "hellö.world.bye", +            "base64.nötbåse64.morebase64", +            "19jd3J.dfkm3d.€víł§tüff", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = discord_token.TOKEN_RE.findall(token) +                self.assertEqual(len(results), 0) + +    def test_regex_valid_tokens(self): +        """Messages that look like tokens should be matched.""" +        # Don't worry, these tokens have been invalidated. +        tokens = ( +            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", +            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", +            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = discord_token.TOKEN_RE.fullmatch(token) +                self.assertIsNotNone(results, f"{token} was not matched by the regex") + +    def test_regex_matches_multiple_valid(self): +        """Should support multiple matches in the middle of a string.""" +        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" +        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" +        message = f"garbage {token_1} hello {token_2} world" + +        results = discord_token.TOKEN_RE.finditer(message) +        results = [match[0] for match in results] +        self.assertCountEqual((token_1, token_2), results) + +    @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE") +    def test_format_log_message(self, log_message): +        """Should correctly format the log message with info from the message and token.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        log_message.format.return_value = "Howdy" + +        return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token) + +        self.assertEqual(return_value, log_message.format.return_value) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member") +    async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message): +        """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        unknown_user_log_message.format.return_value = " Partner" +        get_or_fetch_member.return_value = None + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") +    async def test_format_userid_log_message_bot(self, known_user_log_message): +        """Should correctly format the user ID portion when the ID belongs to a known bot.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        known_user_log_message.format.return_value = " Partner" + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") +    async def test_format_log_message_user_token_user(self, user_token_message): +        """Should correctly format the user ID portion when the ID belongs to a known user.""" +        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        user_token_message.format.return_value = "Partner" + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (user_token_message.format.return_value, True)) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py new file mode 100644 index 000000000..827d267d2 --- /dev/null +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, patch + +import arrow + +from bot.constants import Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import extension +from bot.exts.filtering._filter_lists.extension import ExtensionsList +from bot.exts.filtering._filter_lists.filter_list import ListType +from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel + +BOT = MockBot() + + +class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): +    """Test the ExtensionsList class.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.filter_list = ExtensionsList(MagicMock()) +        now = arrow.utcnow().timestamp() +        filters = [] +        self.whitelist = [".first", ".second", ".third"] +        for i, filter_content in enumerate(self.whitelist, start=1): +            filters.append({ +                "id": i, "content": filter_content, "description": None, "settings": {}, +                "additional_settings": {}, "created_at": now, "updated_at": now  # noqa: P103 +            }) +        self.filter_list.add_list({ +            "id": 1, +            "list_type": 1, +            "created_at": now, +            "updated_at": now, +            "settings": {}, +            "filters": filters +        }) + +        self.message = MockMessage() +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message) + +    @patch("bot.instance", BOT) +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" +        attachment = MockAttachment(filename="python.first") +        ctx = self.ctx.replace(attachments=[attachment]) + +        result = await self.filter_list.actions_for(ctx) + +        self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) + +    @patch("bot.instance", BOT) +    async def test_message_without_attachment(self): +        """Messages without attachments should return no triggers, messages, or actions.""" +        result = await self.filter_list.actions_for(self.ctx) + +        self.assertEqual(result, (None, [], {})) + +    @patch("bot.instance", BOT) +    async def test_message_with_illegal_extension(self): +        """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" +        attachment = MockAttachment(filename="python.disallowed") +        ctx = self.ctx.replace(attachments=[attachment]) + +        result = await self.filter_list.actions_for(ctx) + +        self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) + +    @patch("bot.instance", BOT) +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site.""" +        attachment = MockAttachment(filename="python.py") +        ctx = self.ctx.replace(attachments=[attachment]) + +        await self.filter_list.actions_for(ctx) + +        self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + +    @patch("bot.instance", BOT) +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt/.json/.csv file should result in the correct embed.""" +        test_values = ( +            ("text", ".txt"), +            ("json", ".json"), +            ("csv", ".csv"), +        ) + +        for file_name, disallowed_extension in test_values: +            with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): + +                attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") +                ctx = self.ctx.replace(attachments=[attachment]) + +                await self.filter_list.actions_for(ctx) + +                self.assertEqual( +                    ctx.dm_embed, +                    extension.TXT_EMBED_DESCRIPTION.format( +                        blocked_extension=disallowed_extension, +                    ) +                ) + +    @patch("bot.instance", BOT) +    async def test_other_disallowed_extension_embed_description(self): +        """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        ctx = self.ctx.replace(attachments=[attachment]) + +        await self.filter_list.actions_for(ctx) +        meta_channel = BOT.get_channel(Channels.meta) + +        self.assertEqual( +            ctx.dm_embed, +            extension.DISALLOWED_EMBED_DESCRIPTION.format( +                joined_whitelist=", ".join(self.whitelist), +                joined_blacklist=".disallowed", +                meta_channel_mention=meta_channel.mention +            ) +        ) + +    @patch("bot.instance", BOT) +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (self.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], ["`.disallowed`"]), +            ([".disallowed"], ["`.disallowed`"]), +            ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) +                result = await self.filter_list.actions_for(ctx) +                self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py new file mode 100644 index 000000000..5a289c1cf --- /dev/null +++ b/tests/bot/exts/filtering/test_settings.py @@ -0,0 +1,20 @@ +import unittest + +import bot.exts.filtering._settings +from bot.exts.filtering._settings import create_settings + + +class FilterTests(unittest.TestCase): +    """Test functionality of the Settings class and its subclasses.""" + +    def test_create_settings_returns_none_for_empty_data(self): +        """`create_settings` should return a tuple of two Nones when passed an empty dict.""" +        result = create_settings({}) + +        self.assertEqual(result, (None, None)) + +    def test_unrecognized_entry_makes_a_warning(self): +        """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" +        create_settings({"abcd": {}}) + +        self.assertIn("abcd", bot.exts.filtering._settings._already_warned) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py new file mode 100644 index 000000000..3ae0b5ab5 --- /dev/null +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -0,0 +1,218 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( +    Infraction, InfractionAndNotification, InfractionDuration +) +from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM +from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel + + +class FilterTests(unittest.TestCase): +    """Test functionality of the Settings class and its subclasses.""" + +    def setUp(self) -> None: +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        message = MockMessage(author=member, channel=channel) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + +    def test_role_bypass_is_off_for_user_without_roles(self): +        """The role bypass should trigger when a user has no roles.""" +        member = MockMember() +        self.ctx.author = member +        bypass_entry = RoleBypass(bypass_roles=["123"]) + +        result = bypass_entry.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_role_bypass_is_on_for_a_user_with_the_right_role(self): +        """The role bypass should not trigger when the user has one of its roles.""" +        cases = ( +            ([123], ["123"]), +            ([123, 234], ["123"]), +            ([123], ["123", "234"]), +            ([123, 234], ["123", "234"]) +        ) + +        for user_role_ids, bypasses in cases: +            with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses): +                user_roles = [MockRole(id=role_id) for role_id in user_role_ids] +                member = MockMember(roles=user_roles) +                self.ctx.author = member +                bypass_entry = RoleBypass(bypass_roles=bypasses) + +                result = bypass_entry.triggers_on(self.ctx) + +                self.assertFalse(result) + +    def test_context_doesnt_trigger_for_empty_channel_scope(self): +        """A filter is enabled for all channels by default.""" +        channel = MockTextChannel() +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_doesnt_trigger_for_disabled_channel(self): +        """A filter shouldn't trigger if it's been disabled in the channel.""" +        channel = MockTextChannel(id=123) +        scope = ChannelScope( +            disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_doesnt_trigger_in_disabled_category(self): +        """A filter shouldn't trigger if it's been disabled in the category.""" +        channel = MockTextChannel(category=MockCategoryChannel(id=456)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_triggers_in_enabled_channel_in_disabled_category(self): +        """A filter should trigger in an enabled channel even if it's been disabled in the category.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_triggers_inside_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_doesnt_trigger_outside_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_filtering_dms_when_necessary(self): +        """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" +        cases = ( +            (True, MockDMChannel(), True), +            (False, MockDMChannel(), False), +            (True, MockTextChannel(), True), +            (False, MockTextChannel(), True) +        ) + +        for apply_in_dms, channel, expected in cases: +            with self.subTest(apply_in_dms=apply_in_dms, channel=channel): +                filter_dms = FilterDM(filter_dm=apply_in_dms) +                self.ctx.channel = channel + +                result = filter_dms.triggers_on(self.ctx) + +                self.assertEqual(expected, result) + +    def test_infraction_merge_of_same_infraction_type(self): +        """When both infractions are of the same type, the one with the longer duration wins.""" +        infraction1 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="hi", +            infraction_duration=InfractionDuration(10), +            dm_content="how", +            dm_embed="what is", +            infraction_channel=0 +        ) +        infraction2 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="there", +            infraction_duration=InfractionDuration(20), +            dm_content="are you", +            dm_embed="your name", +            infraction_channel=0 +        ) + +        result = infraction1.union(infraction2) + +        self.assertDictEqual( +            result.dict(), +            { +                "infraction_type": Infraction.TIMEOUT, +                "infraction_reason": "there", +                "infraction_duration": InfractionDuration(20.0), +                "dm_content": "are you", +                "dm_embed": "your name", +                "infraction_channel": 0 +            } +        ) + +    def test_infraction_merge_of_different_infraction_types(self): +        """If there are two different infraction types, the one higher up the hierarchy should be picked.""" +        infraction1 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="hi", +            infraction_duration=InfractionDuration(20), +            dm_content="", +            dm_embed="", +            infraction_channel=0 +        ) +        infraction2 = InfractionAndNotification( +            infraction_type="BAN", +            infraction_reason="", +            infraction_duration=InfractionDuration(10), +            dm_content="there", +            dm_embed="", +            infraction_channel=0 +        ) + +        result = infraction1.union(infraction2) + +        self.assertDictEqual( +            result.dict(), +            { +                "infraction_type": Infraction.BAN, +                "infraction_reason": "", +                "infraction_duration": InfractionDuration(10), +                "dm_content": "there", +                "dm_embed": "", +                "infraction_channel": 0 +            } +        ) diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py new file mode 100644 index 000000000..03fa6b4b9 --- /dev/null +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -0,0 +1,49 @@ +import unittest + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class TokenFilterTests(unittest.IsolatedAsyncioTestCase): +    """Test functionality of the token filter.""" + +    def setUp(self) -> None: +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        message = MockMessage(author=member, channel=channel) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + +    async def test_token_filter_triggers(self): +        """The filter should evaluate to True only if its token is found in the context content.""" +        test_cases = ( +            (r"hi", "oh hi there", True), +            (r"hi", "goodbye", False), +            (r"bla\d{2,4}", "bla18", True), +            (r"bla\d{2,4}", "bla1", False), +            # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 +            (r"TOKEN", "https://google.com TOKEN", True), +            (r"TOKEN", "https://google.com something else", False) +        ) +        now = arrow.utcnow().timestamp() + +        for pattern, content, expected in test_cases: +            with self.subTest( +                pattern=pattern, +                content=content, +                expected=expected, +            ): +                filter_ = TokenFilter({ +                    "id": 1, +                    "content": pattern, +                    "description": None, +                    "settings": {}, +                    "additional_settings": {}, +                    "created_at": now, +                    "updated_at": now +                }) +                self.ctx.content = content +                result = await filter_.triggered_on(self.ctx) +                self.assertEqual(result, expected) diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py deleted file mode 100644 index 7282334e2..000000000 --- a/tests/bot/exts/filters/test_antimalware.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.constants import Channels, STAFF_ROLES -from bot.exts.filters import antimalware -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): -    """Test the AntiMalware cog.""" - -    def setUp(self): -        """Sets up fresh objects for each test.""" -        self.bot = MockBot() -        self.bot.filter_list_cache = { -            "FILE_FORMAT.True": { -                ".first": {}, -                ".second": {}, -                ".third": {}, -            } -        } -        self.cog = antimalware.AntiMalware(self.bot) -        self.message = MockMessage() -        self.message.webhook_id = None -        self.message.author.bot = None -        self.whitelist = [".first", ".second", ".third"] - -    async def test_message_with_allowed_attachment(self): -        """Messages with allowed extensions should not be deleted""" -        attachment = MockAttachment(filename="python.first") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) -        self.message.delete.assert_not_called() - -    async def test_message_without_attachment(self): -        """Messages without attachments should result in no action.""" -        await self.cog.on_message(self.message) -        self.message.delete.assert_not_called() - -    async def test_direct_message_with_attachment(self): -        """Direct messages should have no action taken.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.guild = None - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_webhook_message_with_illegal_extension(self): -        """A webhook message containing an illegal extension should be ignored.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.webhook_id = 697140105563078727 -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_bot_message_with_illegal_extension(self): -        """A bot message containing an illegal extension should be ignored.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.author.bot = 409107086526644234 -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_message_with_illegal_extension_gets_deleted(self): -        """A message containing an illegal extension should send an embed.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_called_once() - -    async def test_message_send_by_staff(self): -        """A message send by a member of staff should be ignored.""" -        staff_role = MockRole(id=STAFF_ROLES[0]) -        self.message.author.roles.append(staff_role) -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_python_file_redirect_embed_description(self): -        """A message containing a .py file should result in an embed redirecting the user to our paste site""" -        attachment = MockAttachment(filename="python.py") -        self.message.attachments = [attachment] -        self.message.channel.send = AsyncMock() - -        await self.cog.on_message(self.message) -        self.message.channel.send.assert_called_once() -        args, kwargs = self.message.channel.send.call_args -        embed = kwargs.pop("embed") - -        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - -    async def test_txt_file_redirect_embed_description(self): -        """A message containing a .txt/.json/.csv file should result in the correct embed.""" -        test_values = ( -            ("text", ".txt"), -            ("json", ".json"), -            ("csv", ".csv"), -        ) - -        for file_name, disallowed_extension in test_values: -            with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): - -                attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") -                self.message.attachments = [attachment] -                self.message.channel.send = AsyncMock() -                antimalware.TXT_EMBED_DESCRIPTION = Mock() -                antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - -                await self.cog.on_message(self.message) -                self.message.channel.send.assert_called_once() -                args, kwargs = self.message.channel.send.call_args -                embed = kwargs.pop("embed") -                cmd_channel = self.bot.get_channel(Channels.bot_commands) - -                self.assertEqual( -                    embed.description, -                    antimalware.TXT_EMBED_DESCRIPTION.format.return_value -                ) -                antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with( -                    blocked_extension=disallowed_extension, -                    cmd_channel_mention=cmd_channel.mention -                ) - -    async def test_other_disallowed_extension_embed_description(self): -        """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.channel.send = AsyncMock() -        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() -        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - -        await self.cog.on_message(self.message) -        self.message.channel.send.assert_called_once() -        args, kwargs = self.message.channel.send.call_args -        embed = kwargs.pop("embed") -        meta_channel = self.bot.get_channel(Channels.meta) - -        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) -        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( -            joined_whitelist=", ".join(self.whitelist), -            blocked_extensions_str=".disallowed", -            meta_channel_mention=meta_channel.mention -        ) - -    async def test_removing_deleted_message_logs(self): -        """Removing an already deleted message logs the correct message""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - -        with self.assertLogs(logger=antimalware.log, level="INFO"): -            await self.cog.on_message(self.message) -        self.message.delete.assert_called_once() - -    async def test_message_with_illegal_attachment_logs(self): -        """Deleting a message with an illegal attachment should result in a log.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        with self.assertLogs(logger=antimalware.log, level="INFO"): -            await self.cog.on_message(self.message) - -    async def test_get_disallowed_extensions(self): -        """The return value should include all non-whitelisted extensions.""" -        test_values = ( -            ([], []), -            (self.whitelist, []), -            ([".first"], []), -            ([".first", ".disallowed"], [".disallowed"]), -            ([".disallowed"], [".disallowed"]), -            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), -        ) - -        for extensions, expected_disallowed_extensions in test_values: -            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): -                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] -                disallowed_extensions = self.cog._get_disallowed_extensions(self.message) -                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): -    """Tests setup of the `AntiMalware` cog.""" - -    async def test_setup(self): -        """Setup of the extension should call add_cog.""" -        bot = MockBot() -        await antimalware.setup(bot) -        bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py deleted file mode 100644 index 6a0e4fded..000000000 --- a/tests/bot/exts/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.exts.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): -    """Tests validation of the antispam cog configuration.""" - -    def test_default_antispam_config_is_valid(self): -        """The default antispam configuration is valid.""" -        validation_errors = antispam.validate_config() -        self.assertEqual(validation_errors, {}) - -    def test_unknown_rule_returns_error(self): -        """Configuring an unknown rule returns an error.""" -        self.assertEqual( -            antispam.validate_config({'invalid-rule': {}}), -            {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} -        ) - -    def test_missing_keys_returns_error(self): -        """Not configuring required keys returns an error.""" -        keys = (('interval', 'max'), ('max', 'interval')) -        for configured_key, unconfigured_key in keys: -            with self.subTest( -                configured_key=configured_key, -                unconfigured_key=unconfigured_key -            ): -                config = {'burst': {configured_key: 10}} -                error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - -                self.assertEqual( -                    antispam.validate_config(config), -                    {'burst': error} -                ) diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py deleted file mode 100644 index e47cf627b..000000000 --- a/tests/bot/exts/filters/test_filtering.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot.exts.filters import filtering -from tests.helpers import MockBot, autospec - - -class FilteringCogTests(unittest.IsolatedAsyncioTestCase): -    """Tests the `Filtering` cog.""" - -    def setUp(self): -        """Instantiate the bot and cog.""" -        self.bot = MockBot() -        with patch("pydis_core.utils.scheduling.create_task", new=lambda task, **_: task.close()): -            self.cog = filtering.Filtering(self.bot) - -    @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) -    async def test_token_filter(self): -        """Ensure that a filter token is correctly detected in a message.""" -        messages = { -            "": False, -            "no matches": False, -            "TOKEN": True, - -            # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 -            "https://google.com TOKEN": True, -            "https://google.com something else": False, -        } - -        for message, match in messages.items(): -            with self.subTest(input=message, match=match): -                result, _ = await self.cog._has_watch_regex_match(message) - -                self.assertEqual( -                    match, -                    bool(result), -                    msg=f"Hit was {'expected' if match else 'not expected'} for this input." -                ) -                if result: -                    self.assertEqual("TOKEN", result.group()) diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py deleted file mode 100644 index c1f3762ac..000000000 --- a/tests/bot/exts/filters/test_token_remover.py +++ /dev/null @@ -1,409 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.exts.filters import token_remover -from bot.exts.filters.token_remover import Token, TokenRemover -from bot.exts.moderation.modlog import ModLog -from bot.utils.messages import format_user -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): -    """Tests the `TokenRemover` cog.""" - -    def setUp(self): -        """Adds the cog, a bot, and a message to the instance for usage in tests.""" -        self.bot = MockBot() -        self.cog = TokenRemover(bot=self.bot) - -        self.msg = MockMessage(id=555, content="hello world") -        self.msg.channel.mention = "#lemonade-stand" -        self.msg.guild.get_member.return_value.bot = False -        self.msg.guild.get_member.return_value.__str__.return_value = "Woody" -        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) -        self.msg.author.display_avatar.url = "picture-lemon.png" - -    def test_extract_user_id_valid(self): -        """Should consider user IDs valid if they decode into an integer ID.""" -        id_pairs = ( -            ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), -            ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), -            ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), -        ) - -        for token_id, user_id in id_pairs: -            with self.subTest(token_id=token_id): -                result = TokenRemover.extract_user_id(token_id) -                self.assertEqual(result, user_id) - -    def test_extract_user_id_invalid(self): -        """Should consider non-digit and non-ASCII IDs invalid.""" -        ids = ( -            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), -            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), -            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), -            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), -            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), -            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), -            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), -        ) - -        for user_id, msg in ids: -            with self.subTest(msg=msg): -                result = TokenRemover.extract_user_id(user_id) -                self.assertIsNone(result) - -    def test_is_valid_timestamp_valid(self): -        """Should consider timestamps valid if they're greater than the Discord epoch.""" -        timestamps = ( -            "XsyRkw", -            "Xrim9Q", -            "XsyR-w", -            "XsySD_", -            "Dn9r_A", -        ) - -        for timestamp in timestamps: -            with self.subTest(timestamp=timestamp): -                result = TokenRemover.is_valid_timestamp(timestamp) -                self.assertTrue(result) - -    def test_is_valid_timestamp_invalid(self): -        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" -        timestamps = ( -            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), -            ("ew", "123"), -            ("AoIKgA", "42076800"), -            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), -            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), -        ) - -        for timestamp, msg in timestamps: -            with self.subTest(msg=msg): -                result = TokenRemover.is_valid_timestamp(timestamp) -                self.assertFalse(result) - -    def test_is_valid_hmac_valid(self): -        """Should consider an HMAC valid if it has at least 3 unique characters.""" -        valid_hmacs = ( -            "VXmErH7j511turNpfURmb0rVNm8", -            "Ysnu2wacjaKs7qnoo46S8Dm2us8", -            "sJf6omBPORBPju3WJEIAcwW9Zds", -            "s45jqDV_Iisn-symw0yDRrk_jf4", -        ) - -        for hmac in valid_hmacs: -            with self.subTest(msg=hmac): -                result = TokenRemover.is_maybe_valid_hmac(hmac) -                self.assertTrue(result) - -    def test_is_invalid_hmac_invalid(self): -        """Should consider an HMAC invalid if has fewer than 3 unique characters.""" -        invalid_hmacs = ( -            ("xxxxxxxxxxxxxxxxxx", "Single character"), -            ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), -            ("ASFasfASFasfASFASsf", "Three characters alternating-case"), -            ("asdasdasdasdasdasdasd", "Three characters one case"), -        ) - -        for hmac, msg in invalid_hmacs: -            with self.subTest(msg=msg): -                result = TokenRemover.is_maybe_valid_hmac(hmac) -                self.assertFalse(result) - -    def test_mod_log_property(self): -        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -        self.bot.get_cog.return_value = 'lemon' -        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) -        self.bot.get_cog.assert_called_once_with('ModLog') - -    async def test_on_message_edit_uses_on_message(self): -        """The edit listener should delegate handling of the message to the normal listener.""" -        self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - -        await self.cog.on_message_edit(MockMessage(), self.msg) -        self.cog.on_message.assert_awaited_once_with(self.msg) - -    @autospec(TokenRemover, "find_token_in_message", "take_action") -    async def test_on_message_takes_action(self, find_token_in_message, take_action): -        """Should take action if a valid token is found when a message is sent.""" -        cog = TokenRemover(self.bot) -        found_token = "foobar" -        find_token_in_message.return_value = found_token - -        await cog.on_message(self.msg) - -        find_token_in_message.assert_called_once_with(self.msg) -        take_action.assert_awaited_once_with(cog, self.msg, found_token) - -    @autospec(TokenRemover, "find_token_in_message", "take_action") -    async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): -        """Shouldn't take action if a valid token isn't found when a message is sent.""" -        cog = TokenRemover(self.bot) -        find_token_in_message.return_value = False - -        await cog.on_message(self.msg) - -        find_token_in_message.assert_called_once_with(self.msg) -        take_action.assert_not_awaited() - -    @autospec(TokenRemover, "find_token_in_message") -    async def test_on_message_ignores_dms_bots(self, find_token_in_message): -        """Shouldn't parse a message if it is a DM or authored by a bot.""" -        cog = TokenRemover(self.bot) -        dm_msg = MockMessage(guild=None) -        bot_msg = MockMessage(author=MagicMock(bot=True)) - -        for msg in (dm_msg, bot_msg): -            await cog.on_message(msg) -            find_token_in_message.assert_not_called() - -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_no_matches(self, token_re): -        """None should be returned if the regex matches no tokens in a message.""" -        token_re.finditer.return_value = () - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertIsNone(return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") -    @autospec("bot.exts.filters.token_remover", "Token") -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_valid_match( -        self, -        token_re, -        token_cls, -        extract_user_id, -        is_valid_timestamp, -        is_maybe_valid_hmac, -    ): -        """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" -        matches = [ -            mock.create_autospec(Match, spec_set=True, instance=True), -            mock.create_autospec(Match, spec_set=True, instance=True), -        ] -        tokens = [ -            mock.create_autospec(Token, spec_set=True, instance=True), -            mock.create_autospec(Token, spec_set=True, instance=True), -        ] - -        token_re.finditer.return_value = matches -        token_cls.side_effect = tokens -        extract_user_id.side_effect = (None, True)  # The 1st match will be invalid, 2nd one valid. -        is_valid_timestamp.return_value = True -        is_maybe_valid_hmac.return_value = True - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertEqual(tokens[1], return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") -    @autospec("bot.exts.filters.token_remover", "Token") -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_invalid_matches( -        self, -        token_re, -        token_cls, -        extract_user_id, -        is_valid_timestamp, -        is_maybe_valid_hmac, -    ): -        """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" -        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] -        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) -        extract_user_id.return_value = None -        is_valid_timestamp.return_value = False -        is_maybe_valid_hmac.return_value = False - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertIsNone(return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    def test_regex_invalid_tokens(self): -        """Messages without anything looking like a token are not matched.""" -        tokens = ( -            "", -            "lemon wins", -            "..", -            "x.y", -            "x.y.", -            ".y.z", -            ".y.", -            "..z", -            "x..z", -            " . . ", -            "\n.\n.\n", -            "hellö.world.bye", -            "base64.nötbåse64.morebase64", -            "19jd3J.dfkm3d.€víł§tüff", -        ) - -        for token in tokens: -            with self.subTest(token=token): -                results = token_remover.TOKEN_RE.findall(token) -                self.assertEqual(len(results), 0) - -    def test_regex_valid_tokens(self): -        """Messages that look like tokens should be matched.""" -        # Don't worry, these tokens have been invalidated. -        tokens = ( -            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", -            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", -            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", -            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", -        ) - -        for token in tokens: -            with self.subTest(token=token): -                results = token_remover.TOKEN_RE.fullmatch(token) -                self.assertIsNotNone(results, f"{token} was not matched by the regex") - -    def test_regex_matches_multiple_valid(self): -        """Should support multiple matches in the middle of a string.""" -        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" -        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" -        message = f"garbage {token_1} hello {token_2} world" - -        results = token_remover.TOKEN_RE.finditer(message) -        results = [match[0] for match in results] -        self.assertCountEqual((token_1, token_2), results) - -    @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") -    def test_format_log_message(self, log_message): -        """Should correctly format the log message with info from the message and token.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        log_message.format.return_value = "Howdy" - -        return_value = TokenRemover.format_log_message(self.msg, token) - -        self.assertEqual(return_value, log_message.format.return_value) -        log_message.format.assert_called_once_with( -            author=format_user(self.msg.author), -            channel=self.msg.channel.mention, -            user_id=token.user_id, -            timestamp=token.timestamp, -            hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4", -        ) - -    @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE") -    async def test_format_userid_log_message_unknown(self, unknown_user_log_message,): -        """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        unknown_user_log_message.format.return_value = " Partner" -        msg = MockMessage(id=555, content="hello world") -        msg.guild.get_member.return_value = None -        msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found") - -        return_value = await TokenRemover.format_userid_log_message(msg, token) - -        self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) -        unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332) - -    @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") -    async def test_format_userid_log_message_bot(self, known_user_log_message): -        """Should correctly format the user ID portion when the ID belongs to a known bot.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        known_user_log_message.format.return_value = " Partner" -        msg = MockMessage(id=555, content="hello world") -        msg.guild.get_member.return_value.__str__.return_value = "Sam" -        msg.guild.get_member.return_value.bot = True - -        return_value = await TokenRemover.format_userid_log_message(msg, token) - -        self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) - -        known_user_log_message.format.assert_called_once_with( -            user_id=472265943062413332, -            user_name="Sam", -            kind="BOT", -        ) - -    @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") -    async def test_format_log_message_user_token_user(self, user_token_message): -        """Should correctly format the user ID portion when the ID belongs to a known user.""" -        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        user_token_message.format.return_value = "Partner" - -        return_value = await TokenRemover.format_userid_log_message(self.msg, token) - -        self.assertEqual(return_value, (user_token_message.format.return_value, True)) -        user_token_message.format.assert_called_once_with( -            user_id=467223230650777641, -            user_name="Woody", -            kind="USER", -        ) - -    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) -    @autospec("bot.exts.filters.token_remover", "log") -    @autospec(TokenRemover, "format_log_message", "format_userid_log_message") -    async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property): -        """Should delete the message and send a mod log.""" -        cog = TokenRemover(self.bot) -        mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) -        token = mock.create_autospec(Token, spec_set=True, instance=True) -        token.user_id = "no-id" -        log_msg = "testing123" -        userid_log_message = "userid-log-message" - -        mod_log_property.return_value = mod_log -        format_log_message.return_value = log_msg -        format_userid_log_message.return_value = (userid_log_message, True) - -        await cog.take_action(self.msg, token) - -        self.msg.delete.assert_called_once_with() -        self.msg.channel.send.assert_called_once_with( -            token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) -        ) - -        format_log_message.assert_called_once_with(self.msg, token) -        format_userid_log_message.assert_called_once_with(self.msg, token) -        logger.debug.assert_called_with(log_msg) -        self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - -        mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) -        mod_log.send_log_message.assert_called_once_with( -            icon_url=constants.Icons.token_removed, -            colour=Colour(constants.Colours.soft_red), -            title="Token removed!", -            text=log_msg + "\n" + userid_log_message, -            thumbnail=self.msg.author.display_avatar.url, -            channel_id=constants.Channels.mod_alerts, -            ping_everyone=True, -        ) - -    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) -    async def test_take_action_delete_failure(self, mod_log_property): -        """Shouldn't send any messages if the token message can't be deleted.""" -        cog = TokenRemover(self.bot) -        mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) -        self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - -        token = mock.create_autospec(Token, spec_set=True, instance=True) -        await cog.take_action(self.msg, token) - -        self.msg.delete.assert_called_once_with() -        self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): -    """Tests for the token_remover extension.""" - -    @autospec("bot.exts.filters.token_remover", "TokenRemover") -    async def test_extension_setup(self, cog): -        """The TokenRemover cog should be added.""" -        bot = MockBot() -        await token_remover.setup(bot) - -        cog.assert_called_once_with(bot) -        bot.add_cog.assert_awaited_once() -        self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # Should not be called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py deleted file mode 100644 index 0d570f5a3..000000000 --- a/tests/bot/rules/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest -from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple - -from tests.helpers import MockMessage - - -class DisallowedCase(NamedTuple): -    """Encapsulation for test cases expected to fail.""" -    recent_messages: List[MockMessage] -    culprits: Iterable[str] -    n_violations: int - - -class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): -    """ -    Abstract class for antispam rule test cases. - -    Tests for specific rules should inherit from `RuleTest` and implement -    `relevant_messages` and `get_report`. Each instance should also set the -    `apply` and `config` attributes as necessary. - -    The execution of test cases can then be delegated to the `run_allowed` -    and `run_disallowed` methods. -    """ - -    apply: Callable  # The tested rule's apply function -    config: Dict[str, int] - -    async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: -        """Run all `cases` against `self.apply` expecting them to pass.""" -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -            ): -                self.assertIsNone( -                    await self.apply(last_message, recent_messages, self.config) -                ) - -    async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: -        """Run all `cases` against `self.apply` expecting them to fail.""" -        for case in cases: -            recent_messages, culprits, n_violations = case -            last_message = recent_messages[0] -            relevant_messages = self.relevant_messages(case) -            desired_output = ( -                self.get_report(case), -                culprits, -                relevant_messages, -            ) - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                n_violations=n_violations, -                config=self.config, -            ): -                self.assertTupleEqual( -                    await self.apply(last_message, recent_messages, self.config), -                    desired_output, -                ) - -    @abstractmethod -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        """Give expected relevant messages for `case`.""" -        raise NotImplementedError  # pragma: no cover - -    @abstractmethod -    def get_report(self, case: DisallowedCase) -> str: -        """Give expected error report for `case`.""" -        raise NotImplementedError  # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py deleted file mode 100644 index d7e779221..000000000 --- a/tests/bot/rules/test_attachments.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Iterable - -from bot.rules import attachments -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_attachments: int) -> MockMessage: -    """Builds a message with `total_attachments` attachments.""" -    return MockMessage(author=author, attachments=list(range(total_attachments))) - - -class AttachmentRuleTests(RuleTest): -    """Tests applying the `attachments` antispam rule.""" - -    def setUp(self): -        self.apply = attachments.apply -        self.config = {"max": 5, "interval": 10} - -    async def test_allows_messages_without_too_many_attachments(self): -        """Messages without too many attachments are allowed as-is.""" -        cases = ( -            [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], -            [make_msg("bob", 2), make_msg("bob", 2)], -            [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_with_too_many_attachments(self): -        """Messages with too many attachments trigger the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], -                ("bob",), -                10, -            ), -            DisallowedCase( -                [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], -                ("bob",), -                6, -            ), -            DisallowedCase( -                [make_msg("alice", 6)], -                ("alice",), -                6, -            ), -            DisallowedCase( -                [make_msg("alice", 1) for _ in range(6)], -                ("alice",), -                6, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if ( -                msg.author == last_message.author -                and len(msg.attachments) > 0 -            ) -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py deleted file mode 100644 index 03682966b..000000000 --- a/tests/bot/rules/test_burst.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Iterable - -from bot.rules import burst -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: -    """ -    Init a MockMessage instance with author set to `author`. - -    This serves as a shorthand / alias to keep the test cases visually clean. -    """ -    return MockMessage(author=author) - - -class BurstRuleTests(RuleTest): -    """Tests the `burst` antispam rule.""" - -    def setUp(self): -        self.apply = burst.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("bob"), make_msg("bob")], -            [make_msg("bob"), make_msg("alice"), make_msg("bob")], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the amount of messages exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("bob")], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], -                ("bob",), -                3, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py deleted file mode 100644 index 3275143d5..000000000 --- a/tests/bot/rules/test_burst_shared.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable - -from bot.rules import burst_shared -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: -    """ -    Init a MockMessage instance with the passed arg. - -    This serves as a shorthand / alias to keep the test cases visually clean. -    """ -    return MockMessage(author=author) - - -class BurstSharedRuleTests(RuleTest): -    """Tests the `burst_shared` antispam rule.""" - -    def setUp(self): -        self.apply = burst_shared.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """ -        Cases that do not violate the rule. - -        There really isn't more to test here than a single case. -        """ -        cases = ( -            [make_msg("spongebob"), make_msg("patrick")], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the amount of messages exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("bob")], -                {"bob"}, -                3, -            ), -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], -                {"bob", "alice"}, -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return case.recent_messages - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py deleted file mode 100644 index f1e3c76a7..000000000 --- a/tests/bot/rules/test_chars.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import chars -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_chars: int) -> MockMessage: -    """Build a message with arbitrary content of `n_chars` length.""" -    return MockMessage(author=author, content="A" * n_chars) - - -class CharsRuleTests(RuleTest): -    """Tests the `chars` antispam rule.""" - -    def setUp(self): -        self.apply = chars.apply -        self.config = { -            "max": 20,  # Max allowed sum of chars per user -            "interval": 10, -        } - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of chars within limit.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 20)], -            [make_msg("bob", 15), make_msg("alice", 15)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the total amount of chars exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 21)], -                ("bob",), -                21, -            ), -            DisallowedCase( -                [make_msg("bob", 15), make_msg("bob", 15)], -                ("bob",), -                30, -            ), -            DisallowedCase( -                [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], -                ("alice",), -                30, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py deleted file mode 100644 index 66c2d9f92..000000000 --- a/tests/bot/rules/test_discord_emojis.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Iterable - -from bot.rules import discord_emojis -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - -discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> -unicode_emoji = "🧪" - - -def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage: -    """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" -    return MockMessage(author=author, content=emoji * n_emojis) - - -class DiscordEmojisRuleTests(RuleTest): -    """Tests for the `discord_emojis` antispam rule.""" - -    def setUp(self): -        self.apply = discord_emojis.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of discord and unicode emojis within limit.""" -        cases = ( -            [make_msg("bob", 2)], -            [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], -            [make_msg("bob", 2, unicode_emoji)], -            [ -                make_msg("alice", 1, unicode_emoji), -                make_msg("bob", 2, unicode_emoji), -                make_msg("alice", 1, unicode_emoji) -            ], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with more than the allowed amount of discord and unicode emojis.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                ("alice",), -                4, -            ), -            DisallowedCase( -                [make_msg("bob", 3, unicode_emoji)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [ -                    make_msg("alice", 2, unicode_emoji), -                    make_msg("bob", 2, unicode_emoji), -                    make_msg("alice", 2, unicode_emoji) -                ], -                ("alice",), -                4 -            ) -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py deleted file mode 100644 index 9bd886a77..000000000 --- a/tests/bot/rules/test_duplicates.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import duplicates -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, content: str) -> MockMessage: -    """Give a MockMessage instance with `author` and `content` attrs.""" -    return MockMessage(author=author, content=content) - - -class DuplicatesRuleTests(RuleTest): -    """Tests the `duplicates` antispam rule.""" - -    def setUp(self): -        self.apply = duplicates.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("alice", "A"), make_msg("alice", "A")], -            [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")],  # Non-duplicate -            [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")],  # Different author -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with too many duplicate messages from the same author.""" -        cases = ( -            DisallowedCase( -                [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], -                ("alice",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], -                ("bob",), -                3,  # 4 duplicate messages, but only 3 from bob -            ), -            DisallowedCase( -                [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], -                ("bob",), -                3,  # 4 message from bob, but only 3 duplicates -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if ( -                msg.author == last_message.author -                and msg.content == last_message.content -            ) -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py deleted file mode 100644 index b091bd9d7..000000000 --- a/tests/bot/rules/test_links.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Iterable - -from bot.rules import links -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_links: int) -> MockMessage: -    """Makes a message with `total_links` links.""" -    content = " ".join(["https://pydis.com"] * total_links) -    return MockMessage(author=author, content=content) - - -class LinksTests(RuleTest): -    """Tests applying the `links` rule.""" - -    def setUp(self): -        self.apply = links.apply -        self.config = { -            "max": 2, -            "interval": 10 -        } - -    async def test_links_within_limit(self): -        """Messages with an allowed amount of links.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 2)], -            [make_msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 -            [make_msg("bob", 1), make_msg("bob", 1)], -            [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count -        ) - -        await self.run_allowed(cases) - -    async def test_links_exceeding_limit(self): -        """Messages with a a higher than allowed amount of links.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 1), make_msg("bob", 2)], -                ("bob",), -                3 -            ), -            DisallowedCase( -                [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], -                ("alice",), -                3 -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], -                ("alice",), -                3 -            ) -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py deleted file mode 100644 index e1f904917..000000000 --- a/tests/bot/rules/test_mentions.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Iterable, Optional - -import discord - -from bot.rules import mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage, MockMessageReference - - -def make_msg( -    author: str, -    total_user_mentions: int, -    total_bot_mentions: int = 0, -    *, -    reference: Optional[MockMessageReference] = None -) -> MockMessage: -    """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" -    user_mentions = [MockMember() for _ in range(total_user_mentions)] -    bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - -    mentions = user_mentions + bot_mentions -    if reference is not None: -        # For the sake of these tests we assume that all references are mentions. -        mentions.append(reference.resolved.author) -        msg_type = discord.MessageType.reply -    else: -        msg_type = discord.MessageType.default - -    return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) - - -class TestMentions(RuleTest): -    """Tests applying the `mentions` antispam rule.""" - -    def setUp(self): -        self.apply = mentions.apply -        self.config = { -            "max": 2, -            "interval": 10, -        } - -    async def test_mentions_within_limit(self): -        """Messages with an allowed amount of mentions.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 2)], -            [make_msg("bob", 1), make_msg("bob", 1)], -            [make_msg("bob", 1), make_msg("alice", 2)], -        ) - -        await self.run_allowed(cases) - -    async def test_mentions_exceeding_limit(self): -        """Messages with a higher than allowed amount of mentions.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], -                ("alice",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], -                ("bob",), -                4, -            ), -            DisallowedCase( -                [make_msg("bob", 3, 1)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 3, reference=MockMessageReference())], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], -                ("bob",), -                3 -            ) -        ) - -        await self.run_disallowed(cases) - -    async def test_ignore_bot_mentions(self): -        """Messages with an allowed amount of mentions, also containing bot mentions.""" -        cases = ( -            [make_msg("bob", 0, 3)], -            [make_msg("bob", 2, 1)], -            [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], -            [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] -        ) - -        await self.run_allowed(cases) - -    async def test_ignore_reply_mentions(self): -        """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" -        cases = ( -            [ -                make_msg("bob", 2, reference=MockMessageReference()) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference()), -                make_msg("bob", 0, reference=MockMessageReference()) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), -                make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) -            ] -        ) - -        await self.run_allowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py deleted file mode 100644 index e35377773..000000000 --- a/tests/bot/rules/test_newlines.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Iterable, List - -from bot.rules import newlines -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, newline_groups: List[int]) -> MockMessage: -    """Init a MockMessage instance with `author` and content configured by `newline_groups". - -    Configure content by passing a list of ints, where each int `n` will generate -    a separate group of `n` newlines. - -    Example: -        newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" -    """ -    content = " ".join("\n" * n for n in newline_groups) -    return MockMessage(author=author, content=content) - - -class TotalNewlinesRuleTests(RuleTest): -    """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" - -    def setUp(self): -        self.apply = newlines.apply -        self.config = { -            "max": 5,  # Max sum of newlines in relevant messages -            "max_consecutive": 3,  # Max newlines in one group, in one message -            "interval": 10, -        } - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("alice", [])],  # Single message with no newlines -            [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])],  # 5 newlines in 2 messages -            [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])],  # 5 newlines from each author -            [make_msg("bob", [1]), make_msg("alice", [5])],  # Alice breaks the rule, but only bob is relevant -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_total(self): -        """Cases which violate the rule by having too many newlines in total.""" -        cases = ( -            DisallowedCase(  # Alice sends a total of 6 newlines (disallowed) -                [make_msg("alice", [2, 2]), make_msg("alice", [2])], -                ("alice",), -                6, -            ), -            DisallowedCase(  # Here we test that only alice's newlines count in the sum -                [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], -                ("alice",), -                7, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_author = case.recent_messages[0].author -        return tuple(msg for msg in case.recent_messages if msg.author == last_author) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} newlines in {self.config['interval']}s" - - -class GroupNewlinesRuleTests(RuleTest): -    """ -    Tests the `newlines` antispam rule against max consecutive newline violations. - -    As these violations yield a different error report, they require a different -    `get_report` implementation. -    """ - -    def setUp(self): -        self.apply = newlines.apply -        self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - -    async def test_disallows_messages_consecutive(self): -        """Cases which violate the rule due to having too many consecutive newlines.""" -        cases = ( -            DisallowedCase(  # Bob sends a group of newlines too large -                [make_msg("bob", [4])], -                ("bob",), -                4, -            ), -            DisallowedCase(  # Alice sends 5 in total (allowed), but 4 in one group (disallowed) -                [make_msg("alice", [1]), make_msg("alice", [4])], -                ("alice",), -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_author = case.recent_messages[0].author -        return tuple(msg for msg in case.recent_messages if msg.author == last_author) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py deleted file mode 100644 index 26c05d527..000000000 --- a/tests/bot/rules/test_role_mentions.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Iterable - -from bot.rules import role_mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_mentions: int) -> MockMessage: -    """Build a MockMessage instance with `n_mentions` role mentions.""" -    return MockMessage(author=author, role_mentions=[None] * n_mentions) - - -class RoleMentionsRuleTests(RuleTest): -    """Tests for the `role_mentions` antispam rule.""" - -    def setUp(self): -        self.apply = role_mentions.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of role mentions within limit.""" -        cases = ( -            [make_msg("bob", 2)], -            [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with more than the allowed amount of role mentions.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                ("alice",), -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/helpers.py b/tests/helpers.py index 1a71f210a..020f1aee5 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data  class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """ -    A MagicMock subclass to mock TextChannel objects. +    A MagicMock subclass to mock DMChannel objects. -    Instances of this class will follow the specifications of `discord.TextChannel` instances. For +    Instances of this class will follow the specifications of `discord.DMChannel` instances. For      more information, see the `MockGuild` docstring.      """      spec_set = dm_channel_instance      def __init__(self, **kwargs) -> None: -        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} +        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None}          super().__init__(**collections.ChainMap(kwargs, default_kwargs)) @@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel(  class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id)} -        super().__init__(**collections.ChainMap(default_kwargs, kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))  # Create a Message instance to get a realistic MagicMock of `discord.Message` | 
