diff options
author | 2021-12-09 00:30:41 +0200 | |
---|---|---|
committer | 2022-07-16 01:45:39 +0300 | |
commit | 8095800ae8f38928ab8c406e622ec79ea93b21c3 (patch) | |
tree | 519976086843c54e762388715a30a349fd722294 | |
parent | Tear down the old filtering system (diff) |
New filtering backbone and regex filtering migration
This commit provides the basis of the new filtering system:
- The filtering cog consists of several filter lists loaded from the database (filtering.py).
- Each filter list contains a list of filters, which are run in response to events (message posting, reaction, thread creation). Each filter list may choose to respond to different events (the subscribe method in filtering.py).
- Each filter has settings (settings.py) which decide when it is going to be run (e.g it might be disabled in a specific channel), and what will happen if it triggers (e.g delete the offending message).
- Not every filter has a value for every setting (the _settings_types package) . It will use the default settings specified by its filter list as a fallback.
- Since each filter might have a different effect when triggered, we must check all relevant filters even if we found a triggered filter already, unlike in the old system.
- Two triggered filters may specify different values for the same setting, therefore each entry has a rule for combining two different values (the __or__ method in each file in _settings_types).
To avoid having to prefix each file with an underscore (or the bot will try to load it as a cog), the loading script was changed to ignore packages with names starting with an underscore.
Alert sending is done via a webhook so that several embeds can be sent in the same message (will be useful for example for guild invite alerts).
Filter lists and setting entries classes are loaded dynamically from their respective packages.
In order to be able to test the new features, this commit also includes a migration of the regex-based filtering.
29 files changed, 1494 insertions, 0 deletions
diff --git a/bot/constants.py b/bot/constants.py index c39f9d2b8..65791daa3 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -477,6 +477,7 @@ class Webhooks(metaclass=YAMLGetter): duck_pond: int incidents: int incidents_archive: int + filters: int class Roles(metaclass=YAMLGetter): diff --git a/bot/exts/filtering/README.md b/bot/exts/filtering/README.md new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/filtering/README.md diff --git a/bot/exts/filtering/__init__.py b/bot/exts/filtering/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/filtering/__init__.py diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py new file mode 100644 index 000000000..ee9e87f56 --- /dev/null +++ b/bot/exts/filtering/_filter_context.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from enum import Enum, auto +from typing import Optional, Union + +from discord import DMChannel, Embed, Message, TextChannel, Thread, User + + +class Event(Enum): + """Types of events that can trigger filtering. Note this does not have to align with gateway event types.""" + + MESSAGE = auto() + MESSAGE_EDIT = auto() + + +@dataclass +class FilterContext: + """A dataclass containing the information that should be filtered, and output information of the filtering.""" + + # Input context + event: Event # The type of event + author: User # Who triggered the event + channel: Union[TextChannel, Thread, DMChannel] # The channel involved + content: str # What actually needs filtering + message: Optional[Message] # The message involved + embeds: list = field(default_factory=list) # Any embeds involved + # Output context + dm_content: str = field(default_factory=str) # The content to DM the invoker + dm_embed: Embed = field(default_factory=Embed) # The embed to DM the invoker + send_alert: bool = field(default=True) # Whether to send an alert for the moderators + alert_content: str = field(default_factory=str) # The content of the alert + alert_embeds: list = field(default_factory=list) # Any embeds to add to the alert + action_descriptions: list = field(default_factory=list) # What actions were taken + matches: list = field(default_factory=list) # What exactly was found + + def replace(self, **changes) -> FilterContext: + """Return a new context object assigning new values to the specified fields.""" + return replace(self, **changes) diff --git a/bot/exts/filtering/_filter_lists/__init__.py b/bot/exts/filtering/_filter_lists/__init__.py new file mode 100644 index 000000000..415e3a6bf --- /dev/null +++ b/bot/exts/filtering/_filter_lists/__init__.py @@ -0,0 +1,9 @@ +from os.path import dirname + +from bot.exts.filtering._filter_lists.filter_list import FilterList +from bot.exts.filtering._utils import subclasses_in_package + +filter_list_types = subclasses_in_package(dirname(__file__), f"{__name__}.", FilterList) +filter_list_types = {filter_list.name: filter_list for filter_list in filter_list_types} + +__all__ = [filter_list_types, FilterList] diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py new file mode 100644 index 000000000..f9e304b59 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -0,0 +1,79 @@ +from abc import abstractmethod +from enum import Enum +from typing import Dict, List, Type + +from bot.exts.filtering._settings import Settings, ValidationSettings, create_settings +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._utils import FieldRequiring +from bot.log import get_logger + +log = get_logger(__name__) + + +class ListType(Enum): + DENY = 0 + ALLOW = 1 + + +class FilterList(FieldRequiring): + """Dispatches events to lists of _filters, and aggregates the responses into a single list of actions to take.""" + + # Each subclass must define a name matching the filter_list name we're expecting to receive from the database. + # Names must be unique across all filter lists. + name = FieldRequiring.MUST_SET_UNIQUE + + def __init__(self, filter_type: Type[Filter]): + self._filter_lists: dict[ListType, list[Filter]] = {} + self._defaults: dict[ListType, dict[str, Settings]] = {} + + self.filter_type = filter_type + + def add_list(self, list_data: Dict) -> None: + """Add a new type of list (such as a whitelist or a blacklist) this filter list.""" + actions, validations = create_settings(list_data["settings"]) + list_type = ListType(list_data["list_type"]) + self._defaults[list_type] = {"actions": actions, "validations": validations} + + filters = [] + for filter_data in list_data["filters"]: + try: + filters.append(self.filter_type(filter_data, actions)) + except TypeError as e: + log.warning(e) + self._filter_lists[list_type] = filters + + @abstractmethod + def triggers_for(self, ctx: FilterContext) -> list[Filter]: + """Dispatch the given event to the list's filters, and return filters triggered.""" + + @staticmethod + def filter_list_result(ctx: FilterContext, filters: List[Filter], defaults: ValidationSettings) -> list[Filter]: + """ + Sift through the list of filters, and return only the ones which apply to the given context. + + The strategy is as follows: + 1. The default settings are evaluated on the given context. The default answer for whether the filter is + relevant in the given context is whether there aren't any validation settings which returned False. + 2. For each filter, its overrides are considered: + - If there are no overrides, then the filter is relevant if that is the default answer. + - Otherwise it is relevant if there are no failed overrides, and any failing default is overridden by a + successful override. + + If the filter is relevant in context, see if it actually triggers. + """ + passed_by_default, failed_by_default = defaults.evaluate(ctx) + default_answer = not bool(failed_by_default) + + relevant_filters = [] + for filter_ in filters: + if not filter_.validations: + if default_answer and filter_.triggered_on(ctx): + relevant_filters.append(filter_) + else: + passed, failed = filter_.validations.evaluate(ctx) + if not failed and failed_by_default < passed: + if filter_.triggered_on(ctx): + relevant_filters.append(filter_) + + return relevant_filters diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py new file mode 100644 index 000000000..4495f4414 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/token.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import re +import typing + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filters.token import TokenFilter +from bot.exts.filtering._utils import clean_input + +if typing.TYPE_CHECKING: + from bot.exts.filtering.filtering import Filtering + +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) + + +class TokensList(FilterList): + """A list of filters, each looking for a specific token given by regex.""" + + name = "token" + + def __init__(self, filtering_cog: Filtering): + super().__init__(TokenFilter) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + + def triggers_for(self, ctx: FilterContext) -> list[Filter]: + """Dispatch the given event to the list's filters, and return filters triggered.""" + text = ctx.content + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + text = clean_input(text) + ctx = ctx.replace(content=text) + + return self.filter_list_result( + ctx, self._filter_lists[ListType.DENY], self._defaults[ListType.DENY]["validations"] + ) + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) diff --git a/bot/exts/filtering/_filters/__init__.py b/bot/exts/filtering/_filters/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/filtering/_filters/__init__.py diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py new file mode 100644 index 000000000..484e506fc --- /dev/null +++ b/bot/exts/filtering/_filters/filter.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings import ActionSettings, create_settings + + +class Filter(ABC): + """ + A class representing a filter. + + Each filter looks for a specific attribute within an event (such as message sent), + and defines what action should be performed if it is triggered. + """ + + def __init__(self, filter_data: Dict, action_defaults: Optional[ActionSettings] = None): + self.id = filter_data["id"] + self.content = filter_data["content"] + self.description = filter_data["description"] + self.actions, self.validations = create_settings(filter_data["settings"]) + if not self.actions: + self.actions = action_defaults + elif action_defaults: + self.actions.fallback_to(action_defaults) + self.exact = filter_data["additional_field"] + + @abstractmethod + def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" diff --git a/bot/exts/filtering/_filters/token.py b/bot/exts/filtering/_filters/token.py new file mode 100644 index 000000000..07590c54b --- /dev/null +++ b/bot/exts/filtering/_filters/token.py @@ -0,0 +1,20 @@ +import re + +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filter_context import FilterContext + + +class TokenFilter(Filter): + """A filter which looks for a specific token given by regex.""" + + def triggered_on(self, ctx: FilterContext) -> bool: + """Searches for a regex pattern within a given context.""" + pattern = self.content + + match = re.search(pattern, ctx.content, flags=re.IGNORECASE) + if match: + ctx.matches.append(match[0]) + return True + return False + + diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py new file mode 100644 index 000000000..96e1c1f7f --- /dev/null +++ b/bot/exts/filtering/_settings.py @@ -0,0 +1,180 @@ +from __future__ import annotations +from abc import abstractmethod +from typing import Iterator, Mapping, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types import settings_types +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry +from bot.exts.filtering._utils import FieldRequiring +from bot.log import get_logger + +log = get_logger(__name__) + +_already_warned: set[str] = set() + + +def create_settings(settings_data: dict) -> tuple[Optional[ActionSettings], Optional[ValidationSettings]]: + """ + Create and return instances of the Settings subclasses from the given data + + Additionally, warn for data entries with no matching class. + """ + action_data = {} + validation_data = {} + for entry_name, entry_data in settings_data.items(): + if entry_name in settings_types["ActionEntry"]: + action_data[entry_name] = entry_data + elif entry_name in settings_types["ValidationEntry"]: + validation_data[entry_name] = entry_data + else: + log.warning( + f"A setting named {entry_name} was loaded from the database, but no matching class." + ) + _already_warned.add(entry_name) + return ActionSettings.create(action_data), ValidationSettings.create(validation_data) + + +class Settings(FieldRequiring): + """ + A collection of settings. + + For processing the settings parts in the database and evaluating them on given contexts. + + Each filter list and filter has its own settings. + + A filter doesn't have to have its own settings. For every undefined setting, it falls back to the value defined in + the filter list which contains the filter. + """ + + entry_type = FieldRequiring.MUST_SET + + _already_warned: set[str] = set() + + @abstractmethod + def __init__(self, settings_data: dict): + self._entries: dict[str, Settings.entry_type] = {} + + entry_classes = settings_types.get(self.entry_type.__name__) + for entry_name, entry_data in settings_data.items(): + try: + entry_cls = entry_classes[entry_name] + except KeyError: + if entry_name not in self._already_warned: + log.warning( + f"A setting named {entry_name} was loaded from the database, " + f"but no matching {self.entry_type.__name__} class." + ) + self._already_warned.add(entry_name) + else: + try: + new_entry = entry_cls.create(entry_data) + if new_entry: + self._entries[entry_name] = new_entry + except TypeError as e: + raise TypeError( + f"Attempted to load a {entry_name} setting, but the response is malformed: {entry_data}" + ) from e + + def __contains__(self, item) -> bool: + return item in self._entries + + def __setitem__(self, key: str, value: entry_type) -> None: + self._entries[key] = value + + def copy(self): + copy = self.__class__({}) + copy._entries = self._entries + return copy + + def items(self) -> Iterator[tuple[str, entry_type]]: + yield from self._entries.items() + + def update(self, mapping: Mapping[str, entry_type], **kwargs: entry_type) -> None: + self._entries.update(mapping, **kwargs) + + @classmethod + def create(cls, settings_data: dict) -> Optional[Settings]: + """ + Returns a Settings object from `settings_data` if it holds any value, None otherwise. + + Use this method to create Settings objects instead of the init. + The None value is significant for how a filter list iterates over its filters. + """ + settings = cls(settings_data) + # If an entry doesn't hold any values, its `create` method will return None. + # If all entries are None, then the settings object holds no values. + if not any(settings._entries.values()): + return None + + return settings + + +class ValidationSettings(Settings): + """ + A collection of validation settings. + + A filter is triggered only if all of its validation settings (e.g whether to invoke in DM) approve + (the check returns True). + """ + + entry_type = ValidationEntry + + def __init__(self, settings_data: dict): + super().__init__(settings_data) + + def evaluate(self, ctx: FilterContext) -> tuple[set[str], set[str]]: + """Evaluates for each setting whether the context is relevant to the filter.""" + passed = set() + failed = set() + + self._entries: dict[str, ValidationEntry] + for name, validation in self._entries.items(): + if validation: + if validation.triggers_on(ctx): + passed.add(name) + else: + failed.add(name) + + return passed, failed + + +class ActionSettings(Settings): + """ + A collection of action settings. + + If a filter is triggered, its action settings (e.g how to infract the user) are combined with the action settings of + other triggered filters in the same event, and action is taken according to the combined action settings. + """ + + entry_type = ActionEntry + + def __init__(self, settings_data: dict): + super().__init__(settings_data) + + def __or__(self, other: ActionSettings) -> ActionSettings: + """Combine the entries of two collections of settings into a new ActionsSettings""" + actions = {} + # A settings object doesn't necessarily have all types of entries (e.g in the case of filter overrides). + for entry in self._entries: + if entry in other._entries: + actions[entry] = self._entries[entry] | other._entries[entry] + else: + actions[entry] = self._entries[entry] + for entry in other._entries: + if entry not in actions: + actions[entry] = other._entries[entry] + + result = ActionSettings({}) + result.update(actions) + return result + + async def action(self, ctx: FilterContext) -> None: + """Execute the action of every action entry stored.""" + for entry in self._entries.values(): + await entry.action(ctx) + + def fallback_to(self, fallback: ActionSettings) -> None: + """Fill in missing entries from `fallback`.""" + for entry_name, entry_value in fallback.items(): + if entry_name not in self._entries: + self._entries[entry_name] = entry_value diff --git a/bot/exts/filtering/_settings_types/__init__.py b/bot/exts/filtering/_settings_types/__init__.py new file mode 100644 index 000000000..620290cb2 --- /dev/null +++ b/bot/exts/filtering/_settings_types/__init__.py @@ -0,0 +1,14 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry +from bot.exts.filtering._utils import subclasses_in_package + +action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry) +validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry) + +settings_types = { + "ActionEntry": {settings_type.name: settings_type for settings_type in action_types}, + "ValidationEntry": {settings_type.name: settings_type for settings_type in validation_types} +} + +__all__ = [settings_types] diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py new file mode 100644 index 000000000..9665283ff --- /dev/null +++ b/bot/exts/filtering/_settings_types/bypass_roles.py @@ -0,0 +1,29 @@ +from typing import Any + +from discord import Member + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry +from bot.exts.filtering._utils import ROLE_LITERALS + + +class RoleBypass(ValidationEntry): + """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" + + name = "bypass_roles" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.roles = set() + for role in entry_data: + if role in ROLE_LITERALS: + self.roles.add(ROLE_LITERALS[role]) + elif role.isdigit(): + self.roles.add(int(role)) + # Ignore entries that can't be resolved. + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered on this user given their roles.""" + if not isinstance(ctx.author, Member): + return True + return all(member_role.id not in self.roles for member_role in ctx.author.roles) diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py new file mode 100644 index 000000000..b17914f2f --- /dev/null +++ b/bot/exts/filtering/_settings_types/channel_scope.py @@ -0,0 +1,45 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class ChannelScope(ValidationEntry): + """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" + + name = "channel_scope" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + if entry_data["disabled_channels"]: + self.disabled_channels = set(entry_data["disabled_channels"]) + else: + self.disabled_channels = set() + + if entry_data["disabled_categories"]: + self.disabled_categories = set(entry_data["disabled_categories"]) + else: + self.disabled_categories = set() + + if entry_data["enabled_channels"]: + self.enabled_channels = set(entry_data["enabled_channels"]) + else: + self.enabled_channels = set() + + def triggers_on(self, ctx: FilterContext) -> bool: + """ + Return whether the filter should be triggered in the given channel. + + The filter is invoked by default. + If the channel is explicitly enabled, it bypasses the set disabled channels and categories. + """ + channel = ctx.channel + if hasattr(channel, "parent"): + channel = channel.parent + return ( + channel.id in self.enabled_channels + or ( + channel.id not in self.disabled_channels + and (not channel.category or channel.category.id not in self.disabled_categories) + ) + ) diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py new file mode 100644 index 000000000..b0a018433 --- /dev/null +++ b/bot/exts/filtering/_settings_types/delete_messages.py @@ -0,0 +1,35 @@ +from contextlib import suppress +from typing import Any + +from discord.errors import NotFound + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class DeleteMessages(ActionEntry): + """A setting entry which tells whether to delete the offending message(s).""" + + name = "delete_messages" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.delete: bool = entry_data + + async def action(self, ctx: FilterContext) -> None: + """Delete the context message(s).""" + if not self.delete or ctx.event not in (Event.MESSAGE, Event.MESSAGE_EDIT): + return + + with suppress(NotFound): + if ctx.message.guild: + await ctx.message.delete() + ctx.action_descriptions.append("deleted") + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, DeleteMessages): + return NotImplemented + + return DeleteMessages(self.delete or other.delete) + diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py new file mode 100644 index 000000000..553dccc9c --- /dev/null +++ b/bot/exts/filtering/_settings_types/enabled.py @@ -0,0 +1,18 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class Enabled(ValidationEntry): + """A setting entry which tells whether the filter is enabled.""" + + name = "enabled" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.enabled = entry_data + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter is enabled.""" + return self.enabled diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py new file mode 100644 index 000000000..54f19e4d1 --- /dev/null +++ b/bot/exts/filtering/_settings_types/filter_dm.py @@ -0,0 +1,18 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class FilterDM(ValidationEntry): + """A setting entry which tells whether to apply the filter to DMs.""" + + name = "filter_dm" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.apply_in_dm = entry_data + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered even if it was triggered in DMs.""" + return hasattr(ctx.channel, "guild") or self.apply_in_dm diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py new file mode 100644 index 000000000..263fd851c --- /dev/null +++ b/bot/exts/filtering/_settings_types/infraction_and_notification.py @@ -0,0 +1,180 @@ +from collections import namedtuple +from datetime import timedelta +from enum import Enum, auto +from typing import Any, Optional + +import arrow +from discord import Colour +from discord.errors import Forbidden + +import bot +from bot.constants import Channels, Guild +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class Infraction(Enum): + """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" + + BAN = auto() + KICK = auto() + MUTE = auto() + VOICE_BAN = auto() + SUPERSTAR = auto() + WARNING = auto() + WATCH = auto() + NOTE = auto() + NONE = auto() # Allows making operations on an entry with no infraction without checking for None. + + def __bool__(self) -> bool: + """ + Make the NONE value false-y. + + This is useful for Settings.create to evaluate whether the entry contains anything. + """ + return self != Infraction.NONE + + +superstar = namedtuple("superstar", ["reason", "duration"]) + + +class InfractionAndNotification(ActionEntry): + """ + A setting entry which specifies what infraction to issue and the notification to DM the user. + + Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. + """ + + name = "infraction_and_notification" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + + if entry_data["infraction_type"]: + self.infraction_type = entry_data["infraction_type"] + if isinstance(self.infraction_type, str): + self.infraction_type = Infraction[self.infraction_type.replace(" ", "_").upper()] + self.infraction_reason = entry_data["infraction_reason"] + if entry_data["infraction_duration"] is not None: + self.infraction_duration = float(entry_data["infraction_duration"]) + else: + self.infraction_duration = None + else: + self.infraction_type = Infraction.NONE + self.infraction_reason = None + self.infraction_duration = 0 + + self.dm_content = entry_data["dm_content"] + self.dm_embed = entry_data["dm_embed"] + + self._superstar = entry_data.get("superstar", None) + + async def action(self, ctx: FilterContext) -> None: + """Send the notification to the user, and apply any specified infractions.""" + # If there is no infraction to apply, any DM contents already provided in the context take precedence. + if self.infraction_type == Infraction.NONE and (ctx.dm_content or ctx.dm_embed): + dm_content = ctx.dm_content + dm_embed = ctx.dm_embed.description + else: + dm_content = self.dm_content + dm_embed = self.dm_embed + + if dm_content or dm_embed: + dm_content = f"Hey {ctx.author.mention}!\n{dm_content}" + ctx.dm_embed.description = dm_embed + if not ctx.dm_embed.colour: + ctx.dm_embed.colour = Colour.og_blurple() + + try: + await ctx.author.send(dm_content, embed=ctx.dm_embed) + except Forbidden: + await ctx.channel.send(ctx.dm_content, embed=ctx.dm_embed) + ctx.action_descriptions.append("notified") + + msg_ctx = await bot.instance.get_context(ctx.message) + msg_ctx.guild = bot.instance.get_guild(Guild.id) + msg_ctx.author = ctx.author + msg_ctx.channel = ctx.channel + if self._superstar: + msg_ctx.command = bot.instance.get_command("superstarify") + await msg_ctx.invoke( + msg_ctx.command, + ctx.author, + arrow.utcnow() + timedelta(seconds=self._superstar.duration) + if self._superstar.duration is not None else None, + reason=self._superstar.reason + ) + ctx.action_descriptions.append("superstar") + + if self.infraction_type != Infraction.NONE: + if self.infraction_type == Infraction.BAN or not hasattr(ctx.channel, "guild"): + msg_ctx.channel = bot.instance.get_channel(Channels.mod_alerts) + msg_ctx.command = bot.instance.get_command(self.infraction_type.name) + await msg_ctx.invoke( + msg_ctx.command, + ctx.author, + arrow.utcnow() + timedelta(seconds=self.infraction_duration) + if self.infraction_duration is not None else None, + reason=self.infraction_reason + ) + ctx.action_descriptions.append(self.infraction_type.name.lower()) + + def __or__(self, other: ActionEntry): + """ + Combines two actions of the same type. Each type of action is executed once per filter. + + If the infractions are different, take the data of the one higher up the hierarchy. + + A special case is made for superstar infractions. Even if we decide to auto-mute a user, if they have a + particularly problematic username we will still want to superstarify them. + + This is a "best attempt" implementation. Trying to account for any type of combination would create an + extremely complex ruleset. For example, we could special-case watches as well. + + There is no clear way to properly combine several notification messages, especially when it's in two parts. + To avoid bombarding the user with several notifications, the message with the more significant infraction + is used. + """ + if not isinstance(other, InfractionAndNotification): + return NotImplemented + + # Lower number -> higher in the hierarchy + if self.infraction_type.value < other.infraction_type.value and other.infraction_type != Infraction.SUPERSTAR: + result = InfractionAndNotification(self.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + elif self.infraction_type.value > other.infraction_type.value and self.infraction_type != Infraction.SUPERSTAR: + result = InfractionAndNotification(other.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + + if self.infraction_type == other.infraction_type: + if self.infraction_duration is None or ( + other.infraction_duration is not None and self.infraction_duration > other.infraction_duration + ): + result = InfractionAndNotification(self.to_dict()) + else: + result = InfractionAndNotification(other.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + + # At this stage the infraction types are different, and the lower one is a superstar. + if self.infraction_type.value < other.infraction_type.value: + result = InfractionAndNotification(self.to_dict()) + result._superstar = superstar(other.infraction_reason, other.infraction_duration) + else: + result = InfractionAndNotification(other.to_dict()) + result._superstar = superstar(self.infraction_reason, self.infraction_duration) + return result + + @staticmethod + def _merge_superstars(superstar1: Optional[superstar], superstar2: Optional[superstar]) -> Optional[superstar]: + """Take the superstar with the greater duration.""" + if not superstar1: + return superstar2 + if not superstar2: + return superstar1 + + if superstar1.duration is None or superstar1.duration > superstar2.duration: + return superstar1 + return superstar2 diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py new file mode 100644 index 000000000..857e4a7e8 --- /dev/null +++ b/bot/exts/filtering/_settings_types/ping.py @@ -0,0 +1,52 @@ +from functools import cache +from typing import Any + +from discord import Guild + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import ROLE_LITERALS + + +class Ping(ActionEntry): + """A setting entry which adds the appropriate pings to the alert.""" + + name = "mentions" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.guild_mentions = set(entry_data["guild_pings"]) + self.dm_mentions = set(entry_data["dm_pings"]) + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + mentions = self.guild_mentions if ctx.channel.guild else self.dm_mentions + new_content = " ".join([self._resolve_mention(mention, ctx.channel.guild) for mention in mentions]) + ctx.alert_content = f"{new_content} {ctx.alert_content}" + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, Ping): + return NotImplemented + + return Ping({ + "ping_type": self.guild_mentions | other.guild_mentions, + "dm_ping_type": self.dm_mentions | other.dm_mentions + }) + + @staticmethod + @cache + def _resolve_mention(mention: str, guild: Guild) -> str: + """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" + if mention in ("here", "everyone"): + return f"@{mention}" + if mention in ROLE_LITERALS: + return f"<@&{ROLE_LITERALS[mention]}>" + if not mention.isdigit(): + return mention + + mention = int(mention) + if any(mention == role.id for role in guild.roles): + return f"<@&{mention}>" + else: + return f"<@{mention}>" diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py new file mode 100644 index 000000000..e332494eb --- /dev/null +++ b/bot/exts/filtering/_settings_types/send_alert.py @@ -0,0 +1,26 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class SendAlert(ActionEntry): + """A setting entry which tells whether to send an alert message.""" + + name = "send_alert" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.send_alert: bool = entry_data + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + ctx.send_alert = self.send_alert + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, SendAlert): + return NotImplemented + + return SendAlert(self.send_alert or other.send_alert) + diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py new file mode 100644 index 000000000..b0d54fac3 --- /dev/null +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._utils import FieldRequiring + + +class SettingsEntry(FieldRequiring): + """ + A basic entry in the settings field appearing in every filter list and filter. + + For a filter list, this is the default setting for it. For a filter, it's an override of the default entry. + """ + + # Each subclass must define a name matching the entry name we're expecting to receive from the database. + # Names must be unique across all filter lists. + name = FieldRequiring.MUST_SET_UNIQUE + + @abstractmethod + def __init__(self, entry_data: Any): + super().__init__() + self._dict = {} + + def __setattr__(self, key: str, value: Any) -> None: + super().__setattr__(key, value) + if key == "_dict": + return + self._dict[key] = value + + def __eq__(self, other: SettingsEntry) -> bool: + if not isinstance(other, SettingsEntry): + return NotImplemented + return self._dict == other._dict + + def to_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the entry.""" + return self._dict.copy() + + def copy(self) -> SettingsEntry: + """Return a new entry object with the same parameters.""" + return self.__class__(self.to_dict()) + + @classmethod + def create(cls, entry_data: Optional[dict[str, Any]]) -> Optional[SettingsEntry]: + """ + Returns a SettingsEntry object from `entry_data` if it holds any value, None otherwise. + + Use this method to create SettingsEntry objects instead of the init. + The None value is significant for how a filter list iterates over its filters. + """ + if entry_data is None: + return None + if hasattr(entry_data, "values") and not any(value for value in entry_data.values()): + return None + + return cls(entry_data) + + +class ValidationEntry(SettingsEntry): + """A setting entry to validate whether the filter should be triggered in the given context.""" + + @abstractmethod + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered with this setting in the given context.""" + ... + + +class ActionEntry(SettingsEntry): + """A setting entry defining what the bot should do if the filter it belongs to is triggered.""" + + @abstractmethod + async def action(self, ctx: FilterContext) -> None: + """Execute an action that should be taken when the filter this setting belongs to is triggered.""" + ... + + @abstractmethod + def __or__(self, other: ActionEntry): + """ + Combine two actions of the same type. Each type of action is executed once per filter. + + The following condition must hold: if self == other, then self | other == self. + """ + ... diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py new file mode 100644 index 000000000..a769001f6 --- /dev/null +++ b/bot/exts/filtering/_utils.py @@ -0,0 +1,97 @@ +import importlib +import importlib.util +import inspect +import pkgutil +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Set + +import regex + +from bot.constants import Roles + +ROLE_LITERALS = { + "admins": Roles.admins, + "onduty": Roles.moderators, + "staff": Roles.helpers +} + +VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF" +INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1) +ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1) + + +def subclasses_in_package(package: str, prefix: str, parent: type) -> Set[type]: + """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" + subclasses = set() + + # Find all modules in the package. + for module_info in pkgutil.iter_modules([package], prefix): + if not module_info.ispkg: + module = importlib.import_module(module_info.name) + # Find all classes in each module... + for _, class_ in inspect.getmembers(module, inspect.isclass): + # That are a subclass of the given class. + if parent in class_.__bases__: + subclasses.add(class_) + + return subclasses + + +def clean_input(string: str) -> str: + """Remove zalgo and invisible characters from `string`.""" + # For future consideration: remove characters in the Mc, Sk, and Lm categories too. + # Can be normalised with form C to merge char + combining char into a single char to avoid + # removing legit diacritics, but this would open up a way to bypass _filters. + no_zalgo = ZALGO_RE.sub("", string) + return INVISIBLE_RE.sub("", no_zalgo) + + +class FieldRequiring(ABC): + """A mixin class that can force its concrete subclasses to set a value for specific class attributes.""" + + # Sentinel value that mustn't remain in a concrete subclass. + MUST_SET = object() + + # Sentinel value that mustn't remain in a concrete subclass. + # Overriding value must be unique in the subclasses of the abstract class in which the attribute was set. + MUST_SET_UNIQUE = object() + + # A mapping of the attributes which must be unique, and their unique values, per FieldRequiring subclass. + __unique_attributes: defaultdict[type, dict[str, set]] = defaultdict(dict) + + @abstractmethod + def __init__(self): + ... + + def __init_subclass__(cls, **kwargs): + # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it. + if inspect.isabstract(cls): + for attribute in dir(cls): + if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE: + for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. + if hasattr(parent, attribute): # The attribute was inherited. + break + else: + # A new attribute with the value MUST_SET_UNIQUE. + FieldRequiring.__unique_attributes[cls][attribute] = set() + return + + for attribute in dir(cls): + if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"): + continue + value = getattr(cls, attribute) + if value is FieldRequiring.MUST_SET: + raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}") + elif value is FieldRequiring.MUST_SET_UNIQUE: + raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}") + else: + # Check if the value needs to be unique. + for parent in cls.__mro__[1:-1]: + # Find the parent class the attribute was first defined in. + if attribute in FieldRequiring.__unique_attributes[parent]: + if value in FieldRequiring.__unique_attributes[parent][attribute]: + raise ValueError(f"Value of {attribute!r} in {cls!r} is not unique for parent {parent!r}.") + else: + # Add to the set of unique values for that field. + FieldRequiring.__unique_attributes[parent][attribute].add(value) diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py new file mode 100644 index 000000000..c74b85698 --- /dev/null +++ b/bot/exts/filtering/filtering.py @@ -0,0 +1,150 @@ +import operator +from collections import defaultdict +from functools import reduce +from typing import Optional + +from discord import Embed, HTTPException, Message +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import Colours, Webhooks +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import FilterList, filter_list_types +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings import ActionSettings +from bot.log import get_logger +from bot.utils.messages import format_channel, format_user + +log = get_logger(__name__) + + +class Filtering(Cog): + """Filtering and alerting for content posted on the server.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.filter_lists: dict[str, FilterList] = {} + self._subscriptions: defaultdict[Event, list[FilterList]] = defaultdict(list) + self.webhook = None + + async def cog_load(self) -> None: + """ + Fetch the filter data from the API, parse it, and load it to the appropriate data structures. + + Additionally, fetch the alerting webhook. + """ + await self.bot.wait_until_guild_available() + already_warned = set() + + raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists") + for raw_filter_list in raw_filter_lists: + list_name = raw_filter_list["name"] + if list_name not in self.filter_lists: + if list_name not in filter_list_types: + if list_name not in already_warned: + log.warning( + f"A filter list named {list_name} was loaded from the database, but no matching class." + ) + already_warned.add(list_name) + continue + self.filter_lists[list_name] = filter_list_types[list_name](self) + self.filter_lists[list_name].add_list(raw_filter_list) + + try: + self.webhook = await self.bot.fetch_webhook(Webhooks.filters) + except HTTPException: + log.error(f"Failed to fetch incidents webhook with id `{Webhooks.incidents}`.") + + def subscribe(self, filter_list: FilterList, *events: Event) -> None: + """ + Subscribe a filter list to the given events. + + The filter list is added to a list for each event. When the event is triggered, the filter context will be + dispatched to the subscribed filter lists. + + While it's possible to just make each filter list check the context's event, these are only the events a filter + list expects to receive from the filtering cog, there isn't an actual limitation on the kinds of events a filter + list can handle as long as the filter context is built properly. If for whatever reason we want to invoke a + filter list outside of the usual procedure with the filtering cog, it will be more problematic if the events are + hard-coded into each filter list. + """ + for event in events: + if filter_list not in self._subscriptions[event]: + self._subscriptions[event].append(filter_list) + + async def _resolve_action( + self, ctx: FilterContext + ) -> tuple[dict[FilterList, list[Filter]], Optional[ActionSettings]]: + """Get the filters triggered per list, and resolve from them the action that needs to be taken for the event.""" + triggered = {} + for filter_list in self._subscriptions[ctx.event]: + triggered[filter_list] = filter_list.triggers_for(ctx) + + result_actions = None + if triggered: + result_actions = reduce( + operator.or_, (filter_.actions for filters in triggered.values() for filter_ in filters) + ) + + return triggered, result_actions + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Filter the contents of a sent message.""" + if msg.author.bot: + return + + ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) + + triggered, result_actions = await self._resolve_action(ctx) + if result_actions: + await result_actions.action(ctx) + if ctx.send_alert: + await self._send_alert(ctx, triggered) + + async def _send_alert(self, ctx: FilterContext, triggered_filters: dict[FilterList, list[Filter]]) -> None: + """Build an alert message from the filter context, and send it via the alert webhook.""" + if not self.webhook: + return + + name = f"{ctx.event.name.replace('_', ' ').title()} Filter" + + embed = Embed(color=Colours.soft_orange) + embed.set_thumbnail(url=ctx.author.display_avatar.url) + triggered_by = f"**Triggered by:** {format_user(ctx.author)}" + if ctx.channel.guild: + triggered_in = f"**Triggered in:** {format_channel(ctx.channel)}" + else: + triggered_in = "**DM**" + if len(triggered_filters) == 1 and len(list(triggered_filters.values())[0]) == 1: + filter_list, (filter_,) = next(iter(triggered_filters.items())) + filters = f"**{filter_list.name.title()} Filter:** #{filter_.id} (`{filter_.content}`)" + if filter_.description: + filters += f" - {filter_.description}" + else: + filters = [] + for filter_list, list_filters in triggered_filters.items(): + filters.append( + (f"**{filter_list.name.title()} Filters:** " + ", ".join(f"#{filter_.id} (`{filter_.content}`)" for filter_ in list_filters)) + ) + filters = "\n".join(filters) + + matches = "**Matches:** " + ", ".join(repr(match) for match in ctx.matches) + actions = "**Actions Taken:** " + (", ".join(ctx.action_descriptions) if ctx.action_descriptions else "-") + content = f"**[Original Content]({ctx.message.jump_url})**: {escape_markdown(ctx.content)}" + + embed_content = "\n".join( + part for part in (triggered_by, triggered_in, filters, matches, actions, content) if part + ) + if len(embed_content) > 4000: + embed_content = embed_content[:4000] + " [...]" + embed.description = embed_content + + await self.webhook.send(username=name, content=ctx.alert_content, embeds=[embed, *ctx.alert_embeds]) + + +async def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + await bot.add_cog(Filtering(bot)) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index a5ed84351..63929cd0b 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -238,3 +238,12 @@ async def send_denial(ctx: Context, reason: str) -> discord.Message: def format_user(user: discord.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" + + +def format_channel(channel: discord.abc.Messageable) -> str: + """Return a string for `channel` with its mention, ID, and the parent channel if it is a thread.""" + formatted = f"{channel.mention} ({channel.category}/#{channel}" + if hasattr(channel, "parent"): + formatted += f"/{channel.parent}" + formatted += ")" + return formatted diff --git a/config-default.yml b/config-default.yml index 91945e2b8..1815b8ed7 100644 --- a/config-default.yml +++ b/config-default.yml @@ -317,6 +317,7 @@ guild: incidents: 816650601844572212 incidents_archive: 720671599790915702 python_news: &PYNEWS_WEBHOOK 704381182279942324 + filters: 926442964463521843 filter: diff --git a/tests/bot/exts/filtering/__init__.py b/tests/bot/exts/filtering/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/filtering/__init__.py diff --git a/tests/bot/exts/filtering/test_filters.py b/tests/bot/exts/filtering/test_filters.py new file mode 100644 index 000000000..214637b52 --- /dev/null +++ b/tests/bot/exts/filtering/test_filters.py @@ -0,0 +1,41 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class FilterTests(unittest.TestCase): + """Test functionality of the token filter.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + def test_token_filter_triggers(self): + """The filter should evaluate to True only if its token is found in the context content.""" + test_cases = ( + (r"hi", "oh hi there", True), + (r"hi", "goodbye", False), + (r"bla\d{2,4}", "bla18", True), + (r"bla\d{2,4}", "bla1", False) + ) + + for pattern, content, expected in test_cases: + with self.subTest( + pattern=pattern, + content=content, + expected=expected, + ): + filter_ = TokenFilter({ + "id": 1, + "content": pattern, + "description": None, + "settings": {}, + "additional_field": "{}" # noqa: P103 + }) + self.ctx.content = content + result = filter_.triggered_on(self.ctx) + self.assertEqual(result, expected) diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py new file mode 100644 index 000000000..ac21a5d47 --- /dev/null +++ b/tests/bot/exts/filtering/test_settings.py @@ -0,0 +1,20 @@ +import unittest + +import bot.exts.filtering._settings +from bot.exts.filtering._settings import create_settings + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def test_create_settings_returns_none_for_empty_data(self): + """`create_settings` should return a tuple of two Nones when passed an empty dict.""" + result = create_settings({}) + + self.assertEquals(result, (None, None)) + + def test_unrecognized_entry_makes_a_warning(self): + """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" + create_settings({"abcd": {}}) + + self.assertIn("abcd", bot.exts.filtering._settings._already_warned) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py new file mode 100644 index 000000000..4db6438ab --- /dev/null +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -0,0 +1,272 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.filter_dm import FilterDM +from bot.exts.filtering._settings_types.infraction_and_notification import ( + Infraction, InfractionAndNotification, superstar +) +from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + def test_role_bypass_is_off_for_user_without_roles(self): + """The role bypass should trigger when a user has no roles.""" + member = MockMember() + self.ctx.author = member + bypass_entry = RoleBypass(["123"]) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_role_bypass_is_on_for_a_user_with_the_right_role(self): + """The role bypass should not trigger when the user has one of its roles.""" + cases = ( + ([123], ["123"]), + ([123, 234], ["123"]), + ([123], ["123", "234"]), + ([123, 234], ["123", "234"]) + ) + + for user_role_ids, bypasses in cases: + with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses): + user_roles = [MockRole(id=role_id) for role_id in user_role_ids] + member = MockMember(roles=user_roles) + self.ctx.author = member + bypass_entry = RoleBypass(bypasses) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_for_empty_channel_scope(self): + """A filter is enabled for all channels by default.""" + channel = MockTextChannel() + scope = ChannelScope({"disabled_channels": None, "disabled_categories": None, "enabled_channels": None}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_doesnt_trigger_for_disabled_channel(self): + """A filter shouldn't trigger if it's been disabled in the channel.""" + channel = MockTextChannel(id=123) + scope = ChannelScope({"disabled_channels": [123], "disabled_categories": None, "enabled_channels": None}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_in_disabled_category(self): + """A filter shouldn't trigger if it's been disabled in the category.""" + channel = MockTextChannel() + scope = ChannelScope({ + "disabled_channels": None, "disabled_categories": [channel.category.id], "enabled_channels": None + }) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_triggers_in_enabled_channel_in_disabled_category(self): + """A filter should trigger in an enabled channel even if it's been disabled in the category.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope({"disabled_channels": None, "disabled_categories": [234], "enabled_channels": [123]}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_filtering_dms_when_necessary(self): + """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" + cases = ( + (True, MockDMChannel(), True), + (False, MockDMChannel(), False), + (True, MockTextChannel(), True), + (False, MockTextChannel(), True) + ) + + for apply_in_dms, channel, expected in cases: + with self.subTest(apply_in_dms=apply_in_dms, channel=channel): + filter_dms = FilterDM(apply_in_dms) + self.ctx.channel = channel + + result = filter_dms.triggers_on(self.ctx) + + self.assertEqual(expected, result) + + def test_infraction_merge_of_same_infraction_type(self): + """When both infractions are of the same type, the one with the longer duration wins.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 10, + "dm_content": "how", + "dm_embed": "what is" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "there", + "infraction_duration": 20, + "dm_content": "are you", + "dm_embed": "your name" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.MUTE, + "infraction_reason": "there", + "infraction_duration": 20.0, + "dm_content": "are you", + "dm_embed": "your name", + "_superstar": None + } + ) + + def test_infraction_merge_of_different_infraction_types(self): + """If there are two different infraction types, the one higher up the hierarchy should be picked.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "ban", + "infraction_reason": "", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.BAN, + "infraction_reason": "", + "infraction_duration": 10.0, + "dm_content": "there", + "dm_embed": "", + "_superstar": None + } + ) + + def test_infraction_merge_with_a_superstar(self): + """If there is a superstar infraction, it should be added to a separate field.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "there", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "hello", + "infraction_duration": 10, + "dm_content": "you", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.MUTE, + "infraction_reason": "hi", + "infraction_duration": 20.0, + "dm_content": "there", + "dm_embed": "", + "_superstar": superstar("hello", 10.0) + } + ) + + def test_merge_two_superstar_infractions(self): + """When two superstar infractions are merged, the infraction type remains a superstar.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.SUPERSTAR, + "infraction_reason": "hi", + "infraction_duration": 20.0, + "dm_content": "", + "dm_embed": "", + "_superstar": None + } + ) + + def test_merge_a_voiceban_and_a_superstar_with_another_superstar(self): + """An infraction with a superstar merged with a superstar should combine under `_superstar`.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "voice ban", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "hello", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "bla", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + infraction3 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "blabla", + "infraction_duration": 20, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 | infraction3 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.VOICE_BAN, + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "hello", + "dm_embed": "", + "_superstar": superstar("blabla", 20) + } + ) |