aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/filtering/_filter_lists/filter_list.py12
-rw-r--r--bot/exts/filtering/_settings.py7
-rw-r--r--bot/exts/filtering/_settings_types/settings_entry.py2
-rw-r--r--bot/exts/filtering/_ui.py48
-rw-r--r--bot/exts/filtering/_utils.py16
-rw-r--r--bot/exts/filtering/filtering.py28
6 files changed, 90 insertions, 23 deletions
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py
index c34f46878..a62013192 100644
--- a/bot/exts/filtering/_filter_lists/filter_list.py
+++ b/bot/exts/filtering/_filter_lists/filter_list.py
@@ -1,6 +1,6 @@
from abc import abstractmethod
from enum import Enum
-from typing import Optional, Type
+from typing import Any, Optional, Type
from discord.ext.commands import BadArgument
@@ -71,6 +71,16 @@ class FilterList(FieldRequiring):
else:
return new_filter
+ def default(self, list_type: ListType, setting: str) -> Any:
+ """Get the default value of a specific setting."""
+ missing = object()
+ value = self.defaults[list_type]["actions"].get_setting(setting, missing)
+ if value is missing:
+ value = self.defaults[list_type]["validations"].get_setting(setting, missing)
+ if value is missing:
+ raise ValueError(f"Could find a setting named {setting}.")
+ return value
+
@abstractmethod
def get_filter_type(self, content: str) -> Type[Filter]:
"""Get a subclass of filter matching the filter list and the filter's content."""
diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py
index cbd682d6d..7b09e3c52 100644
--- a/bot/exts/filtering/_settings.py
+++ b/bot/exts/filtering/_settings.py
@@ -107,6 +107,13 @@ class Settings(FieldRequiring):
"""Get the entry matching the key, or fall back to the default value if the key is missing."""
return self._entries.get(key, default)
+ def get_setting(self, key: str, default: Optional[Any] = None) -> Any:
+ """Get the setting matching the key, or fall back to the default value if the key is missing."""
+ for entry in self._entries.values():
+ if hasattr(entry, key):
+ return getattr(entry, key)
+ return default
+
@classmethod
def create(cls, settings_data: dict, *, keep_empty: bool = False) -> Optional[Settings]:
"""
diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py
index 2b3b030a0..5a7e41cac 100644
--- a/bot/exts/filtering/_settings_types/settings_entry.py
+++ b/bot/exts/filtering/_settings_types/settings_entry.py
@@ -33,7 +33,7 @@ class SettingsEntry(BaseModel, FieldRequiring):
"""
if entry_data is None:
return None
- if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()):
+ if not keep_empty and hasattr(entry_data, "values") and all(value is None for value in entry_data.values()):
return None
if not isinstance(entry_data, dict):
diff --git a/bot/exts/filtering/_ui.py b/bot/exts/filtering/_ui.py
index ec2051083..8bfcded77 100644
--- a/bot/exts/filtering/_ui.py
+++ b/bot/exts/filtering/_ui.py
@@ -15,7 +15,7 @@ from discord.ui.select import MISSING, SelectOption
from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
from bot.exts.filtering._filters.filter import Filter
-from bot.exts.filtering._utils import to_serializable
+from bot.exts.filtering._utils import repr_equals, to_serializable
from bot.log import get_logger
log = get_logger(__name__)
@@ -115,7 +115,7 @@ def build_filter_repr_dict(
# Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.
total_values = {}
for name, value in default_setting_values.items():
- if name not in settings_overrides:
+ if name not in settings_overrides or repr_equals(settings_overrides[name], value):
total_values[name] = value
else:
total_values[f"{name}*"] = settings_overrides[name]
@@ -124,7 +124,7 @@ def build_filter_repr_dict(
if filter_type.extra_fields_type:
# This iterates over the default values of the extra fields model.
for name, value in filter_type.extra_fields_type().dict().items():
- if name not in extra_fields_overrides:
+ if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value):
total_values[f"{filter_type.name}/{name}"] = value
else:
total_values[f"{filter_type.name}/{name}*"] = value
@@ -248,7 +248,7 @@ class FreeInputModal(discord.ui.Modal):
async def on_submit(self, interaction: Interaction) -> None:
"""Update the setting with the new value in the embed."""
try:
- value = self.type_(self.setting_input.value) or None
+ value = self.type_(self.setting_input.value)
except (ValueError, TypeError):
await interaction.response.send_message(
f"Could not process the input value for `{self.setting_name}`.", ephemeral=True
@@ -318,6 +318,10 @@ class SequenceEditView(discord.ui.View):
async def apply_addition(self, interaction: Interaction, item: str) -> None:
"""Add an item to the list."""
+ if item in self.stored_value: # Ignore duplicates
+ await interaction.response.defer()
+ return
+
self.stored_value.append(item)
self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]]
if len(self.stored_value) == 1:
@@ -326,7 +330,7 @@ class SequenceEditView(discord.ui.View):
async def apply_edit(self, interaction: Interaction, new_list: str) -> None:
"""Change the contents of the list."""
- self.stored_value = new_list.split(",")
+ self.stored_value = list(set(new_list.split(",")))
self.removal_select.options = [SelectOption(label=item) for item in self.stored_value[:MAX_SELECT_ITEMS]]
if len(self.stored_value) == 1:
self.add_item(self.removal_select)
@@ -571,11 +575,17 @@ class SettingsEditView(discord.ui.View):
if "/" in setting_name:
filter_name, setting_name = setting_name.split("/", maxsplit=1)
dict_to_edit = self.filter_settings_overrides
+ default_value = self.filter_type.extra_fields_type().dict()[setting_name]
else:
dict_to_edit = self.settings_overrides
+ default_value = self.filter_list.default(self.list_type, setting_name)
# Update the setting override value or remove it
if setting_value is not self._REMOVE:
- dict_to_edit[setting_name] = setting_value
+ if not repr_equals(setting_value, default_value):
+ dict_to_edit[setting_name] = setting_value
+ # If there's already an override, remove it, since the new value is the same as the default.
+ elif setting_name in dict_to_edit:
+ del dict_to_edit[setting_name]
elif setting_name in dict_to_edit:
del dict_to_edit[setting_name]
@@ -657,7 +667,12 @@ def _parse_value(value: str, type_: type[T]) -> T:
def description_and_settings_converter(
- list_name: str, loaded_settings: dict, loaded_filter_settings: dict, input_data: str
+ filter_list: FilterList,
+ list_type: ListType,
+ filter_type: type[Filter],
+ loaded_settings: dict,
+ loaded_filter_settings: dict,
+ input_data: str
) -> tuple[str, dict[str, Any], dict[str, Any]]:
"""Parse a string representing a possible description and setting overrides, and validate the setting names."""
if not input_data:
@@ -679,25 +694,32 @@ def description_and_settings_converter(
filter_settings = {}
for setting, _ in list(settings.items()):
if setting not in loaded_settings:
+ # It's a filter setting
if "/" in setting:
setting_list_name, filter_setting_name = setting.split("/", maxsplit=1)
- if setting_list_name.lower() != list_name.lower():
+ if setting_list_name.lower() != filter_list.name.lower():
raise BadArgument(
- f"A setting for a {setting_list_name!r} filter was provided, but the list name is {list_name!r}"
+ f"A setting for a {setting_list_name!r} filter was provided, "
+ f"but the list name is {filter_list.name!r}"
)
- if filter_setting_name not in loaded_filter_settings[list_name]:
+ if filter_setting_name not in loaded_filter_settings[filter_list.name]:
raise BadArgument(f"{setting!r} is not a recognized setting.")
- type_ = loaded_filter_settings[list_name][filter_setting_name][2]
+ type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2]
try:
- filter_settings[filter_setting_name] = _parse_value(settings.pop(setting), type_)
+ parsed_value = _parse_value(settings.pop(setting), type_)
+ if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)):
+ filter_settings[filter_setting_name] = parsed_value
except (TypeError, ValueError) as e:
raise BadArgument(e)
else:
raise BadArgument(f"{setting!r} is not a recognized setting.")
+ # It's a filter list setting
else:
type_ = loaded_settings[setting][2]
try:
- settings[setting] = _parse_value(settings.pop(setting), type_)
+ parsed_value = _parse_value(settings.pop(setting), type_)
+ if not repr_equals(parsed_value, filter_list.default(list_type, setting)):
+ settings[setting] = parsed_value
except (TypeError, ValueError) as e:
raise BadArgument(e)
diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py
index 158f1e7bd..438c22d41 100644
--- a/bot/exts/filtering/_utils.py
+++ b/bot/exts/filtering/_utils.py
@@ -66,6 +66,22 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]
return str(item)
+def repr_equals(override: Any, default: Any) -> bool:
+ """Return whether the override and the default have the same representation."""
+ if override is None: # It's not an override
+ return True
+
+ override_is_sequence = isinstance(override, (tuple, list, set))
+ default_is_sequence = isinstance(default, (tuple, list, set))
+ if override_is_sequence != default_is_sequence: # One is a sequence and the other isn't.
+ return False
+ if override_is_sequence:
+ if len(override) != len(default):
+ return False
+ return all(str(item1) == str(item2) for item1, item2 in zip(set(override), set(default)))
+ return str(override) == str(default)
+
+
class FieldRequiring(ABC):
"""A mixin class that can force its concrete subclasses to set a value for specific class attributes."""
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
index 5f42e2cab..aa90d1600 100644
--- a/bot/exts/filtering/filtering.py
+++ b/bot/exts/filtering/filtering.py
@@ -21,7 +21,7 @@ from bot.exts.filtering._settings import ActionSettings
from bot.exts.filtering._ui import (
ArgumentCompletionView, build_filter_repr_dict, description_and_settings_converter, populate_embed_from_dict
)
-from bot.exts.filtering._utils import past_tense, to_serializable
+from bot.exts.filtering._utils import past_tense, repr_equals, to_serializable
from bot.log import get_logger
from bot.pagination import LinePaginator
from bot.utils.messages import format_channel, format_user
@@ -269,7 +269,7 @@ class Filtering(Cog):
return
filter_, filter_list, list_type = result
- overrides_values, extra_fields_overrides = self._filter_overrides(filter_)
+ overrides_values, extra_fields_overrides = self._filter_overrides(filter_, filter_list, list_type)
all_settings_repr_dict = build_filter_repr_dict(
filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides
@@ -371,9 +371,13 @@ class Filtering(Cog):
return
filter_, filter_list, list_type = result
filter_type = type(filter_)
- settings, filter_settings = self._filter_overrides(filter_)
+ settings, filter_settings = self._filter_overrides(filter_, filter_list, list_type)
description, new_settings, new_filter_settings = description_and_settings_converter(
- filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings
+ filter_list,
+ list_type, filter_type,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ description_and_settings
)
content = filter_.content
@@ -620,15 +624,18 @@ class Filtering(Cog):
return sublist[id_], filter_list, list_type
@staticmethod
- def _filter_overrides(filter_: Filter) -> tuple[dict, dict]:
+ def _filter_overrides(filter_: Filter, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]:
"""Get the filter's overrides to the filter list settings and the extra fields settings."""
overrides_values = {}
for settings in (filter_.actions, filter_.validations):
if settings:
for _, setting in settings.items():
- overrides_values.update(to_serializable(setting.dict()))
+ for setting_name, value in to_serializable(setting.dict()).items():
+ if not repr_equals(value, filter_list.default(list_type, setting_name)):
+ overrides_values[setting_name] = value
if filter_.extra_fields_type:
+ # The values here can be safely used since overrides equal to the defaults won't be saved.
extra_fields_overrides = filter_.extra_fields.dict(exclude_unset=True)
else:
extra_fields_overrides = {}
@@ -645,10 +652,15 @@ class Filtering(Cog):
description_and_settings: Optional[str] = None
) -> None:
"""Add a filter to the database."""
+ filter_type = filter_list.get_filter_type(content)
description, settings, filter_settings = description_and_settings_converter(
- filter_list.name, self.loaded_settings, self.loaded_filter_settings, description_and_settings
+ filter_list,
+ list_type,
+ filter_type,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ description_and_settings
)
- filter_type = filter_list.get_filter_type(content)
if noui:
try: