diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 12 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings.py | 7 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings_types/settings_entry.py | 2 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui.py | 48 | ||||
| -rw-r--r-- | bot/exts/filtering/_utils.py | 16 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 28 | 
6 files changed, 90 insertions, 23 deletions
| diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index c34f46878..a62013192 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -1,6 +1,6 @@  from abc import abstractmethod  from enum import Enum -from typing import Optional, Type +from typing import Any, Optional, Type  from discord.ext.commands import BadArgument @@ -71,6 +71,16 @@ class FilterList(FieldRequiring):          else:              return new_filter +    def default(self, list_type: ListType, setting: str) -> Any: +        """Get the default value of a specific setting.""" +        missing = object() +        value = self.defaults[list_type]["actions"].get_setting(setting, missing) +        if value is missing: +            value = self.defaults[list_type]["validations"].get_setting(setting, missing) +            if value is missing: +                raise ValueError(f"Could find a setting named {setting}.") +        return value +      @abstractmethod      def get_filter_type(self, content: str) -> Type[Filter]:          """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index cbd682d6d..7b09e3c52 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -107,6 +107,13 @@ class Settings(FieldRequiring):          """Get the entry matching the key, or fall back to the default value if the key is missing."""          return self._entries.get(key, default) +    def get_setting(self, key: str, default: Optional[Any] = None) -> Any: +        """Get the setting matching the key, or fall back to the default value if the key is missing.""" +        for entry in self._entries.values(): +            if hasattr(entry, key): +                return getattr(entry, key) +        return default +      @classmethod      def create(cls, settings_data: dict, *, keep_empty: bool = False) -> Optional[Settings]:          """ diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py index 2b3b030a0..5a7e41cac 100644 --- a/bot/exts/filtering/_settings_types/settings_entry.py +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -33,7 +33,7 @@ class SettingsEntry(BaseModel, FieldRequiring):          """          if entry_data is None:              return None -        if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()): +        if not keep_empty and hasattr(entry_data, "values") and all(value is None for value in entry_data.values()):              return None          if not isinstance(entry_data, dict): diff --git a/bot/exts/filtering/_ui.py b/bot/exts/filtering/_ui.py index ec2051083..8bfcded77 100644 --- a/bot/exts/filtering/_ui.py +++ b/bot/exts/filtering/_ui.py @@ -15,7 +15,7 @@ from discord.ui.select import MISSING, SelectOption  from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType  from bot.exts.filtering._filters.filter import Filter -from bot.exts.filtering._utils import to_serializable +from bot.exts.filtering._utils import repr_equals, to_serializable  from bot.log import get_logger  log = get_logger(__name__) @@ -115,7 +115,7 @@ def build_filter_repr_dict(      # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.      total_values = {}      for name, value in default_setting_values.items(): -        if name not in settings_overrides: +        if name not in settings_overrides or repr_equals(settings_overrides[name], value):              total_values[name] = value          else:              total_values[f"{name}*"] = settings_overrides[name] @@ -124,7 +124,7 @@ def build_filter_repr_dict(      if filter_type.extra_fields_type:          # This iterates over the default values of the extra fields model.          for name, value in filter_type.extra_fields_type().dict().items(): -            if name not in extra_fields_overrides: +            if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value):                  total_values[f"{filter_type.name}/{name}"] = value              else:                  total_values[f"{filter_type.name}/{name}*"] = value @@ -248,7 +248,7 @@ class FreeInputModal(discord.ui.Modal):      async def on_submit(self, interaction: Interaction) -> None:          """Update the setting with the new value in the embed."""          try: -            value = self.type_(self.setting_input.value) or None +            value = self.type_(self.setting_input.value)          except (ValueError, TypeError):              await interaction.response.send_message(                  f"Could not process the input value for `{self.setting_name}`.", ephemeral=True @@ -318,6 +318,10 @@ class SequenceEditView(discord.ui.View):      async def apply_addition(self, interaction: Interaction, item: str) -> None:          """Add an item to the list.""" +        if item in self.stored_value:  # Ignore duplicates +            await interaction.response.defer() +            return +          self.stored_value.append(item)          self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]]          if len(self.stored_value) == 1: @@ -326,7 +330,7 @@ class SequenceEditView(discord.ui.View):      async def apply_edit(self, interaction: Interaction, new_list: str) -> None:          """Change the contents of the list.""" -        self.stored_value = new_list.split(",") +        self.stored_value = list(set(new_list.split(",")))          self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]]          if len(self.stored_value) == 1:              self.add_item(self.removal_select) @@ -571,11 +575,17 @@ class SettingsEditView(discord.ui.View):              if "/" in setting_name:                  filter_name, setting_name = setting_name.split("/", maxsplit=1)                  dict_to_edit = self.filter_settings_overrides +                default_value = self.filter_type.extra_fields_type().dict()[setting_name]              else:                  dict_to_edit = self.settings_overrides +                default_value = self.filter_list.default(self.list_type, setting_name)              # Update the setting override value or remove it              if setting_value is not self._REMOVE: -                dict_to_edit[setting_name] = setting_value +                if not repr_equals(setting_value, default_value): +                    dict_to_edit[setting_name] = setting_value +                # If there's already an override, remove it, since the new value is the same as the default. +                elif setting_name in dict_to_edit: +                    del dict_to_edit[setting_name]              elif setting_name in dict_to_edit:                  del dict_to_edit[setting_name] @@ -657,7 +667,12 @@ def _parse_value(value: str, type_: type[T]) -> T:  def description_and_settings_converter( -    list_name: str, loaded_settings: dict, loaded_filter_settings: dict, input_data: str +    filter_list: FilterList, +    list_type: ListType, +    filter_type: type[Filter], +    loaded_settings: dict, +    loaded_filter_settings: dict, +    input_data: str  ) -> tuple[str, dict[str, Any], dict[str, Any]]:      """Parse a string representing a possible description and setting overrides, and validate the setting names."""      if not input_data: @@ -679,25 +694,32 @@ def description_and_settings_converter(      filter_settings = {}      for setting, _ in list(settings.items()):          if setting not in loaded_settings: +            # It's a filter setting              if "/" in setting:                  setting_list_name, filter_setting_name = setting.split("/", maxsplit=1) -                if setting_list_name.lower() != list_name.lower(): +                if setting_list_name.lower() != filter_list.name.lower():                      raise BadArgument( -                        f"A setting for a {setting_list_name!r} filter was provided, but the list name is {list_name!r}" +                        f"A setting for a {setting_list_name!r} filter was provided, " +                        f"but the list name is {filter_list.name!r}"                      ) -                if filter_setting_name not in loaded_filter_settings[list_name]: +                if filter_setting_name not in loaded_filter_settings[filter_list.name]:                      raise BadArgument(f"{setting!r} is not a recognized setting.") -                type_ = loaded_filter_settings[list_name][filter_setting_name][2] +                type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2]                  try: -                    filter_settings[filter_setting_name] = _parse_value(settings.pop(setting), type_) +                    parsed_value = _parse_value(settings.pop(setting), type_) +                    if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): +                        filter_settings[filter_setting_name] = parsed_value                  except (TypeError, ValueError) as e:                      raise BadArgument(e)              else:                  raise BadArgument(f"{setting!r} is not a recognized setting.") +        # It's a filter list setting          else:              type_ = loaded_settings[setting][2]              try: -                settings[setting] = _parse_value(settings.pop(setting), type_) +                parsed_value = _parse_value(settings.pop(setting), type_) +                if not repr_equals(parsed_value, filter_list.default(list_type, setting)): +                    settings[setting] = parsed_value              except (TypeError, ValueError) as e:                  raise BadArgument(e) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index 158f1e7bd..438c22d41 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -66,6 +66,22 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]      return str(item) +def repr_equals(override: Any, default: Any) -> bool: +    """Return whether the override and the default have the same representation.""" +    if override is None:  # It's not an override +        return True + +    override_is_sequence = isinstance(override, (tuple, list, set)) +    default_is_sequence = isinstance(default, (tuple, list, set)) +    if override_is_sequence != default_is_sequence:  # One is a sequence and the other isn't. +        return False +    if override_is_sequence: +        if len(override) != len(default): +            return False +        return all(str(item1) == str(item2) for item1, item2 in zip(set(override), set(default))) +    return str(override) == str(default) + +  class FieldRequiring(ABC):      """A mixin class that can force its concrete subclasses to set a value for specific class attributes.""" diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 5f42e2cab..aa90d1600 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -21,7 +21,7 @@ from bot.exts.filtering._settings import ActionSettings  from bot.exts.filtering._ui import (      ArgumentCompletionView, build_filter_repr_dict, description_and_settings_converter, populate_embed_from_dict  ) -from bot.exts.filtering._utils import past_tense, to_serializable +from bot.exts.filtering._utils import past_tense, repr_equals, to_serializable  from bot.log import get_logger  from bot.pagination import LinePaginator  from bot.utils.messages import format_channel, format_user @@ -269,7 +269,7 @@ class Filtering(Cog):              return          filter_, filter_list, list_type = result -        overrides_values, extra_fields_overrides = self._filter_overrides(filter_) +        overrides_values, extra_fields_overrides = self._filter_overrides(filter_, filter_list, list_type)          all_settings_repr_dict = build_filter_repr_dict(              filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -371,9 +371,13 @@ class Filtering(Cog):              return          filter_, filter_list, list_type = result          filter_type = type(filter_) -        settings, filter_settings = self._filter_overrides(filter_) +        settings, filter_settings = self._filter_overrides(filter_, filter_list, list_type)          description, new_settings, new_filter_settings = description_and_settings_converter( -            filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings +            filter_list, +            list_type, filter_type, +            self.loaded_settings, +            self.loaded_filter_settings, +            description_and_settings          )          content = filter_.content @@ -620,15 +624,18 @@ class Filtering(Cog):                      return sublist[id_], filter_list, list_type      @staticmethod -    def _filter_overrides(filter_: Filter) -> tuple[dict, dict]: +    def _filter_overrides(filter_: Filter, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]:          """Get the filter's overrides to the filter list settings and the extra fields settings."""          overrides_values = {}          for settings in (filter_.actions, filter_.validations):              if settings:                  for _, setting in settings.items(): -                    overrides_values.update(to_serializable(setting.dict())) +                    for setting_name, value in to_serializable(setting.dict()).items(): +                        if not repr_equals(value, filter_list.default(list_type, setting_name)): +                            overrides_values[setting_name] = value          if filter_.extra_fields_type: +            # The values here can be safely used since overrides equal to the defaults won't be saved.              extra_fields_overrides = filter_.extra_fields.dict(exclude_unset=True)          else:              extra_fields_overrides = {} @@ -645,10 +652,15 @@ class Filtering(Cog):          description_and_settings: Optional[str] = None      ) -> None:          """Add a filter to the database.""" +        filter_type = filter_list.get_filter_type(content)          description, settings, filter_settings = description_and_settings_converter( -            filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings +            filter_list, +            list_type, +            filter_type, +            self.loaded_settings, +            self.loaded_filter_settings, +            description_and_settings          ) -        filter_type = filter_list.get_filter_type(content)          if noui:              try: | 
