aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/filtering/_settings_types/bypass_roles.py26
-rw-r--r--bot/exts/filtering/_settings_types/channel_scope.py46
-rw-r--r--bot/exts/filtering/_settings_types/delete_messages.py14
-rw-r--r--bot/exts/filtering/_settings_types/enabled.py12
-rw-r--r--bot/exts/filtering/_settings_types/filter_dm.py10
-rw-r--r--bot/exts/filtering/_settings_types/infraction_and_notification.py68
-rw-r--r--bot/exts/filtering/_settings_types/ping.py25
-rw-r--r--bot/exts/filtering/_settings_types/send_alert.py12
-rw-r--r--bot/exts/filtering/_settings_types/settings_entry.py38
-rw-r--r--bot/exts/filtering/_utils.py16
-rw-r--r--bot/exts/filtering/filtering.py8
11 files changed, 125 insertions, 150 deletions
diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py
index e183e0b42..a5c18cffc 100644
--- a/bot/exts/filtering/_settings_types/bypass_roles.py
+++ b/bot/exts/filtering/_settings_types/bypass_roles.py
@@ -1,6 +1,7 @@
-from typing import Any
+from typing import ClassVar, Union
from discord import Member
+from pydantic import validator
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
@@ -9,17 +10,18 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
class RoleBypass(ValidationEntry):
"""A setting entry which tells whether the roles the member has allow them to bypass the filter."""
- name = "bypass_roles"
- description = "A list of role IDs or role names. Users with these roles will not trigger the filter."
-
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- self.bypass_roles = set()
- for role in entry_data:
- if role.isdigit():
- self.bypass_roles.add(int(role))
- else:
- self.bypass_roles.add(role)
+ name: ClassVar[str] = "bypass_roles"
+ description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter."
+
+ bypass_roles: set[Union[int, str]]
+
+ @validator("bypass_roles", each_item=True)
+ @classmethod
+ def maybe_cast_to_int(cls, role: str) -> Union[int, str]:
+ """If the string is alphanumeric, cast it to int."""
+ if role.isdigit():
+ return int(role)
+ return role
def triggers_on(self, ctx: FilterContext) -> bool:
"""Return whether the filter should be triggered on this user given their roles."""
diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py
index 3a95834b3..fd5206b81 100644
--- a/bot/exts/filtering/_settings_types/channel_scope.py
+++ b/bot/exts/filtering/_settings_types/channel_scope.py
@@ -1,21 +1,16 @@
-from typing import Any, Union
+from typing import ClassVar, Union
+
+from pydantic import validator
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
-def maybe_cast_to_int(item: str) -> Union[str, int]:
- """Cast the item to int if it consists of only digit, or leave as is otherwise."""
- if item.isdigit():
- return int(item)
- return item
-
-
class ChannelScope(ValidationEntry):
"""A setting entry which tells whether the filter was invoked in a whitelisted channel or category."""
- name = "channel_scope"
- description = {
+ name: ClassVar[str] = "channel_scope"
+ description: ClassVar[str] = {
"disabled_channels": "A list of channel IDs or channel names. The filter will not trigger in these channels.",
"disabled_categories": (
"A list of category IDs or category names. The filter will not trigger in these categories."
@@ -26,22 +21,25 @@ class ChannelScope(ValidationEntry):
)
}
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- if entry_data["disabled_channels"]:
- self.disabled_channels = set(map(maybe_cast_to_int, entry_data["disabled_channels"]))
- else:
- self.disabled_channels = set()
+ disabled_channels: set[Union[str, int]]
+ disabled_categories: set[Union[str, int]]
+ enabled_channels: set[Union[str, int]]
- if entry_data["disabled_categories"]:
- self.disabled_categories = set(map(maybe_cast_to_int, entry_data["disabled_categories"]))
- else:
- self.disabled_categories = set()
+ @validator("*", pre=True)
+ @classmethod
+ def init_if_sequence_none(cls, sequence: list[str]) -> list[str]:
+ """Initialize an empty sequence if the value is None."""
+ if sequence is None:
+ return []
+ return sequence
- if entry_data["enabled_channels"]:
- self.enabled_channels = set(map(maybe_cast_to_int, entry_data["enabled_channels"]))
- else:
- self.enabled_channels = set()
+ @validator("*", each_item=True)
+ @classmethod
+ def maybe_cast_items(cls, channel_or_category: str) -> Union[str, int]:
+ """Cast to int each value in each sequence if it is alphanumeric."""
+ if channel_or_category.isdigit():
+ return int(channel_or_category)
+ return channel_or_category
def triggers_on(self, ctx: FilterContext) -> bool:
"""
diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py
index 8de58f804..710cb0ed8 100644
--- a/bot/exts/filtering/_settings_types/delete_messages.py
+++ b/bot/exts/filtering/_settings_types/delete_messages.py
@@ -1,5 +1,5 @@
from contextlib import suppress
-from typing import Any
+from typing import ClassVar
from discord.errors import NotFound
@@ -10,12 +10,12 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry
class DeleteMessages(ActionEntry):
"""A setting entry which tells whether to delete the offending message(s)."""
- name = "delete_messages"
- description = "A boolean field. If True, the filter being triggered will cause the offending message to be deleted."
+ name: ClassVar[str] = "delete_messages"
+ description: ClassVar[str] = (
+ "A boolean field. If True, the filter being triggered will cause the offending message to be deleted."
+ )
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- self.delete_messages: bool = entry_data
+ delete_messages: bool
async def action(self, ctx: FilterContext) -> None:
"""Delete the context message(s)."""
@@ -32,4 +32,4 @@ class DeleteMessages(ActionEntry):
if not isinstance(other, DeleteMessages):
return NotImplemented
- return DeleteMessages(self.delete_messages or other.delete_messages)
+ return DeleteMessages(delete_messages=self.delete_messages or other.delete_messages)
diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py
index 081ae02b0..3b5e3e446 100644
--- a/bot/exts/filtering/_settings_types/enabled.py
+++ b/bot/exts/filtering/_settings_types/enabled.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import ClassVar
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
@@ -7,12 +7,12 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
class Enabled(ValidationEntry):
"""A setting entry which tells whether the filter is enabled."""
- name = "enabled"
- description = "A boolean field. Setting it to False allows disabling the filter without deleting it entirely."
+ name: ClassVar[str] = "enabled"
+ description: ClassVar[str] = (
+ "A boolean field. Setting it to False allows disabling the filter without deleting it entirely."
+ )
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- self.enabled = entry_data
+ enabled: bool
def triggers_on(self, ctx: FilterContext) -> bool:
"""Return whether the filter is enabled."""
diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py
index 1405a636f..93022320f 100644
--- a/bot/exts/filtering/_settings_types/filter_dm.py
+++ b/bot/exts/filtering/_settings_types/filter_dm.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import ClassVar
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
@@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
class FilterDM(ValidationEntry):
"""A setting entry which tells whether to apply the filter to DMs."""
- name = "filter_dm"
- description = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs."
+ name: ClassVar[str] = "filter_dm"
+ description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs."
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- self.filter_dm = entry_data
+ filter_dm: bool
def triggers_on(self, ctx: FilterContext) -> bool:
"""Return whether the filter should be triggered even if it was triggered in DMs."""
diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py
index 4fae09f23..9c7d7b8ff 100644
--- a/bot/exts/filtering/_settings_types/infraction_and_notification.py
+++ b/bot/exts/filtering/_settings_types/infraction_and_notification.py
@@ -1,11 +1,12 @@
from collections import namedtuple
from datetime import timedelta
from enum import Enum, auto
-from typing import Any, Optional
+from typing import ClassVar, Optional
import arrow
from discord import Colour, Embed
from discord.errors import Forbidden
+from pydantic import validator
import bot
from bot.constants import Channels, Guild
@@ -50,8 +51,8 @@ class InfractionAndNotification(ActionEntry):
Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together.
"""
- name = "infraction_and_notification"
- description = {
+ name: ClassVar[str] = "infraction_and_notification"
+ description: ClassVar[dict[str, str]] = {
"infraction_type": (
"The type of infraction to issue when the filter triggers, or 'NONE'. "
"If two infractions are triggered for the same message, "
@@ -65,27 +66,18 @@ class InfractionAndNotification(ActionEntry):
"dm_embed": "The contents of the embed to be DMed to the offending user."
}
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
+ dm_content: str
+ dm_embed: str
+ infraction_type: Optional[Infraction]
+ infraction_reason: Optional[str]
+ infraction_duration: Optional[float]
+ superstar: Optional[superstar] = None
- if entry_data["infraction_type"]:
- self.infraction_type = entry_data["infraction_type"]
- if isinstance(self.infraction_type, str):
- self.infraction_type = Infraction[self.infraction_type.replace(" ", "_").upper()]
- self.infraction_reason = entry_data["infraction_reason"]
- if entry_data["infraction_duration"] is not None:
- self.infraction_duration = float(entry_data["infraction_duration"])
- else:
- self.infraction_duration = None
- else:
- self.infraction_type = Infraction.NONE
- self.infraction_reason = None
- self.infraction_duration = 0
-
- self.dm_content = entry_data["dm_content"]
- self.dm_embed = entry_data["dm_embed"]
-
- self._superstar = entry_data.get("superstar", None)
+ @validator("infraction_type", pre=True)
+ @classmethod
+ def convert_infraction_name(cls, infr_type: str) -> Infraction:
+ """Convert the string to an Infraction by name."""
+ return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else Infraction.NONE
async def action(self, ctx: FilterContext) -> None:
"""Send the notification to the user, and apply any specified infractions."""
@@ -115,14 +107,14 @@ class InfractionAndNotification(ActionEntry):
msg_ctx.guild = bot.instance.get_guild(Guild.id)
msg_ctx.author = ctx.author
msg_ctx.channel = ctx.channel
- if self._superstar:
+ if self.superstar:
msg_ctx.command = bot.instance.get_command("superstarify")
await msg_ctx.invoke(
msg_ctx.command,
ctx.author,
- arrow.utcnow() + timedelta(seconds=self._superstar.duration)
- if self._superstar.duration is not None else None,
- reason=self._superstar.reason
+ arrow.utcnow() + timedelta(seconds=self.superstar.duration)
+ if self.superstar.duration is not None else None,
+ reason=self.superstar.reason
)
ctx.action_descriptions.append("superstar")
@@ -160,31 +152,31 @@ class InfractionAndNotification(ActionEntry):
# Lower number -> higher in the hierarchy
if self.infraction_type.value < other.infraction_type.value and other.infraction_type != Infraction.SUPERSTAR:
- result = InfractionAndNotification(self.to_dict())
- result._superstar = self._merge_superstars(self._superstar, other._superstar)
+ result = self.copy()
+ result.superstar = self._merge_superstars(self.superstar, other.superstar)
return result
elif self.infraction_type.value > other.infraction_type.value and self.infraction_type != Infraction.SUPERSTAR:
- result = InfractionAndNotification(other.to_dict())
- result._superstar = self._merge_superstars(self._superstar, other._superstar)
+ result = other.copy()
+ result.superstar = self._merge_superstars(self.superstar, other.superstar)
return result
if self.infraction_type == other.infraction_type:
if self.infraction_duration is None or (
other.infraction_duration is not None and self.infraction_duration > other.infraction_duration
):
- result = InfractionAndNotification(self.to_dict())
+ result = self.copy()
else:
- result = InfractionAndNotification(other.to_dict())
- result._superstar = self._merge_superstars(self._superstar, other._superstar)
+ result = other.copy()
+ result.superstar = self._merge_superstars(self.superstar, other.superstar)
return result
# At this stage the infraction types are different, and the lower one is a superstar.
if self.infraction_type.value < other.infraction_type.value:
- result = InfractionAndNotification(self.to_dict())
- result._superstar = superstar(other.infraction_reason, other.infraction_duration)
+ result = self.copy()
+ result.superstar = superstar(other.infraction_reason, other.infraction_duration)
else:
- result = InfractionAndNotification(other.to_dict())
- result._superstar = superstar(self.infraction_reason, self.infraction_duration)
+ result = other.copy()
+ result.superstar = superstar(self.infraction_reason, self.infraction_duration)
return result
@staticmethod
diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py
index 1e0067690..8a3403b59 100644
--- a/bot/exts/filtering/_settings_types/ping.py
+++ b/bot/exts/filtering/_settings_types/ping.py
@@ -1,7 +1,8 @@
from functools import cache
-from typing import Any
+from typing import ClassVar
from discord import Guild
+from pydantic import validator
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ActionEntry
@@ -10,8 +11,8 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry
class Ping(ActionEntry):
"""A setting entry which adds the appropriate pings to the alert."""
- name = "mentions"
- description = {
+ name: ClassVar[str] = "mentions"
+ description: ClassVar[dict[str, str]] = {
"guild_pings": (
"A list of role IDs/role names/user IDs/user names/here/everyone. "
"If a mod-alert is generated for a filter triggered in a public channel, these will be pinged."
@@ -22,11 +23,16 @@ class Ping(ActionEntry):
)
}
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
+ guild_pings: set[str]
+ dm_pings: set[str]
- self.guild_pings = set(entry_data["guild_pings"]) if entry_data["guild_pings"] else set()
- self.dm_pings = set(entry_data["dm_pings"]) if entry_data["dm_pings"] else set()
+ @validator("*")
+ @classmethod
+ def init_sequence_if_none(cls, pings: list[str]) -> list[str]:
+ """Initialize an empty sequence if the value is None."""
+ if pings is None:
+ return []
+ return pings
async def action(self, ctx: FilterContext) -> None:
"""Add the stored pings to the alert message content."""
@@ -39,10 +45,7 @@ class Ping(ActionEntry):
if not isinstance(other, Ping):
return NotImplemented
- return Ping({
- "ping_type": self.guild_pings | other.guild_pings,
- "dm_ping_type": self.dm_pings | other.dm_pings
- })
+ return Ping(ping_type=self.guild_pings | other.guild_pings, dm_ping_type=self.dm_pings | other.dm_pings)
@staticmethod
@cache
diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py
index 6429b99ac..04e400764 100644
--- a/bot/exts/filtering/_settings_types/send_alert.py
+++ b/bot/exts/filtering/_settings_types/send_alert.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import ClassVar
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ActionEntry
@@ -7,12 +7,10 @@ from bot.exts.filtering._settings_types.settings_entry import ActionEntry
class SendAlert(ActionEntry):
"""A setting entry which tells whether to send an alert message."""
- name = "send_alert"
- description = "A boolean field. If all filters triggered set this to False, no mod-alert will be created."
+ name: ClassVar[str] = "send_alert"
+ description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created."
- def __init__(self, entry_data: Any):
- super().__init__(entry_data)
- self.send_alert: bool = entry_data
+ send_alert: bool
async def action(self, ctx: FilterContext) -> None:
"""Add the stored pings to the alert message content."""
@@ -23,4 +21,4 @@ class SendAlert(ActionEntry):
if not isinstance(other, SendAlert):
return NotImplemented
- return SendAlert(self.send_alert or other.send_alert)
+ return SendAlert(send_alert=self.send_alert or other.send_alert)
diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py
index 2883deed8..2b3b030a0 100644
--- a/bot/exts/filtering/_settings_types/settings_entry.py
+++ b/bot/exts/filtering/_settings_types/settings_entry.py
@@ -1,13 +1,15 @@
from __future__ import annotations
from abc import abstractmethod
-from typing import Any, Optional
+from typing import Any, ClassVar, Optional, Union
+
+from pydantic import BaseModel
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._utils import FieldRequiring
-class SettingsEntry(FieldRequiring):
+class SettingsEntry(BaseModel, FieldRequiring):
"""
A basic entry in the settings field appearing in every filter list and filter.
@@ -16,34 +18,10 @@ class SettingsEntry(FieldRequiring):
# Each subclass must define a name matching the entry name we're expecting to receive from the database.
# Names must be unique across all filter lists.
- name = FieldRequiring.MUST_SET_UNIQUE
+ name: ClassVar[str] = FieldRequiring.MUST_SET_UNIQUE
# Each subclass must define a description of what it does. If the data an entry type receives is comprised of
# several DB fields, the value should a dictionary of field names and their descriptions.
- description = FieldRequiring.MUST_SET
-
- @abstractmethod
- def __init__(self, entry_data: Any):
- super().__init__()
- self._dict = {}
-
- def __setattr__(self, key: str, value: Any) -> None:
- super().__setattr__(key, value)
- if key == "_dict":
- return
- self._dict[key] = value
-
- def __eq__(self, other: SettingsEntry) -> bool:
- if not isinstance(other, SettingsEntry):
- return NotImplemented
- return self._dict == other._dict
-
- def to_dict(self) -> dict[str, Any]:
- """Return a dictionary representation of the entry."""
- return self._dict.copy()
-
- def copy(self) -> SettingsEntry:
- """Return a new entry object with the same parameters."""
- return self.__class__(self.to_dict())
+ description: ClassVar[Union[str, dict[str, str]]] = FieldRequiring.MUST_SET
@classmethod
def create(cls, entry_data: Optional[dict[str, Any]], *, keep_empty: bool = False) -> Optional[SettingsEntry]:
@@ -58,7 +36,9 @@ class SettingsEntry(FieldRequiring):
if not keep_empty and hasattr(entry_data, "values") and not any(value for value in entry_data.values()):
return None
- return cls(entry_data)
+ if not isinstance(entry_data, dict):
+ entry_data = {cls.name: entry_data}
+ return cls(**entry_data)
class ValidationEntry(SettingsEntry):
diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py
index 14c6bd13b..158f1e7bd 100644
--- a/bot/exts/filtering/_utils.py
+++ b/bot/exts/filtering/_utils.py
@@ -84,14 +84,18 @@ class FieldRequiring(ABC):
...
def __init_subclass__(cls, **kwargs):
+ def inherited(attr: str) -> bool:
+ """True if `attr` was inherited from a parent class."""
+ for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object.
+ if hasattr(parent, attr): # The attribute was inherited.
+ return True
+ return False
+
# If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it.
if inspect.isabstract(cls):
for attribute in dir(cls):
if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE:
- for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object.
- if hasattr(parent, attribute): # The attribute was inherited.
- break
- else:
+ if not inherited(attribute):
# A new attribute with the value MUST_SET_UNIQUE.
FieldRequiring.__unique_attributes[cls][attribute] = set()
return
@@ -100,9 +104,9 @@ class FieldRequiring(ABC):
if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"):
continue
value = getattr(cls, attribute)
- if value is FieldRequiring.MUST_SET:
+ if value is FieldRequiring.MUST_SET and inherited(attribute):
raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}")
- elif value is FieldRequiring.MUST_SET_UNIQUE:
+ elif value is FieldRequiring.MUST_SET_UNIQUE and inherited(attribute):
raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}")
else:
# Check if the value needs to be unique.
diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py
index 2a24769d0..630474c13 100644
--- a/bot/exts/filtering/filtering.py
+++ b/bot/exts/filtering/filtering.py
@@ -215,14 +215,14 @@ class Filtering(Cog):
default_setting_values = {}
for type_ in ("actions", "validations"):
for _, setting in filter_list.defaults[list_type][type_].items():
- default_setting_values.update(to_serializable(setting.to_dict()))
+ default_setting_values.update(to_serializable(setting.dict()))
# Get the filter's overridden settings
overrides_values = {}
for settings in (filter_.actions, filter_.validations):
if settings:
for _, setting in settings.items():
- overrides_values.update(to_serializable(setting.to_dict()))
+ overrides_values.update(to_serializable(setting.dict()))
# Combine them. It's done in this way to preserve field order, since the filter won't have all settings.
total_values = {}
@@ -345,7 +345,7 @@ class Filtering(Cog):
setting_values = {}
for type_ in ("actions", "validations"):
for _, setting in list_defaults[type_].items():
- setting_values.update(to_serializable(setting.to_dict()))
+ setting_values.update(to_serializable(setting.dict()))
embed = self._build_embed_from_dict(setting_values)
# Use the class's docstring, and ignore single newlines.
@@ -458,7 +458,7 @@ class Filtering(Cog):
await ctx.send(f":x: There is no list of {past_tense(list_type.name.lower())} {filter_list.name}s.")
return
- lines = list(map(str, type_filters))
+ lines = list(map(str, type_filters.values()))
log.trace(f"Sending a list of {len(lines)} filters.")
embed = Embed(colour=Colour.blue())