diff options
author | 2022-09-10 20:31:59 +0300 | |
---|---|---|
committer | 2022-09-10 20:31:59 +0300 | |
commit | a13417c82c73a5d30a66ae52eeb280622747ac3a (patch) | |
tree | 1d587a0f0706585c47f4c0218dc7a36716452f76 | |
parent | Add settings display for individual filters and filter lists (diff) |
Convert all setting entries to pydnatic models
In order to facilitate this change, the init of all setting entries was removed, and as such the base SettingsEntry class doesn't count as abstract anymore, which broke the MUST_SET behavior. It was changed to not raise errors for MUST_SET values of attributes which were set in the current class.
Additionally fixed a small bug where filters weren't listed properly in the list command.
Despite the pydantic manual not writing it that way, I first made the validators class methods, otherwise it gave linting errors which couldn't be ignored with noqa (because then it complained about blanket noqas).
-rw-r--r-- | bot/exts/filtering/_settings_types/bypass_roles.py | 26 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/channel_scope.py | 46 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/delete_messages.py | 14 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/enabled.py | 12 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/filter_dm.py | 10 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/infraction_and_notification.py | 68 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/ping.py | 25 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/send_alert.py | 12 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/settings_entry.py | 38 | ||||
-rw-r--r-- | bot/exts/filtering/_utils.py | 16 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 8 |
11 files changed, 125 insertions, 150 deletions
diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py index e183e0b42..a5c18cffc 100644 --- a/bot/exts/filtering/_settings_types/bypass_roles.py +++ b/bot/exts/filtering/_settings_types/bypass_roles.py @@ -1,6 +1,7 @@ -from typing import Any +from typing import ClassVar, Union from discord import Member +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -9,17 +10,18 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class RoleBypass(ValidationEntry): """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" - name = "bypass_roles" - description = "A list of role IDs or role names. Users with these roles will not trigger the filter." - - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.bypass_roles = set() - for role in entry_data: - if role.isdigit(): - self.bypass_roles.add(int(role)) - else: - self.bypass_roles.add(role) + name: ClassVar[str] = "bypass_roles" + description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter." + + bypass_roles: set[Union[int, str]] + + @validator("bypass_roles", each_item=True) + @classmethod + def maybe_cast_to_int(cls, role: str) -> Union[int, str]: + """If the string is alphanumeric, cast it to int.""" + if role.isdigit(): + return int(role) + return role def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered on this user given their roles.""" diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py index 3a95834b3..fd5206b81 100644 --- a/bot/exts/filtering/_settings_types/channel_scope.py +++ b/bot/exts/filtering/_settings_types/channel_scope.py @@ -1,21 +1,16 @@ -from typing import Any, Union +from typing import ClassVar, Union + +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry -def maybe_cast_to_int(item: str) -> Union[str, int]: - """Cast the item to int if it consists of only digit, or leave as is otherwise.""" - if item.isdigit(): - return int(item) - return item - - class ChannelScope(ValidationEntry): """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" - name = "channel_scope" - description = { + name: ClassVar[str] = "channel_scope" + description: ClassVar[str] = { "disabled_channels": "A list of channel IDs or channel names. The filter will not trigger in these channels.", "disabled_categories": ( "A list of category IDs or category names. The filter will not trigger in these categories." @@ -26,22 +21,25 @@ class ChannelScope(ValidationEntry): ) } - def __init__(self, entry_data: Any): - super().__init__(entry_data) - if entry_data["disabled_channels"]: - self.disabled_channels = set(map(maybe_cast_to_int, entry_data["disabled_channels"])) - else: - self.disabled_channels = set() + disabled_channels: set[Union[str, int]] + disabled_categories: set[Union[str, int]] + enabled_channels: set[Union[str, int]] - if entry_data["disabled_categories"]: - self.disabled_categories = set(map(maybe_cast_to_int, entry_data["disabled_categories"])) - else: - self.disabled_categories = set() + @validator("*", pre=True) + @classmethod + def init_if_sequence_none(cls, sequence: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if sequence is None: + return [] + return sequence - if entry_data["enabled_channels"]: - self.enabled_channels = set(map(maybe_cast_to_int, entry_data["enabled_channels"])) - else: - self.enabled_channels = set() + @validator("*", each_item=True) + @classmethod + def maybe_cast_items(cls, channel_or_category: str) -> Union[str, int]: + """Cast to int each value in each sequence if it is alphanumeric.""" + if channel_or_category.isdigit(): + return int(channel_or_category) + return channel_or_category def triggers_on(self, ctx: FilterContext) -> bool: """ diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py index 8de58f804..710cb0ed8 100644 --- a/bot/exts/filtering/_settings_types/delete_messages.py +++ b/bot/exts/filtering/_settings_types/delete_messages.py @@ -1,5 +1,5 @@ from contextlib import suppress -from typing import Any +from typing import ClassVar from discord.errors import NotFound @@ -10,12 +10,12 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class DeleteMessages(ActionEntry): """A setting entry which tells whether to delete the offending message(s).""" - name = "delete_messages" - description = "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." + name: ClassVar[str] = "delete_messages" + description: ClassVar[str] = ( + "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." + ) - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.delete_messages: bool = entry_data + delete_messages: bool async def action(self, ctx: FilterContext) -> None: """Delete the context message(s).""" @@ -32,4 +32,4 @@ class DeleteMessages(ActionEntry): if not isinstance(other, DeleteMessages): return NotImplemented - return DeleteMessages(self.delete_messages or other.delete_messages) + return DeleteMessages(delete_messages=self.delete_messages or other.delete_messages) diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py index 081ae02b0..3b5e3e446 100644 --- a/bot/exts/filtering/_settings_types/enabled.py +++ b/bot/exts/filtering/_settings_types/enabled.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -7,12 +7,12 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class Enabled(ValidationEntry): """A setting entry which tells whether the filter is enabled.""" - name = "enabled" - description = "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." + name: ClassVar[str] = "enabled" + description: ClassVar[str] = ( + "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." + ) - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.enabled = entry_data + enabled: bool def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter is enabled.""" diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py index 1405a636f..93022320f 100644 --- a/bot/exts/filtering/_settings_types/filter_dm.py +++ b/bot/exts/filtering/_settings_types/filter_dm.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class FilterDM(ValidationEntry): """A setting entry which tells whether to apply the filter to DMs.""" - name = "filter_dm" - description = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." + name: ClassVar[str] = "filter_dm" + description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.filter_dm = entry_data + filter_dm: bool def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered even if it was triggered in DMs.""" diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py index 4fae09f23..9c7d7b8ff 100644 --- a/bot/exts/filtering/_settings_types/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/infraction_and_notification.py @@ -1,11 +1,12 @@ from collections import namedtuple from datetime import timedelta from enum import Enum, auto -from typing import Any, Optional +from typing import ClassVar, Optional import arrow from discord import Colour, Embed from discord.errors import Forbidden +from pydantic import validator import bot from bot.constants import Channels, Guild @@ -50,8 +51,8 @@ class InfractionAndNotification(ActionEntry): Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. """ - name = "infraction_and_notification" - description = { + name: ClassVar[str] = "infraction_and_notification" + description: ClassVar[dict[str, str]] = { "infraction_type": ( "The type of infraction to issue when the filter triggers, or 'NONE'. " "If two infractions are triggered for the same message, " @@ -65,27 +66,18 @@ class InfractionAndNotification(ActionEntry): "dm_embed": "The contents of the embed to be DMed to the offending user." } - def __init__(self, entry_data: Any): - super().__init__(entry_data) + dm_content: str + dm_embed: str + infraction_type: Optional[Infraction] + infraction_reason: Optional[str] + infraction_duration: Optional[float] + superstar: Optional[superstar] = None - if entry_data["infraction_type"]: - self.infraction_type = entry_data["infraction_type"] - if isinstance(self.infraction_type, str): - self.infraction_type = Infraction[self.infraction_type.replace(" ", "_").upper()] - self.infraction_reason = entry_data["infraction_reason"] - if entry_data["infraction_duration"] is not None: - self.infraction_duration = float(entry_data["infraction_duration"]) - else: - self.infraction_duration = None - else: - self.infraction_type = Infraction.NONE - self.infraction_reason = None - self.infraction_duration = 0 - - self.dm_content = entry_data["dm_content"] - self.dm_embed = entry_data["dm_embed"] - - self._superstar = entry_data.get("superstar", None) + @validator("infraction_type", pre=True) + @classmethod + def convert_infraction_name(cls, infr_type: str) -> Infraction: + """Convert the string to an Infraction by name.""" + return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else Infraction.NONE async def action(self, ctx: FilterContext) -> None: """Send the notification to the user, and apply any specified infractions.""" @@ -115,14 +107,14 @@ class InfractionAndNotification(ActionEntry): msg_ctx.guild = bot.instance.get_guild(Guild.id) msg_ctx.author = ctx.author msg_ctx.channel = ctx.channel - if self._superstar: + if self.superstar: msg_ctx.command = bot.instance.get_command("superstarify") await msg_ctx.invoke( msg_ctx.command, ctx.author, - arrow.utcnow() + timedelta(seconds=self._superstar.duration) - if self._superstar.duration is not None else None, - reason=self._superstar.reason + arrow.utcnow() + timedelta(seconds=self.superstar.duration) + if self.superstar.duration is not None else None, + reason=self.superstar.reason ) ctx.action_descriptions.append("superstar") @@ -160,31 +152,31 @@ class InfractionAndNotification(ActionEntry): # Lower number -> higher in the hierarchy if self.infraction_type.value < other.infraction_type.value and other.infraction_type != Infraction.SUPERSTAR: - result = InfractionAndNotification(self.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = self.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result elif self.infraction_type.value > other.infraction_type.value and self.infraction_type != Infraction.SUPERSTAR: - result = InfractionAndNotification(other.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = other.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result if self.infraction_type == other.infraction_type: if self.infraction_duration is None or ( other.infraction_duration is not None and self.infraction_duration > other.infraction_duration ): - result = InfractionAndNotification(self.to_dict()) + result = self.copy() else: - result = InfractionAndNotification(other.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = other.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result # At this stage the infraction types are different, and the lower one is a superstar. if self.infraction_type.value < other.infraction_type.value: - result = InfractionAndNotification(self.to_dict()) - result._superstar = superstar(other.infraction_reason, other.infraction_duration) + result = self.copy() + result.superstar = superstar(other.infraction_reason, other.infraction_duration) else: - result = InfractionAndNotification(other.to_dict()) - result._superstar = superstar(self.infraction_reason, self.infraction_duration) + result = other.copy() + result.superstar = superstar(self.infraction_reason, self.infraction_duration) return result @staticmethod diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py index 1e0067690..8a3403b59 100644 --- a/bot/exts/filtering/_settings_types/ping.py +++ b/bot/exts/filtering/_settings_types/ping.py @@ -1,7 +1,8 @@ from functools import cache -from typing import Any +from typing import ClassVar from discord import Guild +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry @@ -10,8 +11,8 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class Ping(ActionEntry): """A setting entry which adds the appropriate pings to the alert.""" - name = "mentions" - description = { + name: ClassVar[str] = "mentions" + description: ClassVar[dict[str, str]] = { "guild_pings": ( "A list of role IDs/role names/user IDs/user names/here/everyone. " "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged." @@ -22,11 +23,16 @@ class Ping(ActionEntry): ) } - def __init__(self, entry_data: Any): - super().__init__(entry_data) + guild_pings: set[str] + dm_pings: set[str] - self.guild_pings = set(entry_data["guild_pings"]) if entry_data["guild_pings"] else set() - self.dm_pings = set(entry_data["dm_pings"]) if entry_data["dm_pings"] else set() + @validator("*") + @classmethod + def init_sequence_if_none(cls, pings: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if pings is None: + return [] + return pings async def action(self, ctx: FilterContext) -> None: """Add the stored pings to the alert message content.""" @@ -39,10 +45,7 @@ class Ping(ActionEntry): if not isinstance(other, Ping): return NotImplemented - return Ping({ - "ping_type": self.guild_pings | other.guild_pings, - "dm_ping_type": self.dm_pings | other.dm_pings - }) + return Ping(ping_type=self.guild_pings | other.guild_pings, dm_ping_type=self.dm_pings | other.dm_pings) @staticmethod @cache diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py index 6429b99ac..04e400764 100644 --- a/bot/exts/filtering/_settings_types/send_alert.py +++ b/bot/exts/filtering/_settings_types/send_alert.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry @@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class SendAlert(ActionEntry): """A setting entry which tells whether to send an alert message.""" - name = "send_alert" - description = "A boolean field. If all filters triggered set this to False, no mod-alert will be created." + name: ClassVar[str] = "send_alert" + description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created." - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.send_alert: bool = entry_data + send_alert: bool async def action(self, ctx: FilterContext) -> None: """Add the stored pings to the alert message content.""" @@ -23,4 +21,4 @@ class SendAlert(ActionEntry): if not isinstance(other, SendAlert): return NotImplemented - return SendAlert(self.send_alert or other.send_alert) + return SendAlert(send_alert=self.send_alert or other.send_alert) diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py index 2883deed8..2b3b030a0 100644 --- a/bot/exts/filtering/_settings_types/settings_entry.py +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -1,13 +1,15 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Optional +from typing import Any, ClassVar, Optional, Union + +from pydantic import BaseModel from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._utils import FieldRequiring -class SettingsEntry(FieldRequiring): +class SettingsEntry(BaseModel, FieldRequiring): """ A basic entry in the settings field appearing in every filter list and filter. @@ -16,34 +18,10 @@ class SettingsEntry(FieldRequiring): # Each subclass must define a name matching the entry name we're expecting to receive from the database. # Names must be unique across all filter lists. - name = FieldRequiring.MUST_SET_UNIQUE + name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE # Each subclass must define a description of what it does. If the data an entry type receives is comprised of # several DB fields, the value should a dictionary of field names and their descriptions. - description = FieldRequiring.MUST_SET - - @abstractmethod - def __init__(self, entry_data: Any): - super().__init__() - self._dict = {} - - def __setattr__(self, key: str, value: Any) -> None: - super().__setattr__(key, value) - if key == "_dict": - return - self._dict[key] = value - - def __eq__(self, other: SettingsEntry) -> bool: - if not isinstance(other, SettingsEntry): - return NotImplemented - return self._dict == other._dict - - def to_dict(self) -> dict[str, Any]: - """Return a dictionary representation of the entry.""" - return self._dict.copy() - - def copy(self) -> SettingsEntry: - """Return a new entry object with the same parameters.""" - return self.__class__(self.to_dict()) + description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET @classmethod def create(cls, entry_data: Optional[dict[str, Any]], *, keep_empty: bool = False) -> Optional[SettingsEntry]: @@ -58,7 +36,9 @@ class SettingsEntry(FieldRequiring): if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()): return None - return cls(entry_data) + if not isinstance(entry_data, dict): + entry_data = {cls.name: entry_data} + return cls(**entry_data) class ValidationEntry(SettingsEntry): diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index 14c6bd13b..158f1e7bd 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -84,14 +84,18 @@ class FieldRequiring(ABC): ... def __init_subclass__(cls, **kwargs): + def inherited(attr: str) -> bool: + """True if `attr` was inherited from a parent class.""" + for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. + if hasattr(parent, attr): # The attribute was inherited. + return True + return False + # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it. if inspect.isabstract(cls): for attribute in dir(cls): if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE: - for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. - if hasattr(parent, attribute): # The attribute was inherited. - break - else: + if not inherited(attribute): # A new attribute with the value MUST_SET_UNIQUE. FieldRequiring.__unique_attributes[cls][attribute] = set() return @@ -100,9 +104,9 @@ class FieldRequiring(ABC): if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"): continue value = getattr(cls, attribute) - if value is FieldRequiring.MUST_SET: + if value is FieldRequiring.MUST_SET and inherited(attribute): raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}") - elif value is FieldRequiring.MUST_SET_UNIQUE: + elif value is FieldRequiring.MUST_SET_UNIQUE and inherited(attribute): raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}") else: # Check if the value needs to be unique. diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 2a24769d0..630474c13 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -215,14 +215,14 @@ class Filtering(Cog): default_setting_values = {} for type_ in ("actions", "validations"): for _, setting in filter_list.defaults[list_type][type_].items(): - default_setting_values.update(to_serializable(setting.to_dict())) + default_setting_values.update(to_serializable(setting.dict())) # Get the filter's overridden settings overrides_values = {} for settings in (filter_.actions, filter_.validations): if settings: for _, setting in settings.items(): - overrides_values.update(to_serializable(setting.to_dict())) + overrides_values.update(to_serializable(setting.dict())) # Combine them. It's done in this way to preserve field order, since the filter won't have all settings. total_values = {} @@ -345,7 +345,7 @@ class Filtering(Cog): setting_values = {} for type_ in ("actions", "validations"): for _, setting in list_defaults[type_].items(): - setting_values.update(to_serializable(setting.to_dict())) + setting_values.update(to_serializable(setting.dict())) embed = self._build_embed_from_dict(setting_values) # Use the class's docstring, and ignore single newlines. @@ -458,7 +458,7 @@ class Filtering(Cog): await ctx.send(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.") return - lines = list(map(str, type_filters)) + lines = list(map(str, type_filters.values())) log.trace(f"Sending a list of {len(lines)} filters.") embed = Embed(colour=Colour.blue()) |