aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mbaruh <[email protected]>2023-03-28 01:12:25 +0300
committerGravatar mbaruh <[email protected]>2023-03-28 01:31:23 +0300
commitf01883682f4d333382d8e8a89363dc906fe86342 (patch)
tree72521763d6ba67bb24d78393da1ce57dd2eadab8
parentCorrect filter match docstring (diff)
Support custom value representation in filtering UI
Adds the `CustomIOField` class which can be used as a base for wrappers that store a value with a customized way to process the user input and to present the value in the UI.
-rw-r--r--bot/exts/filtering/_settings_types/actions/infraction_and_notification.py52
-rw-r--r--bot/exts/filtering/_ui/filter.py10
-rw-r--r--bot/exts/filtering/_ui/search.py4
-rw-r--r--bot/exts/filtering/_utils.py66
-rw-r--r--bot/exts/filtering/filtering.py8
-rw-r--r--tests/bot/exts/filtering/test_settings_entries.py16
6 files changed, 127 insertions, 29 deletions
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
index 5ae4901b6..e3df47029 100644
--- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
+++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
@@ -1,9 +1,9 @@
-from datetime import timedelta
from enum import Enum, auto
from typing import ClassVar
import arrow
import discord.abc
+from dateutil.relativedelta import relativedelta
from discord import Colour, Embed, Member, User
from discord.errors import Forbidden
from pydantic import validator
@@ -15,7 +15,8 @@ import bot as bot_module
from bot.constants import Channels
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ActionEntry
-from bot.exts.filtering._utils import FakeContext
+from bot.exts.filtering._utils import CustomIOField, FakeContext
+from bot.utils.time import humanize_delta, parse_duration_string, relativedelta_to_timedelta
log = get_logger(__name__)
@@ -31,6 +32,38 @@ passive_form = {
}
+class InfractionDuration(CustomIOField):
+ """A field that converts a string to a duration and presents it in a human-readable format."""
+
+ @classmethod
+ def process_value(cls, v: str | relativedelta) -> relativedelta:
+ """
+ Transform the given string into a relativedelta.
+
+ Raise a ValueError if the conversion is not possible.
+ """
+ if isinstance(v, relativedelta):
+ return v
+
+ try:
+ v = float(v)
+ except ValueError: # Not a float.
+ if not (delta := parse_duration_string(v)):
+ raise ValueError(f"`{v}` is not a valid duration string.")
+ else:
+ delta = relativedelta(seconds=float(v)).normalized()
+
+ return delta
+
+ def serialize(self) -> float:
+ """The serialized value is the total number of seconds this duration represents."""
+ return relativedelta_to_timedelta(self.value).total_seconds()
+
+ def __str__(self):
+ """Represent the stored duration in a human-readable format."""
+ return humanize_delta(self.value, max_units=2) if self.value else "Permanent"
+
+
class Infraction(Enum):
"""An enumeration of infraction types. The lower the value, the higher it is on the hierarchy."""
@@ -53,7 +86,7 @@ class Infraction(Enum):
message: discord.Message,
channel: discord.abc.GuildChannel | discord.DMChannel,
alerts_channel: discord.TextChannel,
- duration: float,
+ duration: InfractionDuration,
reason: str
) -> None:
"""Invokes the command matching the infraction name."""
@@ -72,7 +105,7 @@ class Infraction(Enum):
if self.name in ("KICK", "WARNING", "WATCH", "NOTE"):
await command(ctx, user, reason=reason or None)
else:
- duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None
+ duration = arrow.utcnow().datetime + duration.value if duration.value else None
await command(ctx, user, duration, reason=reason or None)
@@ -91,7 +124,10 @@ class InfractionAndNotification(ActionEntry):
"the harsher one will be applied (by type or duration).\n\n"
"Valid infraction types in order of harshness: "
) + ", ".join(infraction.name for infraction in Infraction),
- "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.",
+ "infraction_duration": (
+ "How long the infraction should last for in seconds. 0 for permanent. "
+ "Also supports durations as in an infraction invocation (such as `10d`)."
+ ),
"infraction_reason": "The reason delivered with the infraction.",
"infraction_channel": (
"The channel ID in which to invoke the infraction (and send the confirmation message). "
@@ -106,7 +142,7 @@ class InfractionAndNotification(ActionEntry):
dm_embed: str
infraction_type: Infraction
infraction_reason: str
- infraction_duration: float
+ infraction_duration: InfractionDuration
infraction_channel: int
@validator("infraction_type", pre=True)
@@ -184,8 +220,10 @@ class InfractionAndNotification(ActionEntry):
result = other.copy()
other = self
else:
+ now = arrow.utcnow().datetime
if self.infraction_duration is None or (
- other.infraction_duration is not None and self.infraction_duration > other.infraction_duration
+ other.infraction_duration is not None
+ and now + self.infraction_duration.value > now + other.infraction_duration.value
):
result = self.copy()
else:
diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py
index 1ef25f17a..5b23b71e9 100644
--- a/bot/exts/filtering/_ui/filter.py
+++ b/bot/exts/filtering/_ui/filter.py
@@ -33,7 +33,7 @@ def build_filter_repr_dict(
default_setting_values = {}
for settings_group in filter_list[list_type].defaults:
for _, setting in settings_group.items():
- default_setting_values.update(to_serializable(setting.dict()))
+ default_setting_values.update(to_serializable(setting.dict(), ui_repr=True))
# Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.
total_values = {}
@@ -434,10 +434,10 @@ def description_and_settings_converter(
return description, settings, filter_settings
-def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]:
- """Get a serializable version of the filter's overrides."""
+def filter_overrides_for_ui(filter_: Filter) -> tuple[dict, dict]:
+ """Get the filter's overrides in a format that can be displayed in the UI."""
overrides_values, extra_fields_overrides = filter_.overrides
- return to_serializable(overrides_values), to_serializable(extra_fields_overrides)
+ return to_serializable(overrides_values, ui_repr=True), to_serializable(extra_fields_overrides, ui_repr=True)
def template_settings(
@@ -461,4 +461,4 @@ def template_settings(
raise BadArgument(
f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}"
)
- return filter_serializable_overrides(filter_)
+ return filter_.overrides
diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py
index d553c28ea..dba7f3cea 100644
--- a/bot/exts/filtering/_ui/search.py
+++ b/bot/exts/filtering/_ui/search.py
@@ -10,7 +10,7 @@ from discord.ext.commands import BadArgument
from bot.exts.filtering._filter_lists import FilterList, ListType
from bot.exts.filtering._filters.filter import Filter
from bot.exts.filtering._settings_types.settings_entry import SettingsEntry
-from bot.exts.filtering._ui.filter import filter_serializable_overrides
+from bot.exts.filtering._ui.filter import filter_overrides_for_ui
from bot.exts.filtering._ui.ui import (
COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value,
populate_embed_from_dict
@@ -114,7 +114,7 @@ def template_settings(
if filter_type and not isinstance(filter_, filter_type):
raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.")
- settings, filter_settings = filter_serializable_overrides(filter_)
+ settings, filter_settings = filter_overrides_for_ui(filter_)
return settings, filter_settings, type(filter_)
diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py
index da433330f..a43233f20 100644
--- a/bot/exts/filtering/_utils.py
+++ b/bot/exts/filtering/_utils.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import importlib
import importlib.util
import inspect
@@ -12,6 +14,7 @@ from typing import Any, Iterable, TypeVar, Union, get_args, get_origin
import discord
import regex
from discord.ext.commands import Command
+from typing_extensions import Self
import bot
from bot.bot import Bot
@@ -24,6 +27,8 @@ ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIAT
T = TypeVar('T')
+Serializable = Union[bool, int, float, str, list, dict, None]
+
def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]:
"""Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path."""
@@ -62,8 +67,13 @@ def past_tense(word: str) -> str:
return word + "ed"
-def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]:
- """Convert the item into an object that can be converted to JSON."""
+def to_serializable(item: Any, *, ui_repr: bool = False) -> Serializable:
+ """
+ Convert the item into an object that can be converted to JSON.
+
+ `ui_repr` dictates whether to use the UI representation of `CustomIOField` instances (if any)
+ or the DB-oriented representation.
+ """
if isinstance(item, (bool, int, float, str, type(None))):
return item
if isinstance(item, dict):
@@ -71,10 +81,12 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]
for key, value in item.items():
if not isinstance(key, (bool, int, float, str, type(None))):
key = str(key)
- result[key] = to_serializable(value)
+ result[key] = to_serializable(value, ui_repr=ui_repr)
return result
if isinstance(item, Iterable):
- return [to_serializable(subitem) for subitem in item]
+ return [to_serializable(subitem, ui_repr=ui_repr) for subitem in item]
+ if not ui_repr and hasattr(item, "serialize"):
+ return item.serialize()
return str(item)
@@ -222,3 +234,49 @@ class FakeContext:
async def send(self, *args, **kwargs) -> discord.Message:
"""A wrapper for channel.send."""
return await self.channel.send(*args, **kwargs)
+
+
+class CustomIOField:
+ """
+ A class to be used as a data type in SettingEntry subclasses.
+
+ Its subclasses can have custom methods to read and represent the value, which will be used by the UI.
+ """
+
+ def __init__(self, value: Any):
+ self.value = self.process_value(value)
+
+ @classmethod
+ def __get_validators__(cls):
+ """Boilerplate for Pydantic."""
+ yield cls.validate
+
+ @classmethod
+ def validate(cls, v: Any) -> Self:
+ """Takes the given value and returns a class instance with that value."""
+ if isinstance(v, CustomIOField):
+ return cls(v.value)
+
+ return cls(v)
+
+ def __eq__(self, other: CustomIOField):
+ if not isinstance(other, CustomIOField):
+ return NotImplemented
+ return self.value == other.value
+
+ @classmethod
+ def process_value(cls, v: str) -> Any:
+ """
+ Perform any necessary transformations before the value is stored in a new instance.
+
+ Override this method to customize the input behavior.
+ """
+ return v
+
+ def serialize(self) -> Serializable:
+ """Override this method to customize how the value will be serialized."""
+ return self.value
+
+ def __str__(self):
+ """Override this method to change how the value will be displayed by the UI."""
+ return self.value
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
index 58d2f125e..8fd4ddb13 100644
--- a/bot/exts/filtering/filtering.py
+++ b/bot/exts/filtering/filtering.py
@@ -31,7 +31,7 @@ from bot.exts.filtering._filters.filter import Filter, UniqueFilter
from bot.exts.filtering._settings import ActionSettings
from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction
from bot.exts.filtering._ui.filter import (
- build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict
+ build_filter_repr_dict, description_and_settings_converter, filter_overrides_for_ui, populate_embed_from_dict
)
from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter
from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter
@@ -383,7 +383,7 @@ class Filtering(Cog):
return
filter_, filter_list, list_type = result
- overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_)
+ overrides_values, extra_fields_overrides = filter_overrides_for_ui(filter_)
all_settings_repr_dict = build_filter_repr_dict(
filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides
@@ -493,7 +493,7 @@ class Filtering(Cog):
return
filter_, filter_list, list_type = result
filter_type = type(filter_)
- settings, filter_settings = filter_serializable_overrides(filter_)
+ settings, filter_settings = filter_overrides_for_ui(filter_)
description, new_settings, new_filter_settings = description_and_settings_converter(
filter_list,
list_type, filter_type,
@@ -734,7 +734,7 @@ class Filtering(Cog):
setting_values = {}
for settings_group in filter_list[list_type].defaults:
for _, setting in settings_group.items():
- setting_values.update(to_serializable(setting.dict()))
+ setting_values.update(to_serializable(setting.dict(), ui_repr=True))
embed = Embed(colour=Colour.blue())
populate_embed_from_dict(embed, setting_values)
diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py
index c5f0152b0..3ae0b5ab5 100644
--- a/tests/bot/exts/filtering/test_settings_entries.py
+++ b/tests/bot/exts/filtering/test_settings_entries.py
@@ -1,7 +1,9 @@
import unittest
from bot.exts.filtering._filter_context import Event, FilterContext
-from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import (
+ Infraction, InfractionAndNotification, InfractionDuration
+)
from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass
from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope
from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM
@@ -154,7 +156,7 @@ class FilterTests(unittest.TestCase):
infraction1 = InfractionAndNotification(
infraction_type="TIMEOUT",
infraction_reason="hi",
- infraction_duration=10,
+ infraction_duration=InfractionDuration(10),
dm_content="how",
dm_embed="what is",
infraction_channel=0
@@ -162,7 +164,7 @@ class FilterTests(unittest.TestCase):
infraction2 = InfractionAndNotification(
infraction_type="TIMEOUT",
infraction_reason="there",
- infraction_duration=20,
+ infraction_duration=InfractionDuration(20),
dm_content="are you",
dm_embed="your name",
infraction_channel=0
@@ -175,7 +177,7 @@ class FilterTests(unittest.TestCase):
{
"infraction_type": Infraction.TIMEOUT,
"infraction_reason": "there",
- "infraction_duration": 20.0,
+ "infraction_duration": InfractionDuration(20.0),
"dm_content": "are you",
"dm_embed": "your name",
"infraction_channel": 0
@@ -187,7 +189,7 @@ class FilterTests(unittest.TestCase):
infraction1 = InfractionAndNotification(
infraction_type="TIMEOUT",
infraction_reason="hi",
- infraction_duration=20,
+ infraction_duration=InfractionDuration(20),
dm_content="",
dm_embed="",
infraction_channel=0
@@ -195,7 +197,7 @@ class FilterTests(unittest.TestCase):
infraction2 = InfractionAndNotification(
infraction_type="BAN",
infraction_reason="",
- infraction_duration=10,
+ infraction_duration=InfractionDuration(10),
dm_content="there",
dm_embed="",
infraction_channel=0
@@ -208,7 +210,7 @@ class FilterTests(unittest.TestCase):
{
"infraction_type": Infraction.BAN,
"infraction_reason": "",
- "infraction_duration": 10.0,
+ "infraction_duration": InfractionDuration(10),
"dm_content": "there",
"dm_embed": "",
"infraction_channel": 0