diff options
| -rw-r--r-- | bot/exts/filtering/_filter_lists/filter_list.py | 8 | ||||
| -rw-r--r-- | bot/exts/filtering/_filters/filter.py | 6 | ||||
| -rw-r--r-- | bot/exts/filtering/_settings.py | 7 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/filter.py | 88 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/filter_list.py | 6 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/search.py | 365 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/ui.py | 4 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 127 | ||||
| -rw-r--r-- | bot/pagination.py | 7 | 
9 files changed, 563 insertions, 55 deletions
| diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index f993665f2..84a43072b 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -84,14 +84,14 @@ class AtomicList(NamedTuple):          return relevant_filters -    def default(self, setting: str) -> Any: +    def default(self, setting_name: str) -> Any:          """Get the default value of a specific setting."""          missing = object() -        value = self.defaults.actions.get_setting(setting, missing) +        value = self.defaults.actions.get_setting(setting_name, missing)          if value is missing: -            value = self.defaults.validations.get_setting(setting, missing) +            value = self.defaults.validations.get_setting(setting_name, missing)              if value is missing: -                raise ValueError(f"Couldn't find a setting named {setting!r}.") +                raise ValueError(f"Couldn't find a setting named {setting_name!r}.")          return value diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index 0d11d5b3c..095799781 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -1,5 +1,5 @@  from abc import abstractmethod -from typing import Any, Optional +from typing import Any  from pydantic import ValidationError @@ -28,7 +28,7 @@ class Filter(FieldRequiring):          self.description = filter_data["description"]          self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults)          if self.extra_fields_type: -            self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"]) +            self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"] or "{}")  # noqa: P103          else:              self.extra_fields = None @@ -52,7 +52,7 @@ class Filter(FieldRequiring):          """Search for the filter's content within a given context."""      @classmethod -    def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, Optional[str]]: +    def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, str | None]:          """Validate whether the supplied fields are valid for the filter, and provide the error message if not."""          if cls.extra_fields_type is None:              return True, None diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index 4c2114f07..066c7a369 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -211,3 +211,10 @@ class Defaults(NamedTuple):      actions: ActionSettings      validations: ValidationSettings + +    def dict(self) -> dict[str, Any]: +        """Return a dict representation of the stored fields across all entries.""" +        dict_ = {} +        for settings in self: +            dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_) +        return dict_ diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 765fba683..37584e9fd 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -28,7 +28,7 @@ def build_filter_repr_dict(      settings_overrides: dict,      extra_fields_overrides: dict  ) -> dict: -    """Build a dictionary of field names and values to pass to `_build_embed_from_dict`.""" +    """Build a dictionary of field names and values to pass to `populate_embed_from_dict`."""      # Get filter list settings      default_setting_values = {}      for settings_group in filter_list[list_type].defaults: @@ -155,16 +155,16 @@ class FilterEditView(EditBaseView):          )          self.add_item(add_select) -        override_names = ( -            list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides] -        ) -        remove_select = CustomCallbackSelect( -            self._remove_override, -            placeholder="Select an override to remove", -            options=[SelectOption(label=name) for name in sorted(override_names)], -            row=2 -        ) -        if remove_select.options: +        if settings_overrides or filter_settings_overrides: +            override_names = ( +                list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides] +            ) +            remove_select = CustomCallbackSelect( +                self._remove_override, +                placeholder="Select an override to remove", +                options=[SelectOption(label=name) for name in sorted(override_names)], +                row=2 +            )              self.add_item(remove_select)      @discord.ui.button(label="Edit Content", row=3) @@ -285,9 +285,9 @@ class FilterEditView(EditBaseView):                      dict_to_edit[setting_name] = setting_value                  # If there's already an override, remove it, since the new value is the same as the default.                  elif setting_name in dict_to_edit: -                    del dict_to_edit[setting_name] +                    dict_to_edit.pop(setting_name)              elif setting_name in dict_to_edit: -                del dict_to_edit[setting_name] +                dict_to_edit.pop(setting_name)          # This is inefficient, but otherwise the selects go insane if the user attempts to edit the same setting          # multiple times, even when replacing the select with a new one. @@ -315,8 +315,10 @@ class FilterEditView(EditBaseView):      async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None:          """Replace any non-overridden settings with overrides from the given filter."""          try: -            settings, filter_settings = template_settings(template_id, self.filter_list, self.list_type) -        except ValueError as e:  # The interaction is necessary to send an ephemeral message. +            settings, filter_settings = template_settings( +                template_id, self.filter_list, self.list_type, self.filter_type +            ) +        except BadArgument as e:  # The interaction object is necessary to send an ephemeral message.              await interaction.response.send_message(f":x: {e}", ephemeral=True)              return          else: @@ -326,6 +328,7 @@ class FilterEditView(EditBaseView):          self.filter_settings_overrides = filter_settings | self.filter_settings_overrides          self.embed.clear_fields()          await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop()      async def _remove_override(self, interaction: Interaction, select: discord.ui.Select) -> None:          """ @@ -380,28 +383,7 @@ def description_and_settings_converter(      filter_settings = {}      for setting, _ in list(settings.items()): -        if setting not in loaded_settings: -            # It's a filter setting -            if "/" in setting: -                setting_list_name, filter_setting_name = setting.split("/", maxsplit=1) -                if setting_list_name.lower() != filter_list.name.lower(): -                    raise BadArgument( -                        f"A setting for a {setting_list_name!r} filter was provided, " -                        f"but the list name is {filter_list.name!r}" -                    ) -                if filter_setting_name not in loaded_filter_settings[filter_list.name]: -                    raise BadArgument(f"{setting!r} is not a recognized setting.") -                type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2] -                try: -                    parsed_value = parse_value(settings.pop(setting), type_) -                    if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): -                        filter_settings[filter_setting_name] = parsed_value -                except (TypeError, ValueError) as e: -                    raise BadArgument(e) -            else: -                raise BadArgument(f"{setting!r} is not a recognized setting.") -        # It's a filter list setting -        else: +        if setting in loaded_settings:  # It's a filter list setting              type_ = loaded_settings[setting][2]              try:                  parsed_value = parse_value(settings.pop(setting), type_) @@ -409,11 +391,28 @@ def description_and_settings_converter(                      settings[setting] = parsed_value              except (TypeError, ValueError) as e:                  raise BadArgument(e) +        elif "/" not in setting: +            raise BadArgument(f"{setting!r} is not a recognized setting.") +        else:  # It's a filter setting +            filter_name, filter_setting_name = setting.split("/", maxsplit=1) +            if filter_name.lower() != filter_type.name.lower(): +                raise BadArgument( +                    f"A setting for a {filter_name!r} filter was provided, but the filter name is {filter_type.name!r}" +                ) +            if filter_setting_name not in loaded_filter_settings[filter_type.name]: +                raise BadArgument(f"{setting!r} is not a recognized setting.") +            type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] +            try: +                parsed_value = parse_value(settings.pop(setting), type_) +                if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)): +                    filter_settings[filter_setting_name] = parsed_value +            except (TypeError, ValueError) as e: +                raise BadArgument(e)      # Pull templates settings and apply them.      if template is not None:          try: -            t_settings, t_filter_settings = template_settings(template, filter_list, list_type) +            t_settings, t_filter_settings = template_settings(template, filter_list, list_type, filter_type)          except ValueError as e:              raise BadArgument(str(e))          else: @@ -430,18 +429,25 @@ def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]:      return to_serializable(overrides_values), to_serializable(extra_fields_overrides) -def template_settings(filter_id: str, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]: +def template_settings( +    filter_id: str, filter_list: FilterList, list_type: ListType, filter_type: type[Filter] +) -> tuple[dict, dict]:      """Find the filter with specified ID, and return its settings."""      try:          filter_id = int(filter_id)          if filter_id < 0:              raise ValueError()      except ValueError: -        raise ValueError("Template value must be a non-negative integer.") +        raise BadArgument("Template value must be a non-negative integer.")      if filter_id not in filter_list[list_type].filters: -        raise ValueError( +        raise BadArgument(              f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list."          )      filter_ = filter_list[list_type].filters[filter_id] + +    if not isinstance(filter_, filter_type): +        raise BadArgument( +            f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}" +        )      return filter_serializable_overrides(filter_) diff --git a/bot/exts/filtering/_ui/filter_list.py b/bot/exts/filtering/_ui/filter_list.py index 15d81322b..e77e29ec9 100644 --- a/bot/exts/filtering/_ui/filter_list.py +++ b/bot/exts/filtering/_ui/filter_list.py @@ -24,7 +24,11 @@ def settings_converter(loaded_settings: dict, input_data: str) -> dict[str, Any]      if not parsed:          return {} -    settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    try: +        settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    except ValueError: +        raise BadArgument("The settings provided are not in the correct format.") +      for setting in settings:          if setting not in loaded_settings:              raise BadArgument(f"{setting!r} is not a recognized setting.") diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py new file mode 100644 index 000000000..d553c28ea --- /dev/null +++ b/bot/exts/filtering/_ui/search.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import discord +from discord import Interaction, SelectOption +from discord.ext.commands import BadArgument + +from bot.exts.filtering._filter_lists import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings_types.settings_entry import SettingsEntry +from bot.exts.filtering._ui.filter import filter_serializable_overrides +from bot.exts.filtering._ui.ui import ( +    COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value, +    populate_embed_from_dict +) + + +def search_criteria_converter( +    filter_lists: dict, +    loaded_filters: dict, +    loaded_settings: dict, +    loaded_filter_settings: dict, +    filter_type: type[Filter] | None, +    input_data: str +) -> tuple[dict[str, Any], dict[str, Any], type[Filter]]: +    """Parse a string representing setting overrides, and validate the setting names.""" +    if not input_data: +        return {}, {}, filter_type + +    parsed = SETTINGS_DELIMITER.split(input_data) +    if not parsed: +        return {}, {}, filter_type + +    try: +        settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]} +    except ValueError: +        raise BadArgument("The settings provided are not in the correct format.") + +    template = None +    if "--template" in settings: +        template = settings.pop("--template") + +    filter_settings = {} +    for setting, _ in list(settings.items()): +        if setting in loaded_settings:  # It's a filter list setting +            type_ = loaded_settings[setting][2] +            try: +                settings[setting] = parse_value(settings[setting], type_) +            except (TypeError, ValueError) as e: +                raise BadArgument(e) +        elif "/" not in setting: +            raise BadArgument(f"{setting!r} is not a recognized setting.") +        else:  # It's a filter setting +            filter_name, filter_setting_name = setting.split("/", maxsplit=1) +            if not filter_type: +                if filter_name in loaded_filters: +                    filter_type = loaded_filters[filter_name] +                else: +                    raise BadArgument(f"There's no filter type named {filter_name!r}.") +            if filter_name.lower() != filter_type.name.lower(): +                raise BadArgument( +                    f"A setting for a {filter_name!r} filter was provided, " +                    f"but the filter name is {filter_type.name!r}" +                ) +            if filter_setting_name not in loaded_filter_settings[filter_type.name]: +                raise BadArgument(f"{setting!r} is not a recognized setting.") +            type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2] +            try: +                filter_settings[filter_setting_name] = parse_value(settings.pop(setting), type_) +            except (TypeError, ValueError) as e: +                raise BadArgument(e) + +    # Pull templates settings and apply them. +    if template is not None: +        try: +            t_settings, t_filter_settings, filter_type = template_settings(template, filter_lists, filter_type) +        except ValueError as e: +            raise BadArgument(str(e)) +        else: +            # The specified settings go on top of the template +            settings = t_settings | settings +            filter_settings = t_filter_settings | filter_settings + +    return settings, filter_settings, filter_type + + +def get_filter(filter_id: int, filter_lists: dict) -> tuple[Filter, FilterList, ListType] | None: +    """Return a filter with the specific filter_id, if found.""" +    for filter_list in filter_lists.values(): +        for list_type, sublist in filter_list.items(): +            if filter_id in sublist.filters: +                return sublist.filters[filter_id], filter_list, list_type +    return None + + +def template_settings( +    filter_id: str, filter_lists: dict, filter_type: type[Filter] | None +) -> tuple[dict, dict, type[Filter]]: +    """Find a filter with the specified ID and filter type, and return its settings and (maybe newly found) type.""" +    try: +        filter_id = int(filter_id) +        if filter_id < 0: +            raise ValueError() +    except ValueError: +        raise BadArgument("Template value must be a non-negative integer.") + +    result = get_filter(filter_id, filter_lists) +    if not result: +        raise BadArgument(f"Could not find a filter with ID `{filter_id}`.") +    filter_, filter_list, list_type = result + +    if filter_type and not isinstance(filter_, filter_type): +        raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") + +    settings, filter_settings = filter_serializable_overrides(filter_) +    return settings, filter_settings, type(filter_) + + +def build_search_repr_dict( +    settings: dict[str, Any], filter_settings: dict[str, Any], filter_type: type[Filter] | None +) -> dict: +    """Build a dictionary of field names and values to pass to `populate_embed_from_dict`.""" +    total_values = settings.copy() +    if filter_type: +        for setting_name, value in filter_settings.items(): +            total_values[f"{filter_type.name}/{setting_name}"] = value + +    return total_values + + +class SearchEditView(EditBaseView): +    """A view used to edit the search criteria before performing the search.""" + +    class _REMOVE: +        """Sentinel value for when an override should be removed.""" + +    def __init__( +        self, +        filter_type: type[Filter] | None, +        settings: dict[str, Any], +        filter_settings: dict[str, Any], +        loaded_filter_lists: dict[str, FilterList], +        loaded_filters: dict[str, type[Filter]], +        loaded_settings: dict[str, tuple[str, SettingsEntry, type]], +        loaded_filter_settings: dict[str, dict[str, tuple[str, SettingsEntry, type]]], +        author: discord.User | discord.Member, +        embed: discord.Embed, +        confirm_callback: Callable +    ): +        super().__init__(author) +        self.filter_type = filter_type +        self.settings = settings +        self.filter_settings = filter_settings +        self.loaded_filter_lists = loaded_filter_lists +        self.loaded_filters = loaded_filters +        self.loaded_settings = loaded_settings +        self.loaded_filter_settings = loaded_filter_settings +        self.embed = embed +        self.confirm_callback = confirm_callback + +        title = "Filters Search" +        if filter_type: +            title += f" - {filter_type.name.title()}" +        embed.set_author(name=title) + +        settings_repr_dict = build_search_repr_dict(settings, filter_settings, filter_type) +        populate_embed_from_dict(embed, settings_repr_dict) + +        self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()} +        if filter_type: +            self.type_per_setting_name.update({ +                f"{filter_type.name}/{name}": type_ +                for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items() +            }) + +        add_select = CustomCallbackSelect( +            self._prompt_new_value, +            placeholder="Add or edit criterion", +            options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)], +            row=0 +        ) +        self.add_item(add_select) + +        if settings_repr_dict: +            remove_select = CustomCallbackSelect( +                self._remove_criterion, +                placeholder="Select a criterion to remove", +                options=[SelectOption(label=name) for name in sorted(settings_repr_dict)], +                row=1 +            ) +            self.add_item(remove_select) + +    @discord.ui.button(label="Template", row=2) +    async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to enter a filter template ID and copy its overrides over.""" +        modal = TemplateModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="Filter Type", row=2) +    async def enter_filter_type(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """A button to enter a filter type.""" +        modal = FilterTypeModal(self, interaction.message) +        await interaction.response.send_modal(modal) + +    @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=3) +    async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Confirm the search criteria and perform the search.""" +        await interaction.response.edit_message(view=None)  # Make sure the interaction succeeds first. +        try: +            await self.confirm_callback(interaction.message, self.filter_type, self.settings, self.filter_settings) +        except BadArgument as e: +            await interaction.message.reply( +                embed=discord.Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e)) +            ) +            await interaction.message.edit(view=self) +        else: +            self.stop() + +    @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=3) +    async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None: +        """Cancel the operation.""" +        await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None) +        self.stop() + +    def current_value(self, setting_name: str) -> Any: +        """Get the current value stored for the setting or MISSING if none found.""" +        if setting_name in self.settings: +            return self.settings[setting_name] +        if "/" in setting_name: +            _, setting_name = setting_name.split("/", maxsplit=1) +            if setting_name in self.filter_settings: +                return self.filter_settings[setting_name] +        return MISSING + +    async def update_embed( +        self, +        interaction_or_msg: discord.Interaction | discord.Message, +        *, +        setting_name: str | None = None, +        setting_value: str | type[SearchEditView._REMOVE] | None = None, +    ) -> None: +        """ +        Update the embed with the new information. + +        If a setting name is provided with a _REMOVE value, remove the override. +        If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function. +        """ +        if not setting_name:  # Can be None just to make the function signature compatible with the parent class. +            return + +        if "/" in setting_name: +            filter_name, setting_name = setting_name.split("/", maxsplit=1) +            dict_to_edit = self.filter_settings +        else: +            dict_to_edit = self.settings + +        # Update the criterion value or remove it +        if setting_value is not self._REMOVE: +            dict_to_edit[setting_name] = setting_value +        elif setting_name in dict_to_edit: +            dict_to_edit.pop(setting_name) + +        self.embed.clear_fields() +        new_view = self.copy() + +        try: +            if isinstance(interaction_or_msg, discord.Interaction): +                await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view) +            else: +                await interaction_or_msg.edit(embed=self.embed, view=new_view) +        except discord.errors.HTTPException:  # Just in case of faulty input. +            pass +        else: +            self.stop() + +    async def _remove_criterion(self, interaction: Interaction, select: discord.ui.Select) -> None: +        """ +        Remove the criterion the user selected, and edit the embed. + +        The interaction needs to be the selection of the setting attached to the embed. +        """ +        await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE) + +    async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None: +        """Set any unset criteria with settings values from the given filter.""" +        try: +            settings, filter_settings, self.filter_type = template_settings( +                template_id, self.loaded_filter_lists, self.filter_type +            ) +        except BadArgument as e:  # The interaction object is necessary to send an ephemeral message. +            await interaction.response.send_message(f":x: {e}", ephemeral=True) +            return +        else: +            await interaction.response.defer() + +        self.settings = settings | self.settings +        self.filter_settings = filter_settings | self.filter_settings +        self.embed.clear_fields() +        await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop() + +    async def apply_filter_type(self, type_name: str, embed_message: discord.Message, interaction: Interaction) -> None: +        """Set a new filter type and reset any criteria for settings of the old filter type.""" +        if type_name.lower() not in self.loaded_filters: +            if type_name.lower()[:-1] not in self.loaded_filters:  # In case the user entered the plural form. +                await interaction.response.send_message(f":x: No such filter type {type_name!r}.", ephemeral=True) +                return +            type_name = type_name[:-1] +        type_name = type_name.lower() +        await interaction.response.defer() + +        if self.filter_type and type_name == self.filter_type.name: +            return +        self.filter_type = self.loaded_filters[type_name] +        self.filter_settings = {} +        self.embed.clear_fields() +        await embed_message.edit(embed=self.embed, view=self.copy()) +        self.stop() + +    def copy(self) -> SearchEditView: +        """Create a copy of this view.""" +        return SearchEditView( +            self.filter_type, +            self.settings, +            self.filter_settings, +            self.loaded_filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            self.author, +            self.embed, +            self.confirm_callback +        ) + + +class TemplateModal(discord.ui.Modal, title="Template"): +    """A modal to enter a filter ID to copy its overrides over.""" + +    template = discord.ui.TextInput(label="Template Filter ID", required=False) + +    def __init__(self, embed_view: SearchEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await self.embed_view.apply_template(self.template.value, self.message, interaction) + + +class FilterTypeModal(discord.ui.Modal, title="Template"): +    """A modal to enter a filter ID to copy its overrides over.""" + +    filter_type = discord.ui.TextInput(label="Filter Type") + +    def __init__(self, embed_view: SearchEditView, message: discord.Message): +        super().__init__(timeout=COMPONENT_TIMEOUT) +        self.embed_view = embed_view +        self.message = message + +    async def on_submit(self, interaction: Interaction) -> None: +        """Update the embed with the new description.""" +        await self.embed_view.apply_filter_type(self.filter_type.value, self.message, interaction) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 980eba02a..c506db1fe 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -34,7 +34,7 @@ MAX_SELECT_ITEMS = 25  MAX_EMBED_DESCRIPTION = 4000  SETTINGS_DELIMITER = re.compile(r"\s+(?=\S+=\S+)") -SINGLE_SETTING_PATTERN = re.compile(r"\w+=.+") +SINGLE_SETTING_PATTERN = re.compile(r"[\w/]+=.+")  # Sentinel value to denote that a value is missing  MISSING = object() @@ -76,7 +76,7 @@ def parse_value(value: str, type_: type[T]) -> T:      if type_ in (tuple, list, set):          return type_(value.split(","))      if type_ is bool: -        return value == "True" +        return value.lower() == "true" or value == "1"      if isinstance(type_, EnumMeta):          return type_[value.upper()] diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 6ff5181a9..890b25718 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -26,8 +26,9 @@ from bot.exts.filtering._ui.filter import (      build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict  )  from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter +from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter  from bot.exts.filtering._ui.ui import ArgumentCompletionView, DeleteConfirmationView, format_response_error -from bot.exts.filtering._utils import past_tense, starting_value, to_serializable +from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable  from bot.log import get_logger  from bot.pagination import LinePaginator  from bot.utils.messages import format_channel, format_user @@ -523,6 +524,63 @@ class Filtering(Cog):          embed = Embed(colour=Colour.blue(), title="Match results")          await LinePaginator.paginate(lines, ctx, embed, max_lines=10, empty=False) +    @filter.command(name="search") +    async def f_search( +        self, +        ctx: Context, +        noui: Literal["noui"] | None, +        filter_type_name: str | None, +        *, +        settings: str = "" +    ) -> None: +        """ +        Find filters with the provided settings. The format is identical to that of the add and edit commands. + +        If a list type and/or a list name are provided, the search will be limited to those parameters. A list name must +        be provided in order to search by filter-specific settings. +        """ +        filter_type = None +        if filter_type_name: +            filter_type_name = filter_type_name.lower() +            filter_type = self.loaded_filters.get(filter_type_name) +            if not filter_type: +                self.loaded_filters.get(filter_type_name[:-1])  # In case the user tried to specify the plural form. +        # If settings were provided with no filter_type, discord.py will capture the first word as the filter type. +        if filter_type is None and filter_type_name is not None: +            if settings: +                settings = f"{filter_type_name} {settings}" +            else: +                settings = filter_type_name +            filter_type_name = None + +        settings, filter_settings, filter_type = search_criteria_converter( +            self.filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            filter_type, +            settings +        ) + +        if noui: +            await self._search_filters(ctx.message, filter_type, settings, filter_settings) +            return + +        embed = Embed(colour=Colour.blue()) +        view = SearchEditView( +            filter_type, +            settings, +            filter_settings, +            self.filter_lists, +            self.loaded_filters, +            self.loaded_settings, +            self.loaded_filter_settings, +            ctx.author, +            embed, +            self._search_filters +        ) +        await ctx.send(embed=embed, reference=ctx.message, view=view) +      # endregion      # region: filterlist group @@ -787,7 +845,7 @@ class Filtering(Cog):          embed = Embed(colour=Colour.blue())          embed.set_author(name=f"List of {filter_list[list_type].label}s ({len(lines)} total)") -        await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) +        await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True)      def _get_filter_by_id(self, id_: int) -> Optional[tuple[Filter, FilterList, ListType]]:          """Get the filter object corresponding to the provided ID, along with its containing list and list type.""" @@ -954,6 +1012,71 @@ class Filtering(Cog):          filter_list.add_list(response)          await msg.reply(f"✅ Edited filter list: {filter_list[list_type].label}") +    def _filter_match_query( +        self, filter_: Filter, settings_query: dict, filter_settings_query: dict, differ_by_default: set[str] +    ) -> bool: +        """Return whether the given filter matches the query.""" +        override_matches = set() +        overrides, _ = filter_.overrides +        for setting_name, setting_value in settings_query.items(): +            if setting_name not in overrides: +                continue +            if repr_equals(overrides[setting_name], setting_value): +                override_matches.add(setting_name) +            else:  # If an override doesn't match then the filter doesn't match. +                return False +        if not (differ_by_default <= override_matches):  # The overrides didn't cover for the default mismatches. +            return False + +        filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {} +        # If the dict changes then some fields were not the same. +        return (filter_settings | filter_settings_query) == filter_settings + +    def _search_filter_list( +        self, atomic_list: AtomicList, filter_type: type[Filter] | None, settings: dict, filter_settings: dict +    ) -> list[Filter]: +        """Find all filters in the filter list which match the settings.""" +        # If the default answers are known, only the overrides need to be checked for each filter. +        all_defaults = atomic_list.defaults.dict() +        match_by_default = set() +        differ_by_default = set() +        for setting_name, setting_value in settings.items(): +            if repr_equals(all_defaults[setting_name], setting_value): +                match_by_default.add(setting_name) +            else: +                differ_by_default.add(setting_name) + +        result_filters = [] +        for filter_ in atomic_list.filters.values(): +            if filter_type and not isinstance(filter_, filter_type): +                continue +            if self._filter_match_query(filter_, settings, filter_settings, differ_by_default): +                result_filters.append(filter_) + +        return result_filters + +    async def _search_filters( +        self, message: Message, filter_type: type[Filter] | None, settings: dict, filter_settings: dict +    ) -> None: +        """Find all filters which match the settings and display them.""" +        lines = [] +        result_count = 0 +        for filter_list in self.filter_lists.values(): +            if filter_type and filter_type not in filter_list.filter_types: +                continue +            for atomic_list in filter_list.values(): +                list_results = self._search_filter_list(atomic_list, filter_type, settings, filter_settings) +                if list_results: +                    lines.append(f"**{atomic_list.label.title()}**") +                    lines.extend(map(str, list_results)) +                    lines.append("") +                    result_count += len(list_results) + +        embed = Embed(colour=Colour.blue()) +        embed.set_author(name=f"Search Results ({result_count} total)") +        ctx = await bot.instance.get_context(message) +        await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True) +      # endregion diff --git a/bot/pagination.py b/bot/pagination.py index 10bef1c9f..92fa781ee 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -204,6 +204,7 @@ class LinePaginator(Paginator):          footer_text: str = None,          url: str = None,          exception_on_empty_embed: bool = False, +        reply: bool = False,      ) -> t.Optional[discord.Message]:          """          Use a paginator and set of reactions to provide pagination over a set of lines. @@ -251,6 +252,8 @@ class LinePaginator(Paginator):          embed.description = paginator.pages[current_page] +        reference = ctx.message if reply else None +          if len(paginator.pages) <= 1:              if footer_text:                  embed.set_footer(text=footer_text) @@ -261,7 +264,7 @@ class LinePaginator(Paginator):                  log.trace(f"Setting embed url to '{url}'")              log.debug("There's less than two pages, so we won't paginate - sending single page on its own") -            return await ctx.send(embed=embed) +            return await ctx.send(embed=embed, reference=reference)          else:              if footer_text:                  embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -274,7 +277,7 @@ class LinePaginator(Paginator):                  log.trace(f"Setting embed url to '{url}'")              log.debug("Sending first page to channel...") -            message = await ctx.send(embed=embed) +            message = await ctx.send(embed=embed, reference=reference)          log.debug("Adding emoji reactions to message...") | 
