diff options
| -rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 12 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings.py | 7 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings_types/settings_entry.py | 2 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui.py | 48 | ||||
| -rw-r--r-- | bot/exts/filtering/_utils.py | 16 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 28 |
6 files changed, 90 insertions, 23 deletions
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index c34f46878..a62013192 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import Optional, Type +from typing import Any, Optional, Type from discord.ext.commands import BadArgument @@ -71,6 +71,16 @@ class FilterList(FieldRequiring): else: return new_filter + def default(self, list_type: ListType, setting: str) -> Any: + """Get the default value of a specific setting.""" + missing = object() + value = self.defaults[list_type]["actions"].get_setting(setting, missing) + if value is missing: + value = self.defaults[list_type]["validations"].get_setting(setting, missing) + if value is missing: + raise ValueError(f"Could find a setting named {setting}.") + return value + @abstractmethod def get_filter_type(self, content: str) -> Type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index cbd682d6d..7b09e3c52 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -107,6 +107,13 @@ class Settings(FieldRequiring): """Get the entry matching the key, or fall back to the default value if the key is missing.""" return self._entries.get(key, default) + def get_setting(self, key: str, default: Optional[Any] = None) -> Any: + """Get the setting matching the key, or fall back to the default value if the key is missing.""" + for entry in self._entries.values(): + if hasattr(entry, key): + return getattr(entry, key) + return default + @classmethod def create(cls, settings_data: dict, *, keep_empty: bool = False) -> Optional[Settings]: """ diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py index 2b3b030a0..5a7e41cac 100644 --- a/bot/exts/filtering/_settings_types/settings_entry.py +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -33,7 +33,7 @@ class SettingsEntry(BaseModel, FieldRequiring): """ if entry_data is None: return None - if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()): + if not keep_empty and hasattr(entry_data, "values") and all(value is None for value in entry_data.values()): return None if not isinstance(entry_data, dict): diff --git a/bot/exts/filtering/_ui.py b/bot/exts/filtering/_ui.py index ec2051083..8bfcded77 100644 --- a/bot/exts/filtering/_ui.py +++ b/bot/exts/filtering/_ui.py @@ -15,7 +15,7 @@ from discord.ui.select import MISSING, SelectOption from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType from bot.exts.filtering._filters.filter import Filter -from bot.exts.filtering._utils import to_serializable +from bot.exts.filtering._utils import repr_equals, to_serializable from bot.log import get_logger log = get_logger(__name__) @@ -115,7 +115,7 @@ def build_filter_repr_dict( # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings. total_values = {} for name, value in default_setting_values.items(): - if name not in settings_overrides: + if name not in settings_overrides or repr_equals(settings_overrides[name], value): total_values[name] = value else: total_values[f"{name}*"] = settings_overrides[name] @@ -124,7 +124,7 @@ def build_filter_repr_dict( if filter_type.extra_fields_type: # This iterates over the default values of the extra fields model. for name, value in filter_type.extra_fields_type().dict().items(): - if name not in extra_fields_overrides: + if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value): total_values[f"{filter_type.name}/{name}"] = value else: total_values[f"{filter_type.name}/{name}*"] = value @@ -248,7 +248,7 @@ class FreeInputModal(discord.ui.Modal): async def on_submit(self, interaction: Interaction) -> None: """Update the setting with the new value in the embed.""" try: - value = self.type_(self.setting_input.value) or None + value = self.type_(self.setting_input.value) except (ValueError, TypeError): await interaction.response.send_message( f"Could not process the input value for `{self.setting_name}`.", ephemeral=True @@ -318,6 +318,10 @@ class SequenceEditView(discord.ui.View): async def apply_addition(self, interaction: Interaction, item: str) -> None: """Add an item to the list.""" + if item in self.stored_value: # Ignore duplicates + await interaction.response.defer() + return + self.stored_value.append(item) self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] if len(self.stored_value) == 1: @@ -326,7 +330,7 @@ class SequenceEditView(discord.ui.View): async def apply_edit(self, interaction: Interaction, new_list: str) -> None: """Change the contents of the list.""" - self.stored_value = new_list.split(",") + self.stored_value = list(set(new_list.split(","))) self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]] if len(self.stored_value) == 1: self.add_item(self.removal_select) @@ -571,11 +575,17 @@ class SettingsEditView(discord.ui.View): if "/" in setting_name: filter_name, setting_name = setting_name.split("/", maxsplit=1) dict_to_edit = self.filter_settings_overrides + default_value = self.filter_type.extra_fields_type().dict()[setting_name] else: dict_to_edit = self.settings_overrides + default_value = self.filter_list.default(self.list_type, setting_name) # Update the setting override value or remove it if setting_value is not self._REMOVE: - dict_to_edit[setting_name] = setting_value + if not repr_equals(setting_value, default_value): + dict_to_edit[setting_name] = setting_value + # If there's already an override, remove it, since the new value is the same as the default. + elif setting_name in dict_to_edit: + del dict_to_edit[setting_name] elif setting_name in dict_to_edit: del dict_to_edit[setting_name] @@ -657,7 +667,12 @@ def _parse_value(value: str, type_: type[T]) -> T: def description_and_settings_converter( - list_name: str, loaded_settings: dict, loaded_filter_settings: dict, input_data: str + filter_list: FilterList, + list_type: ListType, + filter_type: type[Filter], + loaded_settings: dict, + loaded_filter_settings: dict, + input_data: str ) -> tuple[str, dict[str, Any], dict[str, Any]]: """Parse a string representing a possible description and setting overrides, and validate the setting names.""" if not input_data: @@ -679,25 +694,32 @@ def description_and_settings_converter( filter_settings = {} for setting, _ in list(settings.items()): if setting not in loaded_settings: + # It's a filter setting if "/" in setting: setting_list_name, filter_setting_name = setting.split("/", maxsplit=1) - if setting_list_name.lower() != list_name.lower(): + if setting_list_name.lower() != filter_list.name.lower(): raise BadArgument( - f"A setting for a {setting_list_name!r} filter was provided, but the list name is {list_name!r}" + f"A setting for a {setting_list_name!r} filter was provided, " + f"but the list name is {filter_list.name!r}" ) - if filter_setting_name not in loaded_filter_settings[list_name]: + if filter_setting_name not in loaded_filter_settings[filter_list.name]: raise BadArgument(f"{setting!r} is not a recognized setting.") - type_ = loaded_filter_settings[list_name][filter_setting_name][2] + type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2] try: - filter_settings[filter_setting_name] = _parse_value(settings.pop(setting), type_) + parsed_value = _parse_value(settings.pop(setting), type_) + if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): + filter_settings[filter_setting_name] = parsed_value except (TypeError, ValueError) as e: raise BadArgument(e) else: raise BadArgument(f"{setting!r} is not a recognized setting.") + # It's a filter list setting else: type_ = loaded_settings[setting][2] try: - settings[setting] = _parse_value(settings.pop(setting), type_) + parsed_value = _parse_value(settings.pop(setting), type_) + if not repr_equals(parsed_value, filter_list.default(list_type, setting)): + settings[setting] = parsed_value except (TypeError, ValueError) as e: raise BadArgument(e) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index 158f1e7bd..438c22d41 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -66,6 +66,22 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None] return str(item) +def repr_equals(override: Any, default: Any) -> bool: + """Return whether the override and the default have the same representation.""" + if override is None: # It's not an override + return True + + override_is_sequence = isinstance(override, (tuple, list, set)) + default_is_sequence = isinstance(default, (tuple, list, set)) + if override_is_sequence != default_is_sequence: # One is a sequence and the other isn't. + return False + if override_is_sequence: + if len(override) != len(default): + return False + return all(str(item1) == str(item2) for item1, item2 in zip(set(override), set(default))) + return str(override) == str(default) + + class FieldRequiring(ABC): """A mixin class that can force its concrete subclasses to set a value for specific class attributes.""" diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 5f42e2cab..aa90d1600 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -21,7 +21,7 @@ from bot.exts.filtering._settings import ActionSettings from bot.exts.filtering._ui import ( ArgumentCompletionView, build_filter_repr_dict, description_and_settings_converter, populate_embed_from_dict ) -from bot.exts.filtering._utils import past_tense, to_serializable +from bot.exts.filtering._utils import past_tense, repr_equals, to_serializable from bot.log import get_logger from bot.pagination import LinePaginator from bot.utils.messages import format_channel, format_user @@ -269,7 +269,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result - overrides_values, extra_fields_overrides = self._filter_overrides(filter_) + overrides_values, extra_fields_overrides = self._filter_overrides(filter_, filter_list, list_type) all_settings_repr_dict = build_filter_repr_dict( filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -371,9 +371,13 @@ class Filtering(Cog): return filter_, filter_list, list_type = result filter_type = type(filter_) - settings, filter_settings = self._filter_overrides(filter_) + settings, filter_settings = self._filter_overrides(filter_, filter_list, list_type) description, new_settings, new_filter_settings = description_and_settings_converter( - filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings + filter_list, + list_type, filter_type, + self.loaded_settings, + self.loaded_filter_settings, + description_and_settings ) content = filter_.content @@ -620,15 +624,18 @@ class Filtering(Cog): return sublist[id_], filter_list, list_type @staticmethod - def _filter_overrides(filter_: Filter) -> tuple[dict, dict]: + def _filter_overrides(filter_: Filter, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]: """Get the filter's overrides to the filter list settings and the extra fields settings.""" overrides_values = {} for settings in (filter_.actions, filter_.validations): if settings: for _, setting in settings.items(): - overrides_values.update(to_serializable(setting.dict())) + for setting_name, value in to_serializable(setting.dict()).items(): + if not repr_equals(value, filter_list.default(list_type, setting_name)): + overrides_values[setting_name] = value if filter_.extra_fields_type: + # The values here can be safely used since overrides equal to the defaults won't be saved. extra_fields_overrides = filter_.extra_fields.dict(exclude_unset=True) else: extra_fields_overrides = {} @@ -645,10 +652,15 @@ class Filtering(Cog): description_and_settings: Optional[str] = None ) -> None: """Add a filter to the database.""" + filter_type = filter_list.get_filter_type(content) description, settings, filter_settings = description_and_settings_converter( - filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings + filter_list, + list_type, + filter_type, + self.loaded_settings, + self.loaded_filter_settings, + description_and_settings ) - filter_type = filter_list.get_filter_type(content) if noui: try: |