diff options
-rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 8 | ||||
-rw-r--r-- | bot/exts/filtering/_filters/filter.py | 6 | ||||
-rw-r--r-- | bot/exts/filtering/_settings.py | 7 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/filter.py | 88 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/filter_list.py | 6 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/search.py | 365 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/ui.py | 4 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 127 | ||||
-rw-r--r-- | bot/pagination.py | 7 |
9 files changed, 563 insertions, 55 deletions
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index f993665f2..84a43072b 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -84,14 +84,14 @@ class AtomicList(NamedTuple): return relevant_filters - def default(self, setting: str) -> Any: + def default(self, setting_name: str) -> Any: """Get the default value of a specific setting.""" missing = object() - value = self.defaults.actions.get_setting(setting, missing) + value = self.defaults.actions.get_setting(setting_name, missing) if value is missing: - value = self.defaults.validations.get_setting(setting, missing) + value = self.defaults.validations.get_setting(setting_name, missing) if value is missing: - raise ValueError(f"Couldn't find a setting named {setting!r}.") + raise ValueError(f"Couldn't find a setting named {setting_name!r}.") return value diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index 0d11d5b3c..095799781 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Optional +from typing import Any from pydantic import ValidationError @@ -28,7 +28,7 @@ class Filter(FieldRequiring): self.description = filter_data["description"] self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults) if self.extra_fields_type: - self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"]) + self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"] or "{}") # noqa: P103 else: self.extra_fields = None @@ -52,7 +52,7 @@ class Filter(FieldRequiring): """Search for the filter's content within a given context.""" @classmethod - def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, Optional[str]]: + def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, str | None]: """Validate whether the supplied fields are valid for the filter, and provide the error message if not.""" if cls.extra_fields_type is None: return True, None diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index 4c2114f07..066c7a369 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -211,3 +211,10 @@ class Defaults(NamedTuple): actions: ActionSettings validations: ValidationSettings + + def dict(self) -> dict[str, Any]: + """Return a dict representation of the stored fields across all entries.""" + dict_ = {} + for settings in self: + dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_) + return dict_ diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 765fba683..37584e9fd 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -28,7 +28,7 @@ def build_filter_repr_dict( settings_overrides: dict, extra_fields_overrides: dict ) -> dict: - """Build a dictionary of field names and values to pass to `_build_embed_from_dict`.""" + """Build a dictionary of field names and values to pass to `populate_embed_from_dict`.""" # Get filter list settings default_setting_values = {} for settings_group in filter_list[list_type].defaults: @@ -155,16 +155,16 @@ class FilterEditView(EditBaseView): ) self.add_item(add_select) - override_names = ( - list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides] - ) - remove_select = CustomCallbackSelect( - self._remove_override, - placeholder="Select an override to remove", - options=[SelectOption(label=name) for name in sorted(override_names)], - row=2 - ) - if remove_select.options: + if settings_overrides or filter_settings_overrides: + override_names = ( + list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides] + ) + remove_select = CustomCallbackSelect( + self._remove_override, + placeholder="Select an override to remove", + options=[SelectOption(label=name) for name in sorted(override_names)], + row=2 + ) self.add_item(remove_select) @discord.ui.button(label="Edit Content", row=3) @@ -285,9 +285,9 @@ class FilterEditView(EditBaseView): dict_to_edit[setting_name] = setting_value # If there's already an override, remove it, since the new value is the same as the default. elif setting_name in dict_to_edit: - del dict_to_edit[setting_name] + dict_to_edit.pop(setting_name) elif setting_name in dict_to_edit: - del dict_to_edit[setting_name] + dict_to_edit.pop(setting_name) # This is inefficient, but otherwise the selects go insane if the user attempts to edit the same setting # multiple times, even when replacing the select with a new one. @@ -315,8 +315,10 @@ class FilterEditView(EditBaseView): async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None: """Replace any non-overridden settings with overrides from the given filter.""" try: - settings, filter_settings = template_settings(template_id, self.filter_list, self.list_type) - except ValueError as e: # The interaction is necessary to send an ephemeral message. + settings, filter_settings = template_settings( + template_id, self.filter_list, self.list_type, self.filter_type + ) + except BadArgument as e: # The interaction object is necessary to send an ephemeral message. await interaction.response.send_message(f":x: {e}", ephemeral=True) return else: @@ -326,6 +328,7 @@ class FilterEditView(EditBaseView): self.filter_settings_overrides = filter_settings | self.filter_settings_overrides self.embed.clear_fields() await embed_message.edit(embed=self.embed, view=self.copy()) + self.stop() async def _remove_override(self, interaction: Interaction, select: discord.ui.Select) -> None: """ @@ -380,28 +383,7 @@ def description_and_settings_converter( filter_settings = {} for setting, _ in list(settings.items()): - if setting not in loaded_settings: - # It's a filter setting - if "/" in setting: - setting_list_name, filter_setting_name = setting.split("/", maxsplit=1) - if setting_list_name.lower() != filter_list.name.lower(): - raise BadArgument( - f"A setting for a {setting_list_name!r} filter was provided, " - f"but the list name is {filter_list.name!r}" - ) - if filter_setting_name not in loaded_filter_settings[filter_list.name]: - raise BadArgument(f"{setting!r} is not a recognized setting.") - type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2] - try: - parsed_value = parse_value(settings.pop(setting), type_) - if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): - filter_settings[filter_setting_name] = parsed_value - except (TypeError, ValueError) as e: - raise BadArgument(e) - else: - raise BadArgument(f"{setting!r} is not a recognized setting.") - # It's a filter list setting - else: + if setting in loaded_settings: # It's a filter list setting type_ = loaded_settings[setting][2] try: parsed_value = parse_value(settings.pop(setting), type_) @@ -409,11 +391,28 @@ def description_and_settings_converter( settings[setting] = parsed_value except (TypeError, ValueError) as e: raise BadArgument(e) + elif "/" not in setting: + raise BadArgument(f"{setting!r} is not a recognized setting.") + else: # It's a filter setting + filter_name, filter_setting_name = setting.split("/", maxsplit=1) + if filter_name.lower() != filter_type.name.lower(): + raise BadArgument( + f"A setting for a {filter_name!r} filter was provided, but the filter name is {filter_type.name!r}" + ) + if filter_setting_name not in loaded_filter_settings[filter_type.name]: + raise BadArgument(f"{setting!r} is not a recognized setting.") + type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] + try: + parsed_value = parse_value(settings.pop(setting), type_) + if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): + filter_settings[filter_setting_name] = parsed_value + except (TypeError, ValueError) as e: + raise BadArgument(e) # Pull templates settings and apply them. if template is not None: try: - t_settings, t_filter_settings = template_settings(template, filter_list, list_type) + t_settings, t_filter_settings = template_settings(template, filter_list, list_type, filter_type) except ValueError as e: raise BadArgument(str(e)) else: @@ -430,18 +429,25 @@ def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]: return to_serializable(overrides_values), to_serializable(extra_fields_overrides) -def template_settings(filter_id: str, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]: +def template_settings( + filter_id: str, filter_list: FilterList, list_type: ListType, filter_type: type[Filter] +) -> tuple[dict, dict]: """Find the filter with specified ID, and return its settings.""" try: filter_id = int(filter_id) if filter_id < 0: raise ValueError() except ValueError: - raise ValueError("Template value must be a non-negative integer.") + raise BadArgument("Template value must be a non-negative integer.") if filter_id not in filter_list[list_type].filters: - raise ValueError( + raise BadArgument( f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list." ) filter_ = filter_list[list_type].filters[filter_id] + + if not isinstance(filter_, filter_type): + raise BadArgument( + f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}" + ) return filter_serializable_overrides(filter_) diff --git a/bot/exts/filtering/_ui/filter_list.py b/bot/exts/filtering/_ui/filter_list.py index 15d81322b..e77e29ec9 100644 --- a/bot/exts/filtering/_ui/filter_list.py +++ b/bot/exts/filtering/_ui/filter_list.py @@ -24,7 +24,11 @@ def settings_converter(loaded_settings: dict, input_data: str) -> dict[str, Any] if not parsed: return {} - settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} + try: + settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} + except ValueError: + raise BadArgument("The settings provided are not in the correct format.") + for setting in settings: if setting not in loaded_settings: raise BadArgument(f"{setting!r} is not a recognized setting.") diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py new file mode 100644 index 000000000..d553c28ea --- /dev/null +++ b/bot/exts/filtering/_ui/search.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import discord +from discord import Interaction, SelectOption +from discord.ext.commands import BadArgument + +from bot.exts.filtering._filter_lists import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings_types.settings_entry import SettingsEntry +from bot.exts.filtering._ui.filter import filter_serializable_overrides +from bot.exts.filtering._ui.ui import ( + COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value, + populate_embed_from_dict +) + + +def search_criteria_converter( + filter_lists: dict, + loaded_filters: dict, + loaded_settings: dict, + loaded_filter_settings: dict, + filter_type: type[Filter] | None, + input_data: str +) -> tuple[dict[str, Any], dict[str, Any], type[Filter]]: + """Parse a string representing setting overrides, and validate the setting names.""" + if not input_data: + return {}, {}, filter_type + + parsed = SETTINGS_DELIMITER.split(input_data) + if not parsed: + return {}, {}, filter_type + + try: + settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} + except ValueError: + raise BadArgument("The settings provided are not in the correct format.") + + template = None + if "--template" in settings: + template = settings.pop("--template") + + filter_settings = {} + for setting, _ in list(settings.items()): + if setting in loaded_settings: # It's a filter list setting + type_ = loaded_settings[setting][2] + try: + settings[setting] = parse_value(settings[setting], type_) + except (TypeError, ValueError) as e: + raise BadArgument(e) + elif "/" not in setting: + raise BadArgument(f"{setting!r} is not a recognized setting.") + else: # It's a filter setting + filter_name, filter_setting_name = setting.split("/", maxsplit=1) + if not filter_type: + if filter_name in loaded_filters: + filter_type = loaded_filters[filter_name] + else: + raise BadArgument(f"There's no filter type named {filter_name!r}.") + if filter_name.lower() != filter_type.name.lower(): + raise BadArgument( + f"A setting for a {filter_name!r} filter was provided, " + f"but the filter name is {filter_type.name!r}" + ) + if filter_setting_name not in loaded_filter_settings[filter_type.name]: + raise BadArgument(f"{setting!r} is not a recognized setting.") + type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] + try: + filter_settings[filter_setting_name] = parse_value(settings.pop(setting), type_) + except (TypeError, ValueError) as e: + raise BadArgument(e) + + # Pull templates settings and apply them. + if template is not None: + try: + t_settings, t_filter_settings, filter_type = template_settings(template, filter_lists, filter_type) + except ValueError as e: + raise BadArgument(str(e)) + else: + # The specified settings go on top of the template + settings = t_settings | settings + filter_settings = t_filter_settings | filter_settings + + return settings, filter_settings, filter_type + + +def get_filter(filter_id: int, filter_lists: dict) -> tuple[Filter, FilterList, ListType] | None: + """Return a filter with the specific filter_id, if found.""" + for filter_list in filter_lists.values(): + for list_type, sublist in filter_list.items(): + if filter_id in sublist.filters: + return sublist.filters[filter_id], filter_list, list_type + return None + + +def template_settings( + filter_id: str, filter_lists: dict, filter_type: type[Filter] | None +) -> tuple[dict, dict, type[Filter]]: + """Find a filter with the specified ID and filter type, and return its settings and (maybe newly found) type.""" + try: + filter_id = int(filter_id) + if filter_id < 0: + raise ValueError() + except ValueError: + raise BadArgument("Template value must be a non-negative integer.") + + result = get_filter(filter_id, filter_lists) + if not result: + raise BadArgument(f"Could not find a filter with ID `{filter_id}`.") + filter_, filter_list, list_type = result + + if filter_type and not isinstance(filter_, filter_type): + raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") + + settings, filter_settings = filter_serializable_overrides(filter_) + return settings, filter_settings, type(filter_) + + +def build_search_repr_dict( + settings: dict[str, Any], filter_settings: dict[str, Any], filter_type: type[Filter] | None +) -> dict: + """Build a dictionary of field names and values to pass to `populate_embed_from_dict`.""" + total_values = settings.copy() + if filter_type: + for setting_name, value in filter_settings.items(): + total_values[f"{filter_type.name}/{setting_name}"] = value + + return total_values + + +class SearchEditView(EditBaseView): + """A view used to edit the search criteria before performing the search.""" + + class _REMOVE: + """Sentinel value for when an override should be removed.""" + + def __init__( + self, + filter_type: type[Filter] | None, + settings: dict[str, Any], + filter_settings: dict[str, Any], + loaded_filter_lists: dict[str, FilterList], + loaded_filters: dict[str, type[Filter]], + loaded_settings: dict[str, tuple[str, SettingsEntry, type]], + loaded_filter_settings: dict[str, dict[str, tuple[str, SettingsEntry, type]]], + author: discord.User | discord.Member, + embed: discord.Embed, + confirm_callback: Callable + ): + super().__init__(author) + self.filter_type = filter_type + self.settings = settings + self.filter_settings = filter_settings + self.loaded_filter_lists = loaded_filter_lists + self.loaded_filters = loaded_filters + self.loaded_settings = loaded_settings + self.loaded_filter_settings = loaded_filter_settings + self.embed = embed + self.confirm_callback = confirm_callback + + title = "Filters Search" + if filter_type: + title += f" - {filter_type.name.title()}" + embed.set_author(name=title) + + settings_repr_dict = build_search_repr_dict(settings, filter_settings, filter_type) + populate_embed_from_dict(embed, settings_repr_dict) + + self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} + if filter_type: + self.type_per_setting_name.update({ + f"{filter_type.name}/{name}": type_ + for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items() + }) + + add_select = CustomCallbackSelect( + self._prompt_new_value, + placeholder="Add or edit criterion", + options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)], + row=0 + ) + self.add_item(add_select) + + if settings_repr_dict: + remove_select = CustomCallbackSelect( + self._remove_criterion, + placeholder="Select a criterion to remove", + options=[SelectOption(label=name) for name in sorted(settings_repr_dict)], + row=1 + ) + self.add_item(remove_select) + + @discord.ui.button(label="Template", row=2) + async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None: + """A button to enter a filter template ID and copy its overrides over.""" + modal = TemplateModal(self, interaction.message) + await interaction.response.send_modal(modal) + + @discord.ui.button(label="Filter Type", row=2) + async def enter_filter_type(self, interaction: Interaction, button: discord.ui.Button) -> None: + """A button to enter a filter type.""" + modal = FilterTypeModal(self, interaction.message) + await interaction.response.send_modal(modal) + + @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=3) + async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: + """Confirm the search criteria and perform the search.""" + await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first. + try: + await self.confirm_callback(interaction.message, self.filter_type, self.settings, self.filter_settings) + except BadArgument as e: + await interaction.message.reply( + embed=discord.Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e)) + ) + await interaction.message.edit(view=self) + else: + self.stop() + + @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=3) + async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: + """Cancel the operation.""" + await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) + self.stop() + + def current_value(self, setting_name: str) -> Any: + """Get the current value stored for the setting or MISSING if none found.""" + if setting_name in self.settings: + return self.settings[setting_name] + if "/" in setting_name: + _, setting_name = setting_name.split("/", maxsplit=1) + if setting_name in self.filter_settings: + return self.filter_settings[setting_name] + return MISSING + + async def update_embed( + self, + interaction_or_msg: discord.Interaction | discord.Message, + *, + setting_name: str | None = None, + setting_value: str | type[SearchEditView._REMOVE] | None = None, + ) -> None: + """ + Update the embed with the new information. + + If a setting name is provided with a _REMOVE value, remove the override. + If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. + """ + if not setting_name: # Can be None just to make the function signature compatible with the parent class. + return + + if "/" in setting_name: + filter_name, setting_name = setting_name.split("/", maxsplit=1) + dict_to_edit = self.filter_settings + else: + dict_to_edit = self.settings + + # Update the criterion value or remove it + if setting_value is not self._REMOVE: + dict_to_edit[setting_name] = setting_value + elif setting_name in dict_to_edit: + dict_to_edit.pop(setting_name) + + self.embed.clear_fields() + new_view = self.copy() + + try: + if isinstance(interaction_or_msg, discord.Interaction): + await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) + else: + await interaction_or_msg.edit(embed=self.embed, view=new_view) + except discord.errors.HTTPException: # Just in case of faulty input. + pass + else: + self.stop() + + async def _remove_criterion(self, interaction: Interaction, select: discord.ui.Select) -> None: + """ + Remove the criterion the user selected, and edit the embed. + + The interaction needs to be the selection of the setting attached to the embed. + """ + await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE) + + async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None: + """Set any unset criteria with settings values from the given filter.""" + try: + settings, filter_settings, self.filter_type = template_settings( + template_id, self.loaded_filter_lists, self.filter_type + ) + except BadArgument as e: # The interaction object is necessary to send an ephemeral message. + await interaction.response.send_message(f":x: {e}", ephemeral=True) + return + else: + await interaction.response.defer() + + self.settings = settings | self.settings + self.filter_settings = filter_settings | self.filter_settings + self.embed.clear_fields() + await embed_message.edit(embed=self.embed, view=self.copy()) + self.stop() + + async def apply_filter_type(self, type_name: str, embed_message: discord.Message, interaction: Interaction) -> None: + """Set a new filter type and reset any criteria for settings of the old filter type.""" + if type_name.lower() not in self.loaded_filters: + if type_name.lower()[:-1] not in self.loaded_filters: # In case the user entered the plural form. + await interaction.response.send_message(f":x: No such filter type {type_name!r}.", ephemeral=True) + return + type_name = type_name[:-1] + type_name = type_name.lower() + await interaction.response.defer() + + if self.filter_type and type_name == self.filter_type.name: + return + self.filter_type = self.loaded_filters[type_name] + self.filter_settings = {} + self.embed.clear_fields() + await embed_message.edit(embed=self.embed, view=self.copy()) + self.stop() + + def copy(self) -> SearchEditView: + """Create a copy of this view.""" + return SearchEditView( + self.filter_type, + self.settings, + self.filter_settings, + self.loaded_filter_lists, + self.loaded_filters, + self.loaded_settings, + self.loaded_filter_settings, + self.author, + self.embed, + self.confirm_callback + ) + + +class TemplateModal(discord.ui.Modal, title="Template"): + """A modal to enter a filter ID to copy its overrides over.""" + + template = discord.ui.TextInput(label="Template Filter ID", required=False) + + def __init__(self, embed_view: SearchEditView, message: discord.Message): + super().__init__(timeout=COMPONENT_TIMEOUT) + self.embed_view = embed_view + self.message = message + + async def on_submit(self, interaction: Interaction) -> None: + """Update the embed with the new description.""" + await self.embed_view.apply_template(self.template.value, self.message, interaction) + + +class FilterTypeModal(discord.ui.Modal, title="Template"): + """A modal to enter a filter ID to copy its overrides over.""" + + filter_type = discord.ui.TextInput(label="Filter Type") + + def __init__(self, embed_view: SearchEditView, message: discord.Message): + super().__init__(timeout=COMPONENT_TIMEOUT) + self.embed_view = embed_view + self.message = message + + async def on_submit(self, interaction: Interaction) -> None: + """Update the embed with the new description.""" + await self.embed_view.apply_filter_type(self.filter_type.value, self.message, interaction) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 980eba02a..c506db1fe 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -34,7 +34,7 @@ MAX_SELECT_ITEMS = 25 MAX_EMBED_DESCRIPTION = 4000 SETTINGS_DELIMITER = re.compile(r"\s+(?=\S+=\S+)") -SINGLE_SETTING_PATTERN = re.compile(r"\w+=.+") +SINGLE_SETTING_PATTERN = re.compile(r"[\w/]+=.+") # Sentinel value to denote that a value is missing MISSING = object() @@ -76,7 +76,7 @@ def parse_value(value: str, type_: type[T]) -> T: if type_ in (tuple, list, set): return type_(value.split(",")) if type_ is bool: - return value == "True" + return value.lower() == "true" or value == "1" if isinstance(type_, EnumMeta): return type_[value.upper()] diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 6ff5181a9..890b25718 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -26,8 +26,9 @@ from bot.exts.filtering._ui.filter import ( build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict ) from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter +from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter from bot.exts.filtering._ui.ui import ArgumentCompletionView, DeleteConfirmationView, format_response_error -from bot.exts.filtering._utils import past_tense, starting_value, to_serializable +from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable from bot.log import get_logger from bot.pagination import LinePaginator from bot.utils.messages import format_channel, format_user @@ -523,6 +524,63 @@ class Filtering(Cog): embed = Embed(colour=Colour.blue(), title="Match results") await LinePaginator.paginate(lines, ctx, embed, max_lines=10, empty=False) + @filter.command(name="search") + async def f_search( + self, + ctx: Context, + noui: Literal["noui"] | None, + filter_type_name: str | None, + *, + settings: str = "" + ) -> None: + """ + Find filters with the provided settings. The format is identical to that of the add and edit commands. + + If a list type and/or a list name are provided, the search will be limited to those parameters. A list name must + be provided in order to search by filter-specific settings. + """ + filter_type = None + if filter_type_name: + filter_type_name = filter_type_name.lower() + filter_type = self.loaded_filters.get(filter_type_name) + if not filter_type: + self.loaded_filters.get(filter_type_name[:-1]) # In case the user tried to specify the plural form. + # If settings were provided with no filter_type, discord.py will capture the first word as the filter type. + if filter_type is None and filter_type_name is not None: + if settings: + settings = f"{filter_type_name} {settings}" + else: + settings = filter_type_name + filter_type_name = None + + settings, filter_settings, filter_type = search_criteria_converter( + self.filter_lists, + self.loaded_filters, + self.loaded_settings, + self.loaded_filter_settings, + filter_type, + settings + ) + + if noui: + await self._search_filters(ctx.message, filter_type, settings, filter_settings) + return + + embed = Embed(colour=Colour.blue()) + view = SearchEditView( + filter_type, + settings, + filter_settings, + self.filter_lists, + self.loaded_filters, + self.loaded_settings, + self.loaded_filter_settings, + ctx.author, + embed, + self._search_filters + ) + await ctx.send(embed=embed, reference=ctx.message, view=view) + # endregion # region: filterlist group @@ -787,7 +845,7 @@ class Filtering(Cog): embed = Embed(colour=Colour.blue()) embed.set_author(name=f"List of {filter_list[list_type].label}s ({len(lines)} total)") - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True) def _get_filter_by_id(self, id_: int) -> Optional[tuple[Filter, FilterList, ListType]]: """Get the filter object corresponding to the provided ID, along with its containing list and list type.""" @@ -954,6 +1012,71 @@ class Filtering(Cog): filter_list.add_list(response) await msg.reply(f"✅ Edited filter list: {filter_list[list_type].label}") + def _filter_match_query( + self, filter_: Filter, settings_query: dict, filter_settings_query: dict, differ_by_default: set[str] + ) -> bool: + """Return whether the given filter matches the query.""" + override_matches = set() + overrides, _ = filter_.overrides + for setting_name, setting_value in settings_query.items(): + if setting_name not in overrides: + continue + if repr_equals(overrides[setting_name], setting_value): + override_matches.add(setting_name) + else: # If an override doesn't match then the filter doesn't match. + return False + if not (differ_by_default <= override_matches): # The overrides didn't cover for the default mismatches. + return False + + filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {} + # If the dict changes then some fields were not the same. + return (filter_settings | filter_settings_query) == filter_settings + + def _search_filter_list( + self, atomic_list: AtomicList, filter_type: type[Filter] | None, settings: dict, filter_settings: dict + ) -> list[Filter]: + """Find all filters in the filter list which match the settings.""" + # If the default answers are known, only the overrides need to be checked for each filter. + all_defaults = atomic_list.defaults.dict() + match_by_default = set() + differ_by_default = set() + for setting_name, setting_value in settings.items(): + if repr_equals(all_defaults[setting_name], setting_value): + match_by_default.add(setting_name) + else: + differ_by_default.add(setting_name) + + result_filters = [] + for filter_ in atomic_list.filters.values(): + if filter_type and not isinstance(filter_, filter_type): + continue + if self._filter_match_query(filter_, settings, filter_settings, differ_by_default): + result_filters.append(filter_) + + return result_filters + + async def _search_filters( + self, message: Message, filter_type: type[Filter] | None, settings: dict, filter_settings: dict + ) -> None: + """Find all filters which match the settings and display them.""" + lines = [] + result_count = 0 + for filter_list in self.filter_lists.values(): + if filter_type and filter_type not in filter_list.filter_types: + continue + for atomic_list in filter_list.values(): + list_results = self._search_filter_list(atomic_list, filter_type, settings, filter_settings) + if list_results: + lines.append(f"**{atomic_list.label.title()}**") + lines.extend(map(str, list_results)) + lines.append("") + result_count += len(list_results) + + embed = Embed(colour=Colour.blue()) + embed.set_author(name=f"Search Results ({result_count} total)") + ctx = await bot.instance.get_context(message) + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True) + # endregion diff --git a/bot/pagination.py b/bot/pagination.py index 10bef1c9f..92fa781ee 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -204,6 +204,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False, + reply: bool = False, ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -251,6 +252,8 @@ class LinePaginator(Paginator): embed.description = paginator.pages[current_page] + reference = ctx.message if reply else None + if len(paginator.pages) <= 1: if footer_text: embed.set_footer(text=footer_text) @@ -261,7 +264,7 @@ class LinePaginator(Paginator): log.trace(f"Setting embed url to '{url}'") log.debug("There's less than two pages, so we won't paginate - sending single page on its own") - return await ctx.send(embed=embed) + return await ctx.send(embed=embed, reference=reference) else: if footer_text: embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -274,7 +277,7 @@ class LinePaginator(Paginator): log.trace(f"Setting embed url to '{url}'") log.debug("Sending first page to channel...") - message = await ctx.send(embed=embed) + message = await ctx.send(embed=embed, reference=reference) log.debug("Adding emoji reactions to message...") |