diff options
-rw-r--r-- | bot/exts/filtering/_settings_types/bypass_roles.py | 26 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/channel_scope.py | 46 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/delete_messages.py | 14 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/enabled.py | 12 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/filter_dm.py | 10 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/infraction_and_notification.py | 68 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/ping.py | 25 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/send_alert.py | 12 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/settings_entry.py | 38 | ||||
-rw-r--r-- | bot/exts/filtering/_utils.py | 16 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 8 |
11 files changed, 125 insertions, 150 deletions
diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py index e183e0b42..a5c18cffc 100644 --- a/bot/exts/filtering/_settings_types/bypass_roles.py +++ b/bot/exts/filtering/_settings_types/bypass_roles.py @@ -1,6 +1,7 @@ -from typing import Any +from typing import ClassVar, Union from discord import Member +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -9,17 +10,18 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class RoleBypass(ValidationEntry): """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" - name = "bypass_roles" - description = "A list of role IDs or role names. Users with these roles will not trigger the filter." - - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.bypass_roles = set() - for role in entry_data: - if role.isdigit(): - self.bypass_roles.add(int(role)) - else: - self.bypass_roles.add(role) + name: ClassVar[str] = "bypass_roles" + description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter." + + bypass_roles: set[Union[int, str]] + + @validator("bypass_roles", each_item=True) + @classmethod + def maybe_cast_to_int(cls, role: str) -> Union[int, str]: + """If the string is alphanumeric, cast it to int.""" + if role.isdigit(): + return int(role) + return role def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered on this user given their roles.""" diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py index 3a95834b3..fd5206b81 100644 --- a/bot/exts/filtering/_settings_types/channel_scope.py +++ b/bot/exts/filtering/_settings_types/channel_scope.py @@ -1,21 +1,16 @@ -from typing import Any, Union +from typing import ClassVar, Union + +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry -def maybe_cast_to_int(item: str) -> Union[str, int]: - """Cast the item to int if it consists of only digit, or leave as is otherwise.""" - if item.isdigit(): - return int(item) - return item - - class ChannelScope(ValidationEntry): """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" - name = "channel_scope" - description = { + name: ClassVar[str] = "channel_scope" + description: ClassVar[str] = { "disabled_channels": "A list of channel IDs or channel names. The filter will not trigger in these channels.", "disabled_categories": ( "A list of category IDs or category names. The filter will not trigger in these categories." @@ -26,22 +21,25 @@ class ChannelScope(ValidationEntry): ) } - def __init__(self, entry_data: Any): - super().__init__(entry_data) - if entry_data["disabled_channels"]: - self.disabled_channels = set(map(maybe_cast_to_int, entry_data["disabled_channels"])) - else: - self.disabled_channels = set() + disabled_channels: set[Union[str, int]] + disabled_categories: set[Union[str, int]] + enabled_channels: set[Union[str, int]] - if entry_data["disabled_categories"]: - self.disabled_categories = set(map(maybe_cast_to_int, entry_data["disabled_categories"])) - else: - self.disabled_categories = set() + @validator("*", pre=True) + @classmethod + def init_if_sequence_none(cls, sequence: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if sequence is None: + return [] + return sequence - if entry_data["enabled_channels"]: - self.enabled_channels = set(map(maybe_cast_to_int, entry_data["enabled_channels"])) - else: - self.enabled_channels = set() + @validator("*", each_item=True) + @classmethod + def maybe_cast_items(cls, channel_or_category: str) -> Union[str, int]: + """Cast to int each value in each sequence if it is alphanumeric.""" + if channel_or_category.isdigit(): + return int(channel_or_category) + return channel_or_category def triggers_on(self, ctx: FilterContext) -> bool: """ diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py index 8de58f804..710cb0ed8 100644 --- a/bot/exts/filtering/_settings_types/delete_messages.py +++ b/bot/exts/filtering/_settings_types/delete_messages.py @@ -1,5 +1,5 @@ from contextlib import suppress -from typing import Any +from typing import ClassVar from discord.errors import NotFound @@ -10,12 +10,12 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class DeleteMessages(ActionEntry): """A setting entry which tells whether to delete the offending message(s).""" - name = "delete_messages" - description = "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." + name: ClassVar[str] = "delete_messages" + description: ClassVar[str] = ( + "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." + ) - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.delete_messages: bool = entry_data + delete_messages: bool async def action(self, ctx: FilterContext) -> None: """Delete the context message(s).""" @@ -32,4 +32,4 @@ class DeleteMessages(ActionEntry): if not isinstance(other, DeleteMessages): return NotImplemented - return DeleteMessages(self.delete_messages or other.delete_messages) + return DeleteMessages(delete_messages=self.delete_messages or other.delete_messages) diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py index 081ae02b0..3b5e3e446 100644 --- a/bot/exts/filtering/_settings_types/enabled.py +++ b/bot/exts/filtering/_settings_types/enabled.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -7,12 +7,12 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class Enabled(ValidationEntry): """A setting entry which tells whether the filter is enabled.""" - name = "enabled" - description = "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." + name: ClassVar[str] = "enabled" + description: ClassVar[str] = ( + "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." + ) - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.enabled = entry_data + enabled: bool def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter is enabled.""" diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py index 1405a636f..93022320f 100644 --- a/bot/exts/filtering/_settings_types/filter_dm.py +++ b/bot/exts/filtering/_settings_types/filter_dm.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry @@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry class FilterDM(ValidationEntry): """A setting entry which tells whether to apply the filter to DMs.""" - name = "filter_dm" - description = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." + name: ClassVar[str] = "filter_dm" + description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.filter_dm = entry_data + filter_dm: bool def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered even if it was triggered in DMs.""" diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py index 4fae09f23..9c7d7b8ff 100644 --- a/bot/exts/filtering/_settings_types/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/infraction_and_notification.py @@ -1,11 +1,12 @@ from collections import namedtuple from datetime import timedelta from enum import Enum, auto -from typing import Any, Optional +from typing import ClassVar, Optional import arrow from discord import Colour, Embed from discord.errors import Forbidden +from pydantic import validator import bot from bot.constants import Channels, Guild @@ -50,8 +51,8 @@ class InfractionAndNotification(ActionEntry): Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. """ - name = "infraction_and_notification" - description = { + name: ClassVar[str] = "infraction_and_notification" + description: ClassVar[dict[str, str]] = { "infraction_type": ( "The type of infraction to issue when the filter triggers, or 'NONE'. " "If two infractions are triggered for the same message, " @@ -65,27 +66,18 @@ class InfractionAndNotification(ActionEntry): "dm_embed": "The contents of the embed to be DMed to the offending user." } - def __init__(self, entry_data: Any): - super().__init__(entry_data) + dm_content: str + dm_embed: str + infraction_type: Optional[Infraction] + infraction_reason: Optional[str] + infraction_duration: Optional[float] + superstar: Optional[superstar] = None - if entry_data["infraction_type"]: - self.infraction_type = entry_data["infraction_type"] - if isinstance(self.infraction_type, str): - self.infraction_type = Infraction[self.infraction_type.replace(" ", "_").upper()] - self.infraction_reason = entry_data["infraction_reason"] - if entry_data["infraction_duration"] is not None: - self.infraction_duration = float(entry_data["infraction_duration"]) - else: - self.infraction_duration = None - else: - self.infraction_type = Infraction.NONE - self.infraction_reason = None - self.infraction_duration = 0 - - self.dm_content = entry_data["dm_content"] - self.dm_embed = entry_data["dm_embed"] - - self._superstar = entry_data.get("superstar", None) + @validator("infraction_type", pre=True) + @classmethod + def convert_infraction_name(cls, infr_type: str) -> Infraction: + """Convert the string to an Infraction by name.""" + return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else Infraction.NONE async def action(self, ctx: FilterContext) -> None: """Send the notification to the user, and apply any specified infractions.""" @@ -115,14 +107,14 @@ class InfractionAndNotification(ActionEntry): msg_ctx.guild = bot.instance.get_guild(Guild.id) msg_ctx.author = ctx.author msg_ctx.channel = ctx.channel - if self._superstar: + if self.superstar: msg_ctx.command = bot.instance.get_command("superstarify") await msg_ctx.invoke( msg_ctx.command, ctx.author, - arrow.utcnow() + timedelta(seconds=self._superstar.duration) - if self._superstar.duration is not None else None, - reason=self._superstar.reason + arrow.utcnow() + timedelta(seconds=self.superstar.duration) + if self.superstar.duration is not None else None, + reason=self.superstar.reason ) ctx.action_descriptions.append("superstar") @@ -160,31 +152,31 @@ class InfractionAndNotification(ActionEntry): # Lower number -> higher in the hierarchy if self.infraction_type.value < other.infraction_type.value and other.infraction_type != Infraction.SUPERSTAR: - result = InfractionAndNotification(self.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = self.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result elif self.infraction_type.value > other.infraction_type.value and self.infraction_type != Infraction.SUPERSTAR: - result = InfractionAndNotification(other.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = other.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result if self.infraction_type == other.infraction_type: if self.infraction_duration is None or ( other.infraction_duration is not None and self.infraction_duration > other.infraction_duration ): - result = InfractionAndNotification(self.to_dict()) + result = self.copy() else: - result = InfractionAndNotification(other.to_dict()) - result._superstar = self._merge_superstars(self._superstar, other._superstar) + result = other.copy() + result.superstar = self._merge_superstars(self.superstar, other.superstar) return result # At this stage the infraction types are different, and the lower one is a superstar. if self.infraction_type.value < other.infraction_type.value: - result = InfractionAndNotification(self.to_dict()) - result._superstar = superstar(other.infraction_reason, other.infraction_duration) + result = self.copy() + result.superstar = superstar(other.infraction_reason, other.infraction_duration) else: - result = InfractionAndNotification(other.to_dict()) - result._superstar = superstar(self.infraction_reason, self.infraction_duration) + result = other.copy() + result.superstar = superstar(self.infraction_reason, self.infraction_duration) return result @staticmethod diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py index 1e0067690..8a3403b59 100644 --- a/bot/exts/filtering/_settings_types/ping.py +++ b/bot/exts/filtering/_settings_types/ping.py @@ -1,7 +1,8 @@ from functools import cache -from typing import Any +from typing import ClassVar from discord import Guild +from pydantic import validator from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry @@ -10,8 +11,8 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class Ping(ActionEntry): """A setting entry which adds the appropriate pings to the alert.""" - name = "mentions" - description = { + name: ClassVar[str] = "mentions" + description: ClassVar[dict[str, str]] = { "guild_pings": ( "A list of role IDs/role names/user IDs/user names/here/everyone. " "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged." @@ -22,11 +23,16 @@ class Ping(ActionEntry): ) } - def __init__(self, entry_data: Any): - super().__init__(entry_data) + guild_pings: set[str] + dm_pings: set[str] - self.guild_pings = set(entry_data["guild_pings"]) if entry_data["guild_pings"] else set() - self.dm_pings = set(entry_data["dm_pings"]) if entry_data["dm_pings"] else set() + @validator("*") + @classmethod + def init_sequence_if_none(cls, pings: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if pings is None: + return [] + return pings async def action(self, ctx: FilterContext) -> None: """Add the stored pings to the alert message content.""" @@ -39,10 +45,7 @@ class Ping(ActionEntry): if not isinstance(other, Ping): return NotImplemented - return Ping({ - "ping_type": self.guild_pings | other.guild_pings, - "dm_ping_type": self.dm_pings | other.dm_pings - }) + return Ping(ping_type=self.guild_pings | other.guild_pings, dm_ping_type=self.dm_pings | other.dm_pings) @staticmethod @cache diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py index 6429b99ac..04e400764 100644 --- a/bot/exts/filtering/_settings_types/send_alert.py +++ b/bot/exts/filtering/_settings_types/send_alert.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import ClassVar from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry @@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry class SendAlert(ActionEntry): """A setting entry which tells whether to send an alert message.""" - name = "send_alert" - description = "A boolean field. If all filters triggered set this to False, no mod-alert will be created." + name: ClassVar[str] = "send_alert" + description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created." - def __init__(self, entry_data: Any): - super().__init__(entry_data) - self.send_alert: bool = entry_data + send_alert: bool async def action(self, ctx: FilterContext) -> None: """Add the stored pings to the alert message content.""" @@ -23,4 +21,4 @@ class SendAlert(ActionEntry): if not isinstance(other, SendAlert): return NotImplemented - return SendAlert(self.send_alert or other.send_alert) + return SendAlert(send_alert=self.send_alert or other.send_alert) diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py index 2883deed8..2b3b030a0 100644 --- a/bot/exts/filtering/_settings_types/settings_entry.py +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -1,13 +1,15 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Optional +from typing import Any, ClassVar, Optional, Union + +from pydantic import BaseModel from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._utils import FieldRequiring -class SettingsEntry(FieldRequiring): +class SettingsEntry(BaseModel, FieldRequiring): """ A basic entry in the settings field appearing in every filter list and filter. @@ -16,34 +18,10 @@ class SettingsEntry(FieldRequiring): # Each subclass must define a name matching the entry name we're expecting to receive from the database. # Names must be unique across all filter lists. - name = FieldRequiring.MUST_SET_UNIQUE + name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE # Each subclass must define a description of what it does. If the data an entry type receives is comprised of # several DB fields, the value should a dictionary of field names and their descriptions. - description = FieldRequiring.MUST_SET - - @abstractmethod - def __init__(self, entry_data: Any): - super().__init__() - self._dict = {} - - def __setattr__(self, key: str, value: Any) -> None: - super().__setattr__(key, value) - if key == "_dict": - return - self._dict[key] = value - - def __eq__(self, other: SettingsEntry) -> bool: - if not isinstance(other, SettingsEntry): - return NotImplemented - return self._dict == other._dict - - def to_dict(self) -> dict[str, Any]: - """Return a dictionary representation of the entry.""" - return self._dict.copy() - - def copy(self) -> SettingsEntry: - """Return a new entry object with the same parameters.""" - return self.__class__(self.to_dict()) + description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET @classmethod def create(cls, entry_data: Optional[dict[str, Any]], *, keep_empty: bool = False) -> Optional[SettingsEntry]: @@ -58,7 +36,9 @@ class SettingsEntry(FieldRequiring): if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()): return None - return cls(entry_data) + if not isinstance(entry_data, dict): + entry_data = {cls.name: entry_data} + return cls(**entry_data) class ValidationEntry(SettingsEntry): diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index 14c6bd13b..158f1e7bd 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -84,14 +84,18 @@ class FieldRequiring(ABC): ... def __init_subclass__(cls, **kwargs): + def inherited(attr: str) -> bool: + """True if `attr` was inherited from a parent class.""" + for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. + if hasattr(parent, attr): # The attribute was inherited. + return True + return False + # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it. if inspect.isabstract(cls): for attribute in dir(cls): if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE: - for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. - if hasattr(parent, attribute): # The attribute was inherited. - break - else: + if not inherited(attribute): # A new attribute with the value MUST_SET_UNIQUE. FieldRequiring.__unique_attributes[cls][attribute] = set() return @@ -100,9 +104,9 @@ class FieldRequiring(ABC): if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"): continue value = getattr(cls, attribute) - if value is FieldRequiring.MUST_SET: + if value is FieldRequiring.MUST_SET and inherited(attribute): raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}") - elif value is FieldRequiring.MUST_SET_UNIQUE: + elif value is FieldRequiring.MUST_SET_UNIQUE and inherited(attribute): raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}") else: # Check if the value needs to be unique. diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 2a24769d0..630474c13 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -215,14 +215,14 @@ class Filtering(Cog): default_setting_values = {} for type_ in ("actions", "validations"): for _, setting in filter_list.defaults[list_type][type_].items(): - default_setting_values.update(to_serializable(setting.to_dict())) + default_setting_values.update(to_serializable(setting.dict())) # Get the filter's overridden settings overrides_values = {} for settings in (filter_.actions, filter_.validations): if settings: for _, setting in settings.items(): - overrides_values.update(to_serializable(setting.to_dict())) + overrides_values.update(to_serializable(setting.dict())) # Combine them. It's done in this way to preserve field order, since the filter won't have all settings. total_values = {} @@ -345,7 +345,7 @@ class Filtering(Cog): setting_values = {} for type_ in ("actions", "validations"): for _, setting in list_defaults[type_].items(): - setting_values.update(to_serializable(setting.to_dict())) + setting_values.update(to_serializable(setting.dict())) embed = self._build_embed_from_dict(setting_values) # Use the class's docstring, and ignore single newlines. @@ -458,7 +458,7 @@ class Filtering(Cog): await ctx.send(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.") return - lines = list(map(str, type_filters)) + lines = list(map(str, type_filters.values())) log.trace(f"Sending a list of {len(lines)} filters.") embed = Embed(colour=Colour.blue()) |