diff options
| -rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 17 | ||||
| -rw-r--r-- | bot/exts/filtering/_filters/filter.py | 28 | ||||
| -rw-r--r-- | bot/exts/filtering/_filters/invite.py | 4 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings.py | 109 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings_types/actions/ping.py | 15 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings_types/settings_entry.py | 31 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/filter.py | 23 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/ui.py | 24 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 6 |
9 files changed, 143 insertions, 114 deletions
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index 9eb907fc1..f993665f2 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -6,7 +6,7 @@ from discord.ext.commands import BadArgument from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._filters.filter import Filter -from bot.exts.filtering._settings import ActionSettings, ValidationSettings, create_settings +from bot.exts.filtering._settings import ActionSettings, Defaults, create_settings from bot.exts.filtering._utils import FieldRequiring, past_tense from bot.log import get_logger @@ -36,13 +36,6 @@ def list_type_converter(argument: str) -> ListType: raise BadArgument(f"No matching list type found for {argument!r}.") -class Defaults(NamedTuple): - """Represents an atomic list's default settings.""" - - actions: ActionSettings - validations: ValidationSettings - - class AtomicList(NamedTuple): """ Represents the atomic structure of a single filter list as it appears in the database. @@ -117,14 +110,14 @@ class FilterList(FieldRequiring, dict[ListType, AtomicList]): filters = {} for filter_data in list_data["filters"]: - filters[filter_data["id"]] = self._create_filter(filter_data) + filters[filter_data["id"]] = self._create_filter(filter_data, defaults) self[list_type] = AtomicList(list_data["id"], self.name, list_type, defaults, filters) return self[list_type] def add_filter(self, list_type: ListType, filter_data: dict) -> Filter: """Add a filter to the list of the specified type.""" - new_filter = self._create_filter(filter_data) + new_filter = self._create_filter(filter_data, self[list_type].defaults) self[list_type].filters[filter_data["id"]] = new_filter return new_filter @@ -141,11 +134,11 @@ class FilterList(FieldRequiring, dict[ListType, AtomicList]): async def actions_for(self, ctx: FilterContext) -> tuple[ActionSettings | None, list[str]]: """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" - def _create_filter(self, filter_data: dict) -> Filter: + def _create_filter(self, filter_data: dict, defaults: Defaults) -> Filter: """Create a filter from the given data.""" try: filter_type = self.get_filter_type(filter_data["content"]) - new_filter = filter_type(filter_data) + new_filter = filter_type(filter_data, defaults) except TypeError as e: log.warning(e) else: diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index b4a2bfe5e..0d11d5b3c 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -1,10 +1,10 @@ from abc import abstractmethod -from typing import Optional +from typing import Any, Optional from pydantic import ValidationError from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings import create_settings +from bot.exts.filtering._settings import Defaults, create_settings from bot.exts.filtering._utils import FieldRequiring @@ -22,14 +22,30 @@ class Filter(FieldRequiring): # If a subclass uses extra fields, it should assign the pydantic model type to this variable. extra_fields_type = None - def __init__(self, filter_data: dict): + def __init__(self, filter_data: dict, defaults: Defaults | None = None): self.id = filter_data["id"] self.content = filter_data["content"] self.description = filter_data["description"] - self.actions, self.validations = create_settings(filter_data["settings"]) - self.extra_fields = filter_data["additional_field"] or "{}" # noqa: P103 + self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults) if self.extra_fields_type: - self.extra_fields = self.extra_fields_type.parse_raw(self.extra_fields) + self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"]) + else: + self.extra_fields = None + + @property + def overrides(self) -> tuple[dict[str, Any], dict[str, Any]]: + """Return a tuple of setting overrides and filter setting overrides.""" + settings = {} + if self.actions: + settings = self.actions.overrides + if self.validations: + settings |= self.validations.overrides + + filter_settings = {} + if self.extra_fields: + filter_settings = self.extra_fields.dict(exclude_unset=True) + + return settings, filter_settings @abstractmethod def triggered_on(self, ctx: FilterContext) -> bool: diff --git a/bot/exts/filtering/_filters/invite.py b/bot/exts/filtering/_filters/invite.py index 0463b0032..ac4f62cb6 100644 --- a/bot/exts/filtering/_filters/invite.py +++ b/bot/exts/filtering/_filters/invite.py @@ -16,8 +16,8 @@ class InviteFilter(Filter): name = "invite" - def __init__(self, filter_data: dict): - super().__init__(filter_data) + def __init__(self, filter_data: dict, defaults_data: dict | None = None): + super().__init__(filter_data, defaults_data) self.content = int(self.content) def triggered_on(self, ctx: FilterContext) -> bool: diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index 7b09e3c52..4c2114f07 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -1,11 +1,13 @@ from __future__ import annotations +import operator from abc import abstractmethod -from typing import Any, Iterator, Mapping, Optional, TypeVar +from functools import reduce +from typing import Any, NamedTuple, Optional, TypeVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types import settings_types -from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, SettingsEntry, ValidationEntry from bot.exts.filtering._utils import FieldRequiring from bot.log import get_logger @@ -15,14 +17,18 @@ log = get_logger(__name__) _already_warned: set[str] = set() +T = TypeVar("T", bound=SettingsEntry) + def create_settings( - settings_data: dict, *, keep_empty: bool = False + settings_data: dict, *, defaults: Defaults | None = None, keep_empty: bool = False ) -> tuple[Optional[ActionSettings], Optional[ValidationSettings]]: """ Create and return instances of the Settings subclasses from the given data. Additionally, warn for data entries with no matching class. + + In case these are setting overrides, the defaults can be provided to keep track of the correct values. """ action_data = {} validation_data = {} @@ -36,13 +42,18 @@ def create_settings( f"A setting named {entry_name} was loaded from the database, but no matching class." ) _already_warned.add(entry_name) + if defaults is None: + default_actions = None + default_validations = None + else: + default_actions, default_validations = defaults return ( - ActionSettings.create(action_data, keep_empty=keep_empty), - ValidationSettings.create(validation_data, keep_empty=keep_empty) + ActionSettings.create(action_data, defaults=default_actions, keep_empty=keep_empty), + ValidationSettings.create(validation_data, defaults=default_validations, keep_empty=keep_empty) ) -class Settings(FieldRequiring): +class Settings(FieldRequiring, dict[str, T]): """ A collection of settings. @@ -54,13 +65,13 @@ class Settings(FieldRequiring): the filter list which contains the filter. """ - entry_type = FieldRequiring.MUST_SET + entry_type = T _already_warned: set[str] = set() @abstractmethod - def __init__(self, settings_data: dict, *, keep_empty: bool = False): - self._entries: dict[str, Settings.entry_type] = {} + def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): + super().__init__() entry_classes = settings_types.get(self.entry_type.__name__) for entry_name, entry_data in settings_data.items(): @@ -75,63 +86,55 @@ class Settings(FieldRequiring): self._already_warned.add(entry_name) else: try: - new_entry = entry_cls.create(entry_data, keep_empty=keep_empty) + entry_defaults = None if defaults is None else defaults[entry_name] + new_entry = entry_cls.create( + entry_data, defaults=entry_defaults, keep_empty=keep_empty + ) if new_entry: - self._entries[entry_name] = new_entry + self[entry_name] = new_entry except TypeError as e: raise TypeError( f"Attempted to load a {entry_name} setting, but the response is malformed: {entry_data}" ) from e - def __contains__(self, item: str) -> bool: - return item in self._entries - - def __setitem__(self, key: str, value: entry_type) -> None: - self._entries[key] = value + @property + def overrides(self) -> dict[str, Any]: + """Return a dictionary of overrides across all entries.""" + return reduce(operator.or_, (entry.overrides for entry in self.values() if entry), {}) def copy(self: TSettings) -> TSettings: """Create a shallow copy of the object.""" copy = self.__class__({}) - copy._entries = self._entries.copy() + copy.update(super().copy()) # Copy the internal dict. return copy - def items(self) -> Iterator[tuple[str, entry_type]]: - """Return an iterator for the items in the entries dictionary.""" - yield from self._entries.items() - - def update(self, mapping: Mapping[str, entry_type], **kwargs: entry_type) -> None: - """Update the entries with items from `mapping` and the kwargs.""" - self._entries.update(mapping, **kwargs) - - def get(self, key: str, default: Optional[Any] = None) -> entry_type: - """Get the entry matching the key, or fall back to the default value if the key is missing.""" - return self._entries.get(key, default) - def get_setting(self, key: str, default: Optional[Any] = None) -> Any: """Get the setting matching the key, or fall back to the default value if the key is missing.""" - for entry in self._entries.values(): + for entry in self.values(): if hasattr(entry, key): return getattr(entry, key) return default @classmethod - def create(cls, settings_data: dict, *, keep_empty: bool = False) -> Optional[Settings]: + def create( + cls, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False + ) -> Optional[Settings]: """ Returns a Settings object from `settings_data` if it holds any value, None otherwise. Use this method to create Settings objects instead of the init. The None value is significant for how a filter list iterates over its filters. """ - settings = cls(settings_data, keep_empty=keep_empty) + settings = cls(settings_data, defaults=defaults, keep_empty=keep_empty) # If an entry doesn't hold any values, its `create` method will return None. # If all entries are None, then the settings object holds no values. - if not keep_empty and not any(settings._entries.values()): + if not keep_empty and not any(settings.values()): return None return settings -class ValidationSettings(Settings): +class ValidationSettings(Settings[ValidationEntry]): """ A collection of validation settings. @@ -141,16 +144,15 @@ class ValidationSettings(Settings): entry_type = ValidationEntry - def __init__(self, settings_data: dict, *, keep_empty: bool = False): - super().__init__(settings_data, keep_empty=keep_empty) + def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): + super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty) def evaluate(self, ctx: FilterContext) -> tuple[set[str], set[str]]: """Evaluates for each setting whether the context is relevant to the filter.""" passed = set() failed = set() - self._entries: dict[str, ValidationEntry] - for name, validation in self._entries.items(): + for name, validation in self.items(): if validation: if validation.triggers_on(ctx): passed.add(name) @@ -160,7 +162,7 @@ class ValidationSettings(Settings): return passed, failed -class ActionSettings(Settings): +class ActionSettings(Settings[ActionEntry]): """ A collection of action settings. @@ -170,21 +172,21 @@ class ActionSettings(Settings): entry_type = ActionEntry - def __init__(self, settings_data: dict, *, keep_empty: bool = False): - super().__init__(settings_data, keep_empty=keep_empty) + def __init__(self, settings_data: dict, *, defaults: Settings | None = None, keep_empty: bool = False): + super().__init__(settings_data, defaults=defaults, keep_empty=keep_empty) def __or__(self, other: ActionSettings) -> ActionSettings: """Combine the entries of two collections of settings into a new ActionsSettings.""" actions = {} # A settings object doesn't necessarily have all types of entries (e.g in the case of filter overrides). - for entry in self._entries: - if entry in other._entries: - actions[entry] = self._entries[entry] | other._entries[entry] + for entry in self: + if entry in other: + actions[entry] = self[entry] | other[entry] else: - actions[entry] = self._entries[entry] - for entry in other._entries: + actions[entry] = self[entry] + for entry in other: if entry not in actions: - actions[entry] = other._entries[entry] + actions[entry] = other[entry] result = ActionSettings({}) result.update(actions) @@ -192,13 +194,20 @@ class ActionSettings(Settings): async def action(self, ctx: FilterContext) -> None: """Execute the action of every action entry stored.""" - for entry in self._entries.values(): + for entry in self.values(): await entry.action(ctx) def fallback_to(self, fallback: ActionSettings) -> ActionSettings: """Fill in missing entries from `fallback`.""" new_actions = self.copy() for entry_name, entry_value in fallback.items(): - if entry_name not in self._entries: - new_actions._entries[entry_name] = entry_value + if entry_name not in self: + new_actions[entry_name] = entry_value return new_actions + + +class Defaults(NamedTuple): + """Represents an atomic list's default settings.""" + + actions: ActionSettings + validations: ValidationSettings diff --git a/bot/exts/filtering/_settings_types/actions/ping.py b/bot/exts/filtering/_settings_types/actions/ping.py index 85590478c..faac8f4b9 100644 --- a/bot/exts/filtering/_settings_types/actions/ping.py +++ b/bot/exts/filtering/_settings_types/actions/ping.py @@ -1,9 +1,10 @@ from functools import cache from typing import ClassVar -from discord import Guild from pydantic import validator +import bot +from bot.constants import Guild from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry @@ -37,7 +38,7 @@ class Ping(ActionEntry): async def action(self, ctx: FilterContext) -> None: """Add the stored pings to the alert message content.""" mentions = self.guild_pings if ctx.channel.guild else self.dm_pings - new_content = " ".join([self._resolve_mention(mention, ctx.channel.guild) for mention in mentions]) + new_content = " ".join([self._resolve_mention(mention) for mention in mentions]) ctx.alert_content = f"{new_content} {ctx.alert_content}" def __or__(self, other: ActionEntry): @@ -49,12 +50,16 @@ class Ping(ActionEntry): @staticmethod @cache - def _resolve_mention(mention: str, guild: Guild) -> str: + def _resolve_mention(mention: str) -> str: """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" + guild = bot.instance.get_guild(Guild.id) if mention in ("here", "everyone"): return f"@{mention}" - if mention.isdigit(): # It's an ID. - mention = int(mention) + try: + mention = int(mention) # It's an ID. + except ValueError: + pass + else: if any(mention == role.id for role in guild.roles): return f"<@&{mention}>" else: diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py index 5a7e41cac..31e11108d 100644 --- a/bot/exts/filtering/_settings_types/settings_entry.py +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._utils import FieldRequiring @@ -19,12 +19,33 @@ class SettingsEntry(BaseModel, FieldRequiring): # Each subclass must define a name matching the entry name we're expecting to receive from the database. # Names must be unique across all filter lists. name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE - # Each subclass must define a description of what it does. If the data an entry type receives is comprised of + # Each subclass must define a description of what it does. If the data an entry type receives comprises # several DB fields, the value should a dictionary of field names and their descriptions. description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET + _overrides: set[str] = PrivateAttr(default_factory=set) + + def __init__(self, defaults: SettingsEntry | None = None, /, **data): + overrides = set() + if defaults: + defaults_dict = defaults.dict() + for field_name, field_value in list(data.items()): + if field_value is None: + data[field_name] = defaults_dict[field_name] + else: + overrides.add(field_name) + super().__init__(**data) + self._overrides |= overrides + + @property + def overrides(self) -> dict[str, Any]: + """Return a dictionary of overrides.""" + return {name: getattr(self, name) for name in self._overrides} + @classmethod - def create(cls, entry_data: Optional[dict[str, Any]], *, keep_empty: bool = False) -> Optional[SettingsEntry]: + def create( + cls, entry_data: dict[str, Any] | None, *, defaults: SettingsEntry | None = None, keep_empty: bool = False + ) -> SettingsEntry | None: """ Returns a SettingsEntry object from `entry_data` if it holds any value, None otherwise. @@ -38,7 +59,7 @@ class SettingsEntry(BaseModel, FieldRequiring): if not isinstance(entry_data, dict): entry_data = {cls.name: entry_data} - return cls(**entry_data) + return cls(defaults, **entry_data) class ValidationEntry(SettingsEntry): diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 38eef3ca6..765fba683 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -424,23 +424,10 @@ def description_and_settings_converter( return description, settings, filter_settings -def filter_overrides(filter_: Filter, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]: - """Get the filter's overrides to the filter list settings and the extra fields settings.""" - overrides_values = {} - for settings in (filter_.actions, filter_.validations): - if settings: - for _, setting in settings.items(): - for setting_name, value in to_serializable(setting.dict()).items(): - if not repr_equals(value, filter_list[list_type].default(setting_name)): - overrides_values[setting_name] = value - - if filter_.extra_fields_type: - # The values here can be safely used since overrides equal to the defaults won't be saved. - extra_fields_overrides = filter_.extra_fields.dict(exclude_unset=True) - else: - extra_fields_overrides = {} - - return overrides_values, extra_fields_overrides +def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]: + """Get a serializable version of the filter's overrides.""" + overrides_values, extra_fields_overrides = filter_.overrides + return to_serializable(overrides_values), to_serializable(extra_fields_overrides) def template_settings(filter_id: str, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]: @@ -457,4 +444,4 @@ def template_settings(filter_id: str, filter_list: FilterList, list_type: ListTy f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list." ) filter_ = filter_list[list_type].filters[filter_id] - return filter_overrides(filter_, filter_list, list_type) + return filter_serializable_overrides(filter_) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 5a60bb21e..980eba02a 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -291,10 +291,8 @@ class SequenceEditView(discord.ui.View): if _i != len(self.stored_value): self.stored_value.pop(_i) - select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] - if not self.stored_value: - self.remove_item(self.removal_select) - await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self) + await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self.copy()) + self.stop() async def apply_addition(self, interaction: Interaction, item: str) -> None: """Add an item to the list.""" @@ -303,18 +301,14 @@ class SequenceEditView(discord.ui.View): return self.stored_value.append(item) - self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] - if len(self.stored_value) == 1: - self.add_item(self.removal_select) - await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self) + await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self.copy()) + self.stop() async def apply_edit(self, interaction: Interaction, new_list: str) -> None: """Change the contents of the list.""" - self.stored_value = list(set(new_list.split(","))) - self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] - if len(self.stored_value) == 1: - self.add_item(self.removal_select) - await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self) + self.stored_value = list(set(part.strip() for part in new_list.split(",") if part.strip())) + await interaction.response.edit_message(content=f"Current list: {self.stored_value}", view=self.copy()) + self.stop() @discord.ui.button(label="Add Value") async def add_value(self, interaction: Interaction, button: discord.ui.Button) -> None: @@ -340,6 +334,10 @@ class SequenceEditView(discord.ui.View): await interaction.response.edit_message(content="🚫 Canceled", view=None) self.stop() + def copy(self) -> SequenceEditView: + """Return a copy of this view.""" + return SequenceEditView(self.setting_name, self.stored_value, self.update_callback) + class EnumSelectView(discord.ui.View): """A view containing an instance of EnumSelect.""" diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 563bdacb5..6ff5181a9 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -23,7 +23,7 @@ from bot.exts.filtering._filter_lists.filter_list import AtomicList from bot.exts.filtering._filters.filter import Filter from bot.exts.filtering._settings import ActionSettings from bot.exts.filtering._ui.filter import ( - build_filter_repr_dict, description_and_settings_converter, filter_overrides, populate_embed_from_dict + build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict ) from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter from bot.exts.filtering._ui.ui import ArgumentCompletionView, DeleteConfirmationView, format_response_error @@ -280,7 +280,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result - overrides_values, extra_fields_overrides = filter_overrides(filter_, filter_list, list_type) + overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_) all_settings_repr_dict = build_filter_repr_dict( filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -388,7 +388,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result filter_type = type(filter_) - settings, filter_settings = filter_overrides(filter_, filter_list, list_type) + settings, filter_settings = filter_serializable_overrides(filter_) description, new_settings, new_filter_settings = description_and_settings_converter( filter_list, list_type, filter_type, |