diff options
-rw-r--r-- | bot/exts/filtering/_settings_types/actions/infraction_and_notification.py | 52 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/filter.py | 10 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/search.py | 4 | ||||
-rw-r--r-- | bot/exts/filtering/_utils.py | 66 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 8 | ||||
-rw-r--r-- | tests/bot/exts/filtering/test_settings_entries.py | 16 |
6 files changed, 127 insertions, 29 deletions
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py index 5ae4901b6..e3df47029 100644 --- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -1,9 +1,9 @@ -from datetime import timedelta from enum import Enum, auto from typing import ClassVar import arrow import discord.abc +from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Member, User from discord.errors import Forbidden from pydantic import validator @@ -15,7 +15,8 @@ import bot as bot_module from bot.constants import Channels from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry -from bot.exts.filtering._utils import FakeContext +from bot.exts.filtering._utils import CustomIOField, FakeContext +from bot.utils.time import humanize_delta, parse_duration_string, relativedelta_to_timedelta log = get_logger(__name__) @@ -31,6 +32,38 @@ passive_form = { } +class InfractionDuration(CustomIOField): + """A field that converts a string to a duration and presents it in a human-readable format.""" + + @classmethod + def process_value(cls, v: str | relativedelta) -> relativedelta: + """ + Transform the given string into a relativedelta. + + Raise a ValueError if the conversion is not possible. + """ + if isinstance(v, relativedelta): + return v + + try: + v = float(v) + except ValueError: # Not a float. + if not (delta := parse_duration_string(v)): + raise ValueError(f"`{v}` is not a valid duration string.") + else: + delta = relativedelta(seconds=float(v)).normalized() + + return delta + + def serialize(self) -> float: + """The serialized value is the total number of seconds this duration represents.""" + return relativedelta_to_timedelta(self.value).total_seconds() + + def __str__(self): + """Represent the stored duration in a human-readable format.""" + return humanize_delta(self.value, max_units=2) if self.value else "Permanent" + + class Infraction(Enum): """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" @@ -53,7 +86,7 @@ class Infraction(Enum): message: discord.Message, channel: discord.abc.GuildChannel | discord.DMChannel, alerts_channel: discord.TextChannel, - duration: float, + duration: InfractionDuration, reason: str ) -> None: """Invokes the command matching the infraction name.""" @@ -72,7 +105,7 @@ class Infraction(Enum): if self.name in ("KICK", "WARNING", "WATCH", "NOTE"): await command(ctx, user, reason=reason or None) else: - duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None + duration = arrow.utcnow().datetime + duration.value if duration.value else None await command(ctx, user, duration, reason=reason or None) @@ -91,7 +124,10 @@ class InfractionAndNotification(ActionEntry): "the harsher one will be applied (by type or duration).\n\n" "Valid infraction types in order of harshness: " ) + ", ".join(infraction.name for infraction in Infraction), - "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.", + "infraction_duration": ( + "How long the infraction should last for in seconds. 0 for permanent. " + "Also supports durations as in an infraction invocation (such as `10d`)." + ), "infraction_reason": "The reason delivered with the infraction.", "infraction_channel": ( "The channel ID in which to invoke the infraction (and send the confirmation message). " @@ -106,7 +142,7 @@ class InfractionAndNotification(ActionEntry): dm_embed: str infraction_type: Infraction infraction_reason: str - infraction_duration: float + infraction_duration: InfractionDuration infraction_channel: int @validator("infraction_type", pre=True) @@ -184,8 +220,10 @@ class InfractionAndNotification(ActionEntry): result = other.copy() other = self else: + now = arrow.utcnow().datetime if self.infraction_duration is None or ( - other.infraction_duration is not None and self.infraction_duration > other.infraction_duration + other.infraction_duration is not None + and now + self.infraction_duration.value > now + other.infraction_duration.value ): result = self.copy() else: diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 1ef25f17a..5b23b71e9 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -33,7 +33,7 @@ def build_filter_repr_dict( default_setting_values = {} for settings_group in filter_list[list_type].defaults: for _, setting in settings_group.items(): - default_setting_values.update(to_serializable(setting.dict())) + default_setting_values.update(to_serializable(setting.dict(), ui_repr=True)) # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings. total_values = {} @@ -434,10 +434,10 @@ def description_and_settings_converter( return description, settings, filter_settings -def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]: - """Get a serializable version of the filter's overrides.""" +def filter_overrides_for_ui(filter_: Filter) -> tuple[dict, dict]: + """Get the filter's overrides in a format that can be displayed in the UI.""" overrides_values, extra_fields_overrides = filter_.overrides - return to_serializable(overrides_values), to_serializable(extra_fields_overrides) + return to_serializable(overrides_values, ui_repr=True), to_serializable(extra_fields_overrides, ui_repr=True) def template_settings( @@ -461,4 +461,4 @@ def template_settings( raise BadArgument( f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}" ) - return filter_serializable_overrides(filter_) + return filter_.overrides diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py index d553c28ea..dba7f3cea 100644 --- a/bot/exts/filtering/_ui/search.py +++ b/bot/exts/filtering/_ui/search.py @@ -10,7 +10,7 @@ from discord.ext.commands import BadArgument from bot.exts.filtering._filter_lists import FilterList, ListType from bot.exts.filtering._filters.filter import Filter from bot.exts.filtering._settings_types.settings_entry import SettingsEntry -from bot.exts.filtering._ui.filter import filter_serializable_overrides +from bot.exts.filtering._ui.filter import filter_overrides_for_ui from bot.exts.filtering._ui.ui import ( COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value, populate_embed_from_dict @@ -114,7 +114,7 @@ def template_settings( if filter_type and not isinstance(filter_, filter_type): raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") - settings, filter_settings = filter_serializable_overrides(filter_) + settings, filter_settings = filter_overrides_for_ui(filter_) return settings, filter_settings, type(filter_) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index da433330f..a43233f20 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import importlib.util import inspect @@ -12,6 +14,7 @@ from typing import Any, Iterable, TypeVar, Union, get_args, get_origin import discord import regex from discord.ext.commands import Command +from typing_extensions import Self import bot from bot.bot import Bot @@ -24,6 +27,8 @@ ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIAT T = TypeVar('T') +Serializable = Union[bool, int, float, str, list, dict, None] + def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]: """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" @@ -62,8 +67,13 @@ def past_tense(word: str) -> str: return word + "ed" -def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]: - """Convert the item into an object that can be converted to JSON.""" +def to_serializable(item: Any, *, ui_repr: bool = False) -> Serializable: + """ + Convert the item into an object that can be converted to JSON. + + `ui_repr` dictates whether to use the UI representation of `CustomIOField` instances (if any) + or the DB-oriented representation. + """ if isinstance(item, (bool, int, float, str, type(None))): return item if isinstance(item, dict): @@ -71,10 +81,12 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None] for key, value in item.items(): if not isinstance(key, (bool, int, float, str, type(None))): key = str(key) - result[key] = to_serializable(value) + result[key] = to_serializable(value, ui_repr=ui_repr) return result if isinstance(item, Iterable): - return [to_serializable(subitem) for subitem in item] + return [to_serializable(subitem, ui_repr=ui_repr) for subitem in item] + if not ui_repr and hasattr(item, "serialize"): + return item.serialize() return str(item) @@ -222,3 +234,49 @@ class FakeContext: async def send(self, *args, **kwargs) -> discord.Message: """A wrapper for channel.send.""" return await self.channel.send(*args, **kwargs) + + +class CustomIOField: + """ + A class to be used as a data type in SettingEntry subclasses. + + Its subclasses can have custom methods to read and represent the value, which will be used by the UI. + """ + + def __init__(self, value: Any): + self.value = self.process_value(value) + + @classmethod + def __get_validators__(cls): + """Boilerplate for Pydantic.""" + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Self: + """Takes the given value and returns a class instance with that value.""" + if isinstance(v, CustomIOField): + return cls(v.value) + + return cls(v) + + def __eq__(self, other: CustomIOField): + if not isinstance(other, CustomIOField): + return NotImplemented + return self.value == other.value + + @classmethod + def process_value(cls, v: str) -> Any: + """ + Perform any necessary transformations before the value is stored in a new instance. + + Override this method to customize the input behavior. + """ + return v + + def serialize(self) -> Serializable: + """Override this method to customize how the value will be serialized.""" + return self.value + + def __str__(self): + """Override this method to change how the value will be displayed by the UI.""" + return self.value diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 58d2f125e..8fd4ddb13 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -31,7 +31,7 @@ from bot.exts.filtering._filters.filter import Filter, UniqueFilter from bot.exts.filtering._settings import ActionSettings from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction from bot.exts.filtering._ui.filter import ( - build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict + build_filter_repr_dict, description_and_settings_converter, filter_overrides_for_ui, populate_embed_from_dict ) from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter @@ -383,7 +383,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result - overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_) + overrides_values, extra_fields_overrides = filter_overrides_for_ui(filter_) all_settings_repr_dict = build_filter_repr_dict( filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -493,7 +493,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result filter_type = type(filter_) - settings, filter_settings = filter_serializable_overrides(filter_) + settings, filter_settings = filter_overrides_for_ui(filter_) description, new_settings, new_filter_settings = description_and_settings_converter( filter_list, list_type, filter_type, @@ -734,7 +734,7 @@ class Filtering(Cog): setting_values = {} for settings_group in filter_list[list_type].defaults: for _, setting in settings_group.items(): - setting_values.update(to_serializable(setting.dict())) + setting_values.update(to_serializable(setting.dict(), ui_repr=True)) embed = Embed(colour=Colour.blue()) populate_embed_from_dict(embed, setting_values) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index c5f0152b0..3ae0b5ab5 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -1,7 +1,9 @@ import unittest from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( + Infraction, InfractionAndNotification, InfractionDuration +) from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM @@ -154,7 +156,7 @@ class FilterTests(unittest.TestCase): infraction1 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="hi", - infraction_duration=10, + infraction_duration=InfractionDuration(10), dm_content="how", dm_embed="what is", infraction_channel=0 @@ -162,7 +164,7 @@ class FilterTests(unittest.TestCase): infraction2 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="there", - infraction_duration=20, + infraction_duration=InfractionDuration(20), dm_content="are you", dm_embed="your name", infraction_channel=0 @@ -175,7 +177,7 @@ class FilterTests(unittest.TestCase): { "infraction_type": Infraction.TIMEOUT, "infraction_reason": "there", - "infraction_duration": 20.0, + "infraction_duration": InfractionDuration(20.0), "dm_content": "are you", "dm_embed": "your name", "infraction_channel": 0 @@ -187,7 +189,7 @@ class FilterTests(unittest.TestCase): infraction1 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="hi", - infraction_duration=20, + infraction_duration=InfractionDuration(20), dm_content="", dm_embed="", infraction_channel=0 @@ -195,7 +197,7 @@ class FilterTests(unittest.TestCase): infraction2 = InfractionAndNotification( infraction_type="BAN", infraction_reason="", - infraction_duration=10, + infraction_duration=InfractionDuration(10), dm_content="there", dm_embed="", infraction_channel=0 @@ -208,7 +210,7 @@ class FilterTests(unittest.TestCase): { "infraction_type": Infraction.BAN, "infraction_reason": "", - "infraction_duration": 10.0, + "infraction_duration": InfractionDuration(10), "dm_content": "there", "dm_embed": "", "infraction_channel": 0 |