aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mbaruh <[email protected]>2022-10-17 23:36:22 +0300
committerGravatar mbaruh <[email protected]>2022-10-17 23:36:22 +0300
commita7a04a118e1a9bf5a4d777ad43d40df9f035021c (patch)
tree11014e6a799a69ba143fdf08818ef387346e55a7
parentChange override handling (diff)
Add a command to query filters by settings
-rw-r--r--bot/exts/filtering/_filter_lists/filter_list.py8
-rw-r--r--bot/exts/filtering/_filters/filter.py6
-rw-r--r--bot/exts/filtering/_settings.py7
-rw-r--r--bot/exts/filtering/_ui/filter.py88
-rw-r--r--bot/exts/filtering/_ui/filter_list.py6
-rw-r--r--bot/exts/filtering/_ui/search.py365
-rw-r--r--bot/exts/filtering/_ui/ui.py4
-rw-r--r--bot/exts/filtering/filtering.py127
-rw-r--r--bot/pagination.py7
9 files changed, 563 insertions, 55 deletions
diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py
index f993665f2..84a43072b 100644
--- a/bot/exts/filtering/_filter_lists/filter_list.py
+++ b/bot/exts/filtering/_filter_lists/filter_list.py
@@ -84,14 +84,14 @@ class AtomicList(NamedTuple):
return relevant_filters
- def default(self, setting: str) -> Any:
+ def default(self, setting_name: str) -> Any:
"""Get the default value of a specific setting."""
missing = object()
- value = self.defaults.actions.get_setting(setting, missing)
+ value = self.defaults.actions.get_setting(setting_name, missing)
if value is missing:
- value = self.defaults.validations.get_setting(setting, missing)
+ value = self.defaults.validations.get_setting(setting_name, missing)
if value is missing:
- raise ValueError(f"Couldn't find a setting named {setting!r}.")
+ raise ValueError(f"Couldn't find a setting named {setting_name!r}.")
return value
diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py
index 0d11d5b3c..095799781 100644
--- a/bot/exts/filtering/_filters/filter.py
+++ b/bot/exts/filtering/_filters/filter.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Optional
+from typing import Any
from pydantic import ValidationError
@@ -28,7 +28,7 @@ class Filter(FieldRequiring):
self.description = filter_data["description"]
self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults)
if self.extra_fields_type:
- self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"])
+ self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"] or "{}") # noqa: P103
else:
self.extra_fields = None
@@ -52,7 +52,7 @@ class Filter(FieldRequiring):
"""Search for the filter's content within a given context."""
@classmethod
- def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, Optional[str]]:
+ def validate_filter_settings(cls, extra_fields: dict) -> tuple[bool, str | None]:
"""Validate whether the supplied fields are valid for the filter, and provide the error message if not."""
if cls.extra_fields_type is None:
return True, None
diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py
index 4c2114f07..066c7a369 100644
--- a/bot/exts/filtering/_settings.py
+++ b/bot/exts/filtering/_settings.py
@@ -211,3 +211,10 @@ class Defaults(NamedTuple):
actions: ActionSettings
validations: ValidationSettings
+
+ def dict(self) -> dict[str, Any]:
+ """Return a dict representation of the stored fields across all entries."""
+ dict_ = {}
+ for settings in self:
+ dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_)
+ return dict_
diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py
index 765fba683..37584e9fd 100644
--- a/bot/exts/filtering/_ui/filter.py
+++ b/bot/exts/filtering/_ui/filter.py
@@ -28,7 +28,7 @@ def build_filter_repr_dict(
settings_overrides: dict,
extra_fields_overrides: dict
) -> dict:
- """Build a dictionary of field names and values to pass to `_build_embed_from_dict`."""
+ """Build a dictionary of field names and values to pass to `populate_embed_from_dict`."""
# Get filter list settings
default_setting_values = {}
for settings_group in filter_list[list_type].defaults:
@@ -155,16 +155,16 @@ class FilterEditView(EditBaseView):
)
self.add_item(add_select)
- override_names = (
- list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides]
- )
- remove_select = CustomCallbackSelect(
- self._remove_override,
- placeholder="Select an override to remove",
- options=[SelectOption(label=name) for name in sorted(override_names)],
- row=2
- )
- if remove_select.options:
+ if settings_overrides or filter_settings_overrides:
+ override_names = (
+ list(settings_overrides) + [f"{filter_list.name}/{setting}" for setting in filter_settings_overrides]
+ )
+ remove_select = CustomCallbackSelect(
+ self._remove_override,
+ placeholder="Select an override to remove",
+ options=[SelectOption(label=name) for name in sorted(override_names)],
+ row=2
+ )
self.add_item(remove_select)
@discord.ui.button(label="Edit Content", row=3)
@@ -285,9 +285,9 @@ class FilterEditView(EditBaseView):
dict_to_edit[setting_name] = setting_value
# If there's already an override, remove it, since the new value is the same as the default.
elif setting_name in dict_to_edit:
- del dict_to_edit[setting_name]
+ dict_to_edit.pop(setting_name)
elif setting_name in dict_to_edit:
- del dict_to_edit[setting_name]
+ dict_to_edit.pop(setting_name)
# This is inefficient, but otherwise the selects go insane if the user attempts to edit the same setting
# multiple times, even when replacing the select with a new one.
@@ -315,8 +315,10 @@ class FilterEditView(EditBaseView):
async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None:
"""Replace any non-overridden settings with overrides from the given filter."""
try:
- settings, filter_settings = template_settings(template_id, self.filter_list, self.list_type)
- except ValueError as e: # The interaction is necessary to send an ephemeral message.
+ settings, filter_settings = template_settings(
+ template_id, self.filter_list, self.list_type, self.filter_type
+ )
+ except BadArgument as e: # The interaction object is necessary to send an ephemeral message.
await interaction.response.send_message(f":x: {e}", ephemeral=True)
return
else:
@@ -326,6 +328,7 @@ class FilterEditView(EditBaseView):
self.filter_settings_overrides = filter_settings | self.filter_settings_overrides
self.embed.clear_fields()
await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
async def _remove_override(self, interaction: Interaction, select: discord.ui.Select) -> None:
"""
@@ -380,28 +383,7 @@ def description_and_settings_converter(
filter_settings = {}
for setting, _ in list(settings.items()):
- if setting not in loaded_settings:
- # It's a filter setting
- if "/" in setting:
- setting_list_name, filter_setting_name = setting.split("/", maxsplit=1)
- if setting_list_name.lower() != filter_list.name.lower():
- raise BadArgument(
- f"A setting for a {setting_list_name!r} filter was provided, "
- f"but the list name is {filter_list.name!r}"
- )
- if filter_setting_name not in loaded_filter_settings[filter_list.name]:
- raise BadArgument(f"{setting!r} is not a recognized setting.")
- type_ = loaded_filter_settings[filter_list.name][filter_setting_name][2]
- try:
- parsed_value = parse_value(settings.pop(setting), type_)
- if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)):
- filter_settings[filter_setting_name] = parsed_value
- except (TypeError, ValueError) as e:
- raise BadArgument(e)
- else:
- raise BadArgument(f"{setting!r} is not a recognized setting.")
- # It's a filter list setting
- else:
+ if setting in loaded_settings: # It's a filter list setting
type_ = loaded_settings[setting][2]
try:
parsed_value = parse_value(settings.pop(setting), type_)
@@ -409,11 +391,28 @@ def description_and_settings_converter(
settings[setting] = parsed_value
except (TypeError, ValueError) as e:
raise BadArgument(e)
+ elif "/" not in setting:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ else: # It's a filter setting
+ filter_name, filter_setting_name = setting.split("/", maxsplit=1)
+ if filter_name.lower() != filter_type.name.lower():
+ raise BadArgument(
+ f"A setting for a {filter_name!r} filter was provided, but the filter name is {filter_type.name!r}"
+ )
+ if filter_setting_name not in loaded_filter_settings[filter_type.name]:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2]
+ try:
+ parsed_value = parse_value(settings.pop(setting), type_)
+ if not repr_equals(parsed_value, getattr(filter_type.extra_fields_type(), filter_setting_name)):
+ filter_settings[filter_setting_name] = parsed_value
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
# Pull templates settings and apply them.
if template is not None:
try:
- t_settings, t_filter_settings = template_settings(template, filter_list, list_type)
+ t_settings, t_filter_settings = template_settings(template, filter_list, list_type, filter_type)
except ValueError as e:
raise BadArgument(str(e))
else:
@@ -430,18 +429,25 @@ def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]:
return to_serializable(overrides_values), to_serializable(extra_fields_overrides)
-def template_settings(filter_id: str, filter_list: FilterList, list_type: ListType) -> tuple[dict, dict]:
+def template_settings(
+ filter_id: str, filter_list: FilterList, list_type: ListType, filter_type: type[Filter]
+) -> tuple[dict, dict]:
"""Find the filter with specified ID, and return its settings."""
try:
filter_id = int(filter_id)
if filter_id < 0:
raise ValueError()
except ValueError:
- raise ValueError("Template value must be a non-negative integer.")
+ raise BadArgument("Template value must be a non-negative integer.")
if filter_id not in filter_list[list_type].filters:
- raise ValueError(
+ raise BadArgument(
f"Could not find filter with ID `{filter_id}` in the {list_type.name} {filter_list.name} list."
)
filter_ = filter_list[list_type].filters[filter_id]
+
+ if not isinstance(filter_, filter_type):
+ raise BadArgument(
+ f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}"
+ )
return filter_serializable_overrides(filter_)
diff --git a/bot/exts/filtering/_ui/filter_list.py b/bot/exts/filtering/_ui/filter_list.py
index 15d81322b..e77e29ec9 100644
--- a/bot/exts/filtering/_ui/filter_list.py
+++ b/bot/exts/filtering/_ui/filter_list.py
@@ -24,7 +24,11 @@ def settings_converter(loaded_settings: dict, input_data: str) -> dict[str, Any]
if not parsed:
return {}
- settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ try:
+ settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ except ValueError:
+ raise BadArgument("The settings provided are not in the correct format.")
+
for setting in settings:
if setting not in loaded_settings:
raise BadArgument(f"{setting!r} is not a recognized setting.")
diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py
new file mode 100644
index 000000000..d553c28ea
--- /dev/null
+++ b/bot/exts/filtering/_ui/search.py
@@ -0,0 +1,365 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
+
+import discord
+from discord import Interaction, SelectOption
+from discord.ext.commands import BadArgument
+
+from bot.exts.filtering._filter_lists import FilterList, ListType
+from bot.exts.filtering._filters.filter import Filter
+from bot.exts.filtering._settings_types.settings_entry import SettingsEntry
+from bot.exts.filtering._ui.filter import filter_serializable_overrides
+from bot.exts.filtering._ui.ui import (
+ COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value,
+ populate_embed_from_dict
+)
+
+
+def search_criteria_converter(
+ filter_lists: dict,
+ loaded_filters: dict,
+ loaded_settings: dict,
+ loaded_filter_settings: dict,
+ filter_type: type[Filter] | None,
+ input_data: str
+) -> tuple[dict[str, Any], dict[str, Any], type[Filter]]:
+ """Parse a string representing setting overrides, and validate the setting names."""
+ if not input_data:
+ return {}, {}, filter_type
+
+ parsed = SETTINGS_DELIMITER.split(input_data)
+ if not parsed:
+ return {}, {}, filter_type
+
+ try:
+ settings = {setting: value for setting, value in [part.split("=", maxsplit=1) for part in parsed]}
+ except ValueError:
+ raise BadArgument("The settings provided are not in the correct format.")
+
+ template = None
+ if "--template" in settings:
+ template = settings.pop("--template")
+
+ filter_settings = {}
+ for setting, _ in list(settings.items()):
+ if setting in loaded_settings: # It's a filter list setting
+ type_ = loaded_settings[setting][2]
+ try:
+ settings[setting] = parse_value(settings[setting], type_)
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+ elif "/" not in setting:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ else: # It's a filter setting
+ filter_name, filter_setting_name = setting.split("/", maxsplit=1)
+ if not filter_type:
+ if filter_name in loaded_filters:
+ filter_type = loaded_filters[filter_name]
+ else:
+ raise BadArgument(f"There's no filter type named {filter_name!r}.")
+ if filter_name.lower() != filter_type.name.lower():
+ raise BadArgument(
+ f"A setting for a {filter_name!r} filter was provided, "
+ f"but the filter name is {filter_type.name!r}"
+ )
+ if filter_setting_name not in loaded_filter_settings[filter_type.name]:
+ raise BadArgument(f"{setting!r} is not a recognized setting.")
+ type_ = loaded_filter_settings[filter_type.name][filter_setting_name][2]
+ try:
+ filter_settings[filter_setting_name] = parse_value(settings.pop(setting), type_)
+ except (TypeError, ValueError) as e:
+ raise BadArgument(e)
+
+ # Pull templates settings and apply them.
+ if template is not None:
+ try:
+ t_settings, t_filter_settings, filter_type = template_settings(template, filter_lists, filter_type)
+ except ValueError as e:
+ raise BadArgument(str(e))
+ else:
+ # The specified settings go on top of the template
+ settings = t_settings | settings
+ filter_settings = t_filter_settings | filter_settings
+
+ return settings, filter_settings, filter_type
+
+
+def get_filter(filter_id: int, filter_lists: dict) -> tuple[Filter, FilterList, ListType] | None:
+ """Return a filter with the specific filter_id, if found."""
+ for filter_list in filter_lists.values():
+ for list_type, sublist in filter_list.items():
+ if filter_id in sublist.filters:
+ return sublist.filters[filter_id], filter_list, list_type
+ return None
+
+
+def template_settings(
+ filter_id: str, filter_lists: dict, filter_type: type[Filter] | None
+) -> tuple[dict, dict, type[Filter]]:
+ """Find a filter with the specified ID and filter type, and return its settings and (maybe newly found) type."""
+ try:
+ filter_id = int(filter_id)
+ if filter_id < 0:
+ raise ValueError()
+ except ValueError:
+ raise BadArgument("Template value must be a non-negative integer.")
+
+ result = get_filter(filter_id, filter_lists)
+ if not result:
+ raise BadArgument(f"Could not find a filter with ID `{filter_id}`.")
+ filter_, filter_list, list_type = result
+
+ if filter_type and not isinstance(filter_, filter_type):
+ raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.")
+
+ settings, filter_settings = filter_serializable_overrides(filter_)
+ return settings, filter_settings, type(filter_)
+
+
+def build_search_repr_dict(
+ settings: dict[str, Any], filter_settings: dict[str, Any], filter_type: type[Filter] | None
+) -> dict:
+ """Build a dictionary of field names and values to pass to `populate_embed_from_dict`."""
+ total_values = settings.copy()
+ if filter_type:
+ for setting_name, value in filter_settings.items():
+ total_values[f"{filter_type.name}/{setting_name}"] = value
+
+ return total_values
+
+
+class SearchEditView(EditBaseView):
+ """A view used to edit the search criteria before performing the search."""
+
+ class _REMOVE:
+ """Sentinel value for when an override should be removed."""
+
+ def __init__(
+ self,
+ filter_type: type[Filter] | None,
+ settings: dict[str, Any],
+ filter_settings: dict[str, Any],
+ loaded_filter_lists: dict[str, FilterList],
+ loaded_filters: dict[str, type[Filter]],
+ loaded_settings: dict[str, tuple[str, SettingsEntry, type]],
+ loaded_filter_settings: dict[str, dict[str, tuple[str, SettingsEntry, type]]],
+ author: discord.User | discord.Member,
+ embed: discord.Embed,
+ confirm_callback: Callable
+ ):
+ super().__init__(author)
+ self.filter_type = filter_type
+ self.settings = settings
+ self.filter_settings = filter_settings
+ self.loaded_filter_lists = loaded_filter_lists
+ self.loaded_filters = loaded_filters
+ self.loaded_settings = loaded_settings
+ self.loaded_filter_settings = loaded_filter_settings
+ self.embed = embed
+ self.confirm_callback = confirm_callback
+
+ title = "Filters Search"
+ if filter_type:
+ title += f" - {filter_type.name.title()}"
+ embed.set_author(name=title)
+
+ settings_repr_dict = build_search_repr_dict(settings, filter_settings, filter_type)
+ populate_embed_from_dict(embed, settings_repr_dict)
+
+ self.type_per_setting_name = {setting: info[2] for setting, info in loaded_settings.items()}
+ if filter_type:
+ self.type_per_setting_name.update({
+ f"{filter_type.name}/{name}": type_
+ for name, (_, _, type_) in loaded_filter_settings.get(filter_type.name, {}).items()
+ })
+
+ add_select = CustomCallbackSelect(
+ self._prompt_new_value,
+ placeholder="Add or edit criterion",
+ options=[SelectOption(label=name) for name in sorted(self.type_per_setting_name)],
+ row=0
+ )
+ self.add_item(add_select)
+
+ if settings_repr_dict:
+ remove_select = CustomCallbackSelect(
+ self._remove_criterion,
+ placeholder="Select a criterion to remove",
+ options=[SelectOption(label=name) for name in sorted(settings_repr_dict)],
+ row=1
+ )
+ self.add_item(remove_select)
+
+ @discord.ui.button(label="Template", row=2)
+ async def enter_template(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to enter a filter template ID and copy its overrides over."""
+ modal = TemplateModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="Filter Type", row=2)
+ async def enter_filter_type(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """A button to enter a filter type."""
+ modal = FilterTypeModal(self, interaction.message)
+ await interaction.response.send_modal(modal)
+
+ @discord.ui.button(label="✅ Confirm", style=discord.ButtonStyle.green, row=3)
+ async def confirm(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Confirm the search criteria and perform the search."""
+ await interaction.response.edit_message(view=None) # Make sure the interaction succeeds first.
+ try:
+ await self.confirm_callback(interaction.message, self.filter_type, self.settings, self.filter_settings)
+ except BadArgument as e:
+ await interaction.message.reply(
+ embed=discord.Embed(colour=discord.Colour.red(), title="Bad Argument", description=str(e))
+ )
+ await interaction.message.edit(view=self)
+ else:
+ self.stop()
+
+ @discord.ui.button(label="🚫 Cancel", style=discord.ButtonStyle.red, row=3)
+ async def cancel(self, interaction: Interaction, button: discord.ui.Button) -> None:
+ """Cancel the operation."""
+ await interaction.response.edit_message(content="🚫 Operation canceled.", embed=None, view=None)
+ self.stop()
+
+ def current_value(self, setting_name: str) -> Any:
+ """Get the current value stored for the setting or MISSING if none found."""
+ if setting_name in self.settings:
+ return self.settings[setting_name]
+ if "/" in setting_name:
+ _, setting_name = setting_name.split("/", maxsplit=1)
+ if setting_name in self.filter_settings:
+ return self.filter_settings[setting_name]
+ return MISSING
+
+ async def update_embed(
+ self,
+ interaction_or_msg: discord.Interaction | discord.Message,
+ *,
+ setting_name: str | None = None,
+ setting_value: str | type[SearchEditView._REMOVE] | None = None,
+ ) -> None:
+ """
+ Update the embed with the new information.
+
+ If a setting name is provided with a _REMOVE value, remove the override.
+ If `interaction_or_msg` is a Message, the invoking Interaction must be deferred before calling this function.
+ """
+ if not setting_name: # Can be None just to make the function signature compatible with the parent class.
+ return
+
+ if "/" in setting_name:
+ filter_name, setting_name = setting_name.split("/", maxsplit=1)
+ dict_to_edit = self.filter_settings
+ else:
+ dict_to_edit = self.settings
+
+ # Update the criterion value or remove it
+ if setting_value is not self._REMOVE:
+ dict_to_edit[setting_name] = setting_value
+ elif setting_name in dict_to_edit:
+ dict_to_edit.pop(setting_name)
+
+ self.embed.clear_fields()
+ new_view = self.copy()
+
+ try:
+ if isinstance(interaction_or_msg, discord.Interaction):
+ await interaction_or_msg.response.edit_message(embed=self.embed, view=new_view)
+ else:
+ await interaction_or_msg.edit(embed=self.embed, view=new_view)
+ except discord.errors.HTTPException: # Just in case of faulty input.
+ pass
+ else:
+ self.stop()
+
+ async def _remove_criterion(self, interaction: Interaction, select: discord.ui.Select) -> None:
+ """
+ Remove the criterion the user selected, and edit the embed.
+
+ The interaction needs to be the selection of the setting attached to the embed.
+ """
+ await self.update_embed(interaction, setting_name=select.values[0], setting_value=self._REMOVE)
+
+ async def apply_template(self, template_id: str, embed_message: discord.Message, interaction: Interaction) -> None:
+ """Set any unset criteria with settings values from the given filter."""
+ try:
+ settings, filter_settings, self.filter_type = template_settings(
+ template_id, self.loaded_filter_lists, self.filter_type
+ )
+ except BadArgument as e: # The interaction object is necessary to send an ephemeral message.
+ await interaction.response.send_message(f":x: {e}", ephemeral=True)
+ return
+ else:
+ await interaction.response.defer()
+
+ self.settings = settings | self.settings
+ self.filter_settings = filter_settings | self.filter_settings
+ self.embed.clear_fields()
+ await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
+
+ async def apply_filter_type(self, type_name: str, embed_message: discord.Message, interaction: Interaction) -> None:
+ """Set a new filter type and reset any criteria for settings of the old filter type."""
+ if type_name.lower() not in self.loaded_filters:
+ if type_name.lower()[:-1] not in self.loaded_filters: # In case the user entered the plural form.
+ await interaction.response.send_message(f":x: No such filter type {type_name!r}.", ephemeral=True)
+ return
+ type_name = type_name[:-1]
+ type_name = type_name.lower()
+ await interaction.response.defer()
+
+ if self.filter_type and type_name == self.filter_type.name:
+ return
+ self.filter_type = self.loaded_filters[type_name]
+ self.filter_settings = {}
+ self.embed.clear_fields()
+ await embed_message.edit(embed=self.embed, view=self.copy())
+ self.stop()
+
+ def copy(self) -> SearchEditView:
+ """Create a copy of this view."""
+ return SearchEditView(
+ self.filter_type,
+ self.settings,
+ self.filter_settings,
+ self.loaded_filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ self.author,
+ self.embed,
+ self.confirm_callback
+ )
+
+
+class TemplateModal(discord.ui.Modal, title="Template"):
+ """A modal to enter a filter ID to copy its overrides over."""
+
+ template = discord.ui.TextInput(label="Template Filter ID", required=False)
+
+ def __init__(self, embed_view: SearchEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await self.embed_view.apply_template(self.template.value, self.message, interaction)
+
+
+class FilterTypeModal(discord.ui.Modal, title="Template"):
+ """A modal to enter a filter ID to copy its overrides over."""
+
+ filter_type = discord.ui.TextInput(label="Filter Type")
+
+ def __init__(self, embed_view: SearchEditView, message: discord.Message):
+ super().__init__(timeout=COMPONENT_TIMEOUT)
+ self.embed_view = embed_view
+ self.message = message
+
+ async def on_submit(self, interaction: Interaction) -> None:
+ """Update the embed with the new description."""
+ await self.embed_view.apply_filter_type(self.filter_type.value, self.message, interaction)
diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py
index 980eba02a..c506db1fe 100644
--- a/bot/exts/filtering/_ui/ui.py
+++ b/bot/exts/filtering/_ui/ui.py
@@ -34,7 +34,7 @@ MAX_SELECT_ITEMS = 25
MAX_EMBED_DESCRIPTION = 4000
SETTINGS_DELIMITER = re.compile(r"\s+(?=\S+=\S+)")
-SINGLE_SETTING_PATTERN = re.compile(r"\w+=.+")
+SINGLE_SETTING_PATTERN = re.compile(r"[\w/]+=.+")
# Sentinel value to denote that a value is missing
MISSING = object()
@@ -76,7 +76,7 @@ def parse_value(value: str, type_: type[T]) -> T:
if type_ in (tuple, list, set):
return type_(value.split(","))
if type_ is bool:
- return value == "True"
+ return value.lower() == "true" or value == "1"
if isinstance(type_, EnumMeta):
return type_[value.upper()]
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
index 6ff5181a9..890b25718 100644
--- a/bot/exts/filtering/filtering.py
+++ b/bot/exts/filtering/filtering.py
@@ -26,8 +26,9 @@ from bot.exts.filtering._ui.filter import (
build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict
)
from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter
+from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter
from bot.exts.filtering._ui.ui import ArgumentCompletionView, DeleteConfirmationView, format_response_error
-from bot.exts.filtering._utils import past_tense, starting_value, to_serializable
+from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable
from bot.log import get_logger
from bot.pagination import LinePaginator
from bot.utils.messages import format_channel, format_user
@@ -523,6 +524,63 @@ class Filtering(Cog):
embed = Embed(colour=Colour.blue(), title="Match results")
await LinePaginator.paginate(lines, ctx, embed, max_lines=10, empty=False)
+ @filter.command(name="search")
+ async def f_search(
+ self,
+ ctx: Context,
+ noui: Literal["noui"] | None,
+ filter_type_name: str | None,
+ *,
+ settings: str = ""
+ ) -> None:
+ """
+ Find filters with the provided settings. The format is identical to that of the add and edit commands.
+
+ If a list type and/or a list name are provided, the search will be limited to those parameters. A list name must
+ be provided in order to search by filter-specific settings.
+ """
+ filter_type = None
+ if filter_type_name:
+ filter_type_name = filter_type_name.lower()
+ filter_type = self.loaded_filters.get(filter_type_name)
+ if not filter_type:
+ self.loaded_filters.get(filter_type_name[:-1]) # In case the user tried to specify the plural form.
+ # If settings were provided with no filter_type, discord.py will capture the first word as the filter type.
+ if filter_type is None and filter_type_name is not None:
+ if settings:
+ settings = f"{filter_type_name} {settings}"
+ else:
+ settings = filter_type_name
+ filter_type_name = None
+
+ settings, filter_settings, filter_type = search_criteria_converter(
+ self.filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ filter_type,
+ settings
+ )
+
+ if noui:
+ await self._search_filters(ctx.message, filter_type, settings, filter_settings)
+ return
+
+ embed = Embed(colour=Colour.blue())
+ view = SearchEditView(
+ filter_type,
+ settings,
+ filter_settings,
+ self.filter_lists,
+ self.loaded_filters,
+ self.loaded_settings,
+ self.loaded_filter_settings,
+ ctx.author,
+ embed,
+ self._search_filters
+ )
+ await ctx.send(embed=embed, reference=ctx.message, view=view)
+
# endregion
# region: filterlist group
@@ -787,7 +845,7 @@ class Filtering(Cog):
embed = Embed(colour=Colour.blue())
embed.set_author(name=f"List of {filter_list[list_type].label}s ({len(lines)} total)")
- await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False)
+ await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True)
def _get_filter_by_id(self, id_: int) -> Optional[tuple[Filter, FilterList, ListType]]:
"""Get the filter object corresponding to the provided ID, along with its containing list and list type."""
@@ -954,6 +1012,71 @@ class Filtering(Cog):
filter_list.add_list(response)
await msg.reply(f"✅ Edited filter list: {filter_list[list_type].label}")
+ def _filter_match_query(
+ self, filter_: Filter, settings_query: dict, filter_settings_query: dict, differ_by_default: set[str]
+ ) -> bool:
+ """Return whether the given filter matches the query."""
+ override_matches = set()
+ overrides, _ = filter_.overrides
+ for setting_name, setting_value in settings_query.items():
+ if setting_name not in overrides:
+ continue
+ if repr_equals(overrides[setting_name], setting_value):
+ override_matches.add(setting_name)
+ else: # If an override doesn't match then the filter doesn't match.
+ return False
+ if not (differ_by_default <= override_matches): # The overrides didn't cover for the default mismatches.
+ return False
+
+ filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {}
+ # If the dict changes then some fields were not the same.
+ return (filter_settings | filter_settings_query) == filter_settings
+
+ def _search_filter_list(
+ self, atomic_list: AtomicList, filter_type: type[Filter] | None, settings: dict, filter_settings: dict
+ ) -> list[Filter]:
+ """Find all filters in the filter list which match the settings."""
+ # If the default answers are known, only the overrides need to be checked for each filter.
+ all_defaults = atomic_list.defaults.dict()
+ match_by_default = set()
+ differ_by_default = set()
+ for setting_name, setting_value in settings.items():
+ if repr_equals(all_defaults[setting_name], setting_value):
+ match_by_default.add(setting_name)
+ else:
+ differ_by_default.add(setting_name)
+
+ result_filters = []
+ for filter_ in atomic_list.filters.values():
+ if filter_type and not isinstance(filter_, filter_type):
+ continue
+ if self._filter_match_query(filter_, settings, filter_settings, differ_by_default):
+ result_filters.append(filter_)
+
+ return result_filters
+
+ async def _search_filters(
+ self, message: Message, filter_type: type[Filter] | None, settings: dict, filter_settings: dict
+ ) -> None:
+ """Find all filters which match the settings and display them."""
+ lines = []
+ result_count = 0
+ for filter_list in self.filter_lists.values():
+ if filter_type and filter_type not in filter_list.filter_types:
+ continue
+ for atomic_list in filter_list.values():
+ list_results = self._search_filter_list(atomic_list, filter_type, settings, filter_settings)
+ if list_results:
+ lines.append(f"**{atomic_list.label.title()}**")
+ lines.extend(map(str, list_results))
+ lines.append("")
+ result_count += len(list_results)
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(name=f"Search Results ({result_count} total)")
+ ctx = await bot.instance.get_context(message)
+ await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False, reply=True)
+
# endregion
diff --git a/bot/pagination.py b/bot/pagination.py
index 10bef1c9f..92fa781ee 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -204,6 +204,7 @@ class LinePaginator(Paginator):
footer_text: str = None,
url: str = None,
exception_on_empty_embed: bool = False,
+ reply: bool = False,
) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -251,6 +252,8 @@ class LinePaginator(Paginator):
embed.description = paginator.pages[current_page]
+ reference = ctx.message if reply else None
+
if len(paginator.pages) <= 1:
if footer_text:
embed.set_footer(text=footer_text)
@@ -261,7 +264,7 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
- return await ctx.send(embed=embed)
+ return await ctx.send(embed=embed, reference=reference)
else:
if footer_text:
embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
@@ -274,7 +277,7 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("Sending first page to channel...")
- message = await ctx.send(embed=embed)
+ message = await ctx.send(embed=embed, reference=reference)
log.debug("Adding emoji reactions to message...")