diff options
| -rw-r--r-- | bot/exts/filtering/_settings_types/actions/infraction_and_notification.py | 52 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/filter.py | 10 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/search.py | 4 | ||||
| -rw-r--r-- | bot/exts/filtering/_utils.py | 66 | ||||
| -rw-r--r-- | bot/exts/filtering/filtering.py | 8 | ||||
| -rw-r--r-- | tests/bot/exts/filtering/test_settings_entries.py | 16 | 
6 files changed, 127 insertions, 29 deletions
| diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py index 5ae4901b6..e3df47029 100644 --- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -1,9 +1,9 @@ -from datetime import timedelta  from enum import Enum, auto  from typing import ClassVar  import arrow  import discord.abc +from dateutil.relativedelta import relativedelta  from discord import Colour, Embed, Member, User  from discord.errors import Forbidden  from pydantic import validator @@ -15,7 +15,8 @@ import bot as bot_module  from bot.constants import Channels  from bot.exts.filtering._filter_context import FilterContext  from bot.exts.filtering._settings_types.settings_entry import ActionEntry -from bot.exts.filtering._utils import FakeContext +from bot.exts.filtering._utils import CustomIOField, FakeContext +from bot.utils.time import humanize_delta, parse_duration_string, relativedelta_to_timedelta  log = get_logger(__name__) @@ -31,6 +32,38 @@ passive_form = {  } +class InfractionDuration(CustomIOField): +    """A field that converts a string to a duration and presents it in a human-readable format.""" + +    @classmethod +    def process_value(cls, v: str | relativedelta) -> relativedelta: +        """ +        Transform the given string into a relativedelta. + +        Raise a ValueError if the conversion is not possible. +        """ +        if isinstance(v, relativedelta): +            return v + +        try: +            v = float(v) +        except ValueError:  # Not a float. +            if not (delta := parse_duration_string(v)): +                raise ValueError(f"`{v}` is not a valid duration string.") +        else: +            delta = relativedelta(seconds=float(v)).normalized() + +        return delta + +    def serialize(self) -> float: +        """The serialized value is the total number of seconds this duration represents.""" +        return relativedelta_to_timedelta(self.value).total_seconds() + +    def __str__(self): +        """Represent the stored duration in a human-readable format.""" +        return humanize_delta(self.value, max_units=2) if self.value else "Permanent" + +  class Infraction(Enum):      """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" @@ -53,7 +86,7 @@ class Infraction(Enum):          message: discord.Message,          channel: discord.abc.GuildChannel | discord.DMChannel,          alerts_channel: discord.TextChannel, -        duration: float, +        duration: InfractionDuration,          reason: str      ) -> None:          """Invokes the command matching the infraction name.""" @@ -72,7 +105,7 @@ class Infraction(Enum):          if self.name in ("KICK", "WARNING", "WATCH", "NOTE"):              await command(ctx, user, reason=reason or None)          else: -            duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None +            duration = arrow.utcnow().datetime + duration.value if duration.value else None              await command(ctx, user, duration, reason=reason or None) @@ -91,7 +124,10 @@ class InfractionAndNotification(ActionEntry):              "the harsher one will be applied (by type or duration).\n\n"              "Valid infraction types in order of harshness: "          ) + ", ".join(infraction.name for infraction in Infraction), -        "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.", +        "infraction_duration": ( +            "How long the infraction should last for in seconds. 0 for permanent. " +            "Also supports durations as in an infraction invocation (such as `10d`)." +        ),          "infraction_reason": "The reason delivered with the infraction.",          "infraction_channel": (              "The channel ID in which to invoke the infraction (and send the confirmation message). " @@ -106,7 +142,7 @@ class InfractionAndNotification(ActionEntry):      dm_embed: str      infraction_type: Infraction      infraction_reason: str -    infraction_duration: float +    infraction_duration: InfractionDuration      infraction_channel: int      @validator("infraction_type", pre=True) @@ -184,8 +220,10 @@ class InfractionAndNotification(ActionEntry):              result = other.copy()              other = self          else: +            now = arrow.utcnow().datetime              if self.infraction_duration is None or ( -                other.infraction_duration is not None and self.infraction_duration > other.infraction_duration +                other.infraction_duration is not None +                and now + self.infraction_duration.value > now + other.infraction_duration.value              ):                  result = self.copy()              else: diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 1ef25f17a..5b23b71e9 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -33,7 +33,7 @@ def build_filter_repr_dict(      default_setting_values = {}      for settings_group in filter_list[list_type].defaults:          for _, setting in settings_group.items(): -            default_setting_values.update(to_serializable(setting.dict())) +            default_setting_values.update(to_serializable(setting.dict(), ui_repr=True))      # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.      total_values = {} @@ -434,10 +434,10 @@ def description_and_settings_converter(      return description, settings, filter_settings -def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]: -    """Get a serializable version of the filter's overrides.""" +def filter_overrides_for_ui(filter_: Filter) -> tuple[dict, dict]: +    """Get the filter's overrides in a format that can be displayed in the UI."""      overrides_values, extra_fields_overrides = filter_.overrides -    return to_serializable(overrides_values), to_serializable(extra_fields_overrides) +    return to_serializable(overrides_values, ui_repr=True), to_serializable(extra_fields_overrides, ui_repr=True)  def template_settings( @@ -461,4 +461,4 @@ def template_settings(          raise BadArgument(              f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}"          ) -    return filter_serializable_overrides(filter_) +    return filter_.overrides diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py index d553c28ea..dba7f3cea 100644 --- a/bot/exts/filtering/_ui/search.py +++ b/bot/exts/filtering/_ui/search.py @@ -10,7 +10,7 @@ from discord.ext.commands import BadArgument  from bot.exts.filtering._filter_lists import FilterList, ListType  from bot.exts.filtering._filters.filter import Filter  from bot.exts.filtering._settings_types.settings_entry import SettingsEntry -from bot.exts.filtering._ui.filter import filter_serializable_overrides +from bot.exts.filtering._ui.filter import filter_overrides_for_ui  from bot.exts.filtering._ui.ui import (      COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value,      populate_embed_from_dict @@ -114,7 +114,7 @@ def template_settings(      if filter_type and not isinstance(filter_, filter_type):          raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") -    settings, filter_settings = filter_serializable_overrides(filter_) +    settings, filter_settings = filter_overrides_for_ui(filter_)      return settings, filter_settings, type(filter_) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index da433330f..a43233f20 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import importlib  import importlib.util  import inspect @@ -12,6 +14,7 @@ from typing import Any, Iterable, TypeVar, Union, get_args, get_origin  import discord  import regex  from discord.ext.commands import Command +from typing_extensions import Self  import bot  from bot.bot import Bot @@ -24,6 +27,8 @@ ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIAT  T = TypeVar('T') +Serializable = Union[bool, int, float, str, list, dict, None] +  def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]:      """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" @@ -62,8 +67,13 @@ def past_tense(word: str) -> str:      return word + "ed" -def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]: -    """Convert the item into an object that can be converted to JSON.""" +def to_serializable(item: Any, *, ui_repr: bool = False) -> Serializable: +    """ +    Convert the item into an object that can be converted to JSON. + +    `ui_repr` dictates whether to use the UI representation of `CustomIOField` instances (if any) +    or the DB-oriented representation. +    """      if isinstance(item, (bool, int, float, str, type(None))):          return item      if isinstance(item, dict): @@ -71,10 +81,12 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]          for key, value in item.items():              if not isinstance(key, (bool, int, float, str, type(None))):                  key = str(key) -            result[key] = to_serializable(value) +            result[key] = to_serializable(value, ui_repr=ui_repr)          return result      if isinstance(item, Iterable): -        return [to_serializable(subitem) for subitem in item] +        return [to_serializable(subitem, ui_repr=ui_repr) for subitem in item] +    if not ui_repr and hasattr(item, "serialize"): +        return item.serialize()      return str(item) @@ -222,3 +234,49 @@ class FakeContext:      async def send(self, *args, **kwargs) -> discord.Message:          """A wrapper for channel.send."""          return await self.channel.send(*args, **kwargs) + + +class CustomIOField: +    """ +    A class to be used as a data type in SettingEntry subclasses. + +    Its subclasses can have custom methods to read and represent the value, which will be used by the UI. +    """ + +    def __init__(self, value: Any): +        self.value = self.process_value(value) + +    @classmethod +    def __get_validators__(cls): +        """Boilerplate for Pydantic.""" +        yield cls.validate + +    @classmethod +    def validate(cls, v: Any) -> Self: +        """Takes the given value and returns a class instance with that value.""" +        if isinstance(v, CustomIOField): +            return cls(v.value) + +        return cls(v) + +    def __eq__(self, other: CustomIOField): +        if not isinstance(other, CustomIOField): +            return NotImplemented +        return self.value == other.value + +    @classmethod +    def process_value(cls, v: str) -> Any: +        """ +        Perform any necessary transformations before the value is stored in a new instance. + +        Override this method to customize the input behavior. +        """ +        return v + +    def serialize(self) -> Serializable: +        """Override this method to customize how the value will be serialized.""" +        return self.value + +    def __str__(self): +        """Override this method to change how the value will be displayed by the UI.""" +        return self.value diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 58d2f125e..8fd4ddb13 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -31,7 +31,7 @@ from bot.exts.filtering._filters.filter import Filter, UniqueFilter  from bot.exts.filtering._settings import ActionSettings  from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction  from bot.exts.filtering._ui.filter import ( -    build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict +    build_filter_repr_dict, description_and_settings_converter, filter_overrides_for_ui, populate_embed_from_dict  )  from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter  from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter @@ -383,7 +383,7 @@ class Filtering(Cog):              return          filter_, filter_list, list_type = result -        overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_) +        overrides_values, extra_fields_overrides = filter_overrides_for_ui(filter_)          all_settings_repr_dict = build_filter_repr_dict(              filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -493,7 +493,7 @@ class Filtering(Cog):              return          filter_, filter_list, list_type = result          filter_type = type(filter_) -        settings, filter_settings = filter_serializable_overrides(filter_) +        settings, filter_settings = filter_overrides_for_ui(filter_)          description, new_settings, new_filter_settings = description_and_settings_converter(              filter_list,              list_type, filter_type, @@ -734,7 +734,7 @@ class Filtering(Cog):          setting_values = {}          for settings_group in filter_list[list_type].defaults:              for _, setting in settings_group.items(): -                setting_values.update(to_serializable(setting.dict())) +                setting_values.update(to_serializable(setting.dict(), ui_repr=True))          embed = Embed(colour=Colour.blue())          populate_embed_from_dict(embed, setting_values) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index c5f0152b0..3ae0b5ab5 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -1,7 +1,9 @@  import unittest  from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( +    Infraction, InfractionAndNotification, InfractionDuration +)  from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass  from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope  from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM @@ -154,7 +156,7 @@ class FilterTests(unittest.TestCase):          infraction1 = InfractionAndNotification(              infraction_type="TIMEOUT",              infraction_reason="hi", -            infraction_duration=10, +            infraction_duration=InfractionDuration(10),              dm_content="how",              dm_embed="what is",              infraction_channel=0 @@ -162,7 +164,7 @@ class FilterTests(unittest.TestCase):          infraction2 = InfractionAndNotification(              infraction_type="TIMEOUT",              infraction_reason="there", -            infraction_duration=20, +            infraction_duration=InfractionDuration(20),              dm_content="are you",              dm_embed="your name",              infraction_channel=0 @@ -175,7 +177,7 @@ class FilterTests(unittest.TestCase):              {                  "infraction_type": Infraction.TIMEOUT,                  "infraction_reason": "there", -                "infraction_duration": 20.0, +                "infraction_duration": InfractionDuration(20.0),                  "dm_content": "are you",                  "dm_embed": "your name",                  "infraction_channel": 0 @@ -187,7 +189,7 @@ class FilterTests(unittest.TestCase):          infraction1 = InfractionAndNotification(              infraction_type="TIMEOUT",              infraction_reason="hi", -            infraction_duration=20, +            infraction_duration=InfractionDuration(20),              dm_content="",              dm_embed="",              infraction_channel=0 @@ -195,7 +197,7 @@ class FilterTests(unittest.TestCase):          infraction2 = InfractionAndNotification(              infraction_type="BAN",              infraction_reason="", -            infraction_duration=10, +            infraction_duration=InfractionDuration(10),              dm_content="there",              dm_embed="",              infraction_channel=0 @@ -208,7 +210,7 @@ class FilterTests(unittest.TestCase):              {                  "infraction_type": Infraction.BAN,                  "infraction_reason": "", -                "infraction_duration": 10.0, +                "infraction_duration": InfractionDuration(10),                  "dm_content": "there",                  "dm_embed": "",                  "infraction_channel": 0 | 
