aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mbaruh <[email protected]>2022-11-04 00:28:52 +0200
committerGravatar mbaruh <[email protected]>2022-11-04 00:28:52 +0200
commite100ae9b63f8fbb075e7ab5793d5028c74c4607b (patch)
tree90d3b924d89d289a5abc82a500cb81b6575c69a2
parentRemove old filtering constants (diff)
Stop using None as a valid setting value
A None value signifies that there's no override, and it shouldn't be used for anything else. Doing so is confusing and bug-prone.
-rw-r--r--bot/exts/filtering/_settings_types/actions/infraction_and_notification.py35
-rw-r--r--bot/exts/filtering/_ui/ui.py30
2 files changed, 29 insertions, 36 deletions
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
index fb679855a..b8b463626 100644
--- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
+++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py
@@ -61,6 +61,7 @@ class Infraction(Enum):
WARNING = auto()
WATCH = auto()
NOTE = auto()
+ NONE = auto()
def __str__(self) -> str:
return self.name
@@ -70,8 +71,8 @@ class Infraction(Enum):
user: Member | User,
channel: discord.abc.Messageable,
alerts_channel: discord.TextChannel,
- duration: float | None,
- reason: str | None
+ duration: float,
+ reason: str
) -> None:
"""Invokes the command matching the infraction name."""
command_name = self.name.lower()
@@ -81,10 +82,10 @@ class Infraction(Enum):
ctx = FakeContext(channel)
if self.name in ("KICK", "WARNING", "WATCH", "NOTE"):
- await command(ctx, user, reason=reason)
+ await command(ctx, user, reason=reason or None)
else:
duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None
- await command(ctx, user, duration, reason=reason)
+ await command(ctx, user, duration, reason=reason or None)
class InfractionAndNotification(ActionEntry):
@@ -102,29 +103,31 @@ class InfractionAndNotification(ActionEntry):
"the harsher one will be applied (by type or duration).\n\n"
"Valid infraction types in order of harshness: "
) + ", ".join(infraction.name for infraction in Infraction),
- "infraction_duration": "How long the infraction should last for in seconds, or 'None' for permanent.",
+ "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.",
"infraction_reason": "The reason delivered with the infraction.",
"infraction_channel": (
"The channel ID in which to invoke the infraction (and send the confirmation message). "
- "If blank, the infraction will be sent in the context channel. If the ID fails to resolve, it will default "
- "to the mod-alerts channel."
+ "If 0, the infraction will be sent in the context channel. If the ID otherwise fails to resolve, "
+ "it will default to the mod-alerts channel."
),
"dm_content": "The contents of a message to be DMed to the offending user.",
"dm_embed": "The contents of the embed to be DMed to the offending user."
}
- dm_content: str | None
- dm_embed: str | None
- infraction_type: Infraction | None
- infraction_reason: str | None
- infraction_duration: float | None
- infraction_channel: int | None
+ dm_content: str
+ dm_embed: str
+ infraction_type: Infraction
+ infraction_reason: str
+ infraction_duration: float
+ infraction_channel: int
@validator("infraction_type", pre=True)
@classmethod
- def convert_infraction_name(cls, infr_type: str) -> Infraction:
+ def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction:
"""Convert the string to an Infraction by name."""
- return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else None
+ if isinstance(infr_type, Infraction):
+ return infr_type
+ return Infraction[infr_type.replace(" ", "_").upper()]
async def action(self, ctx: FilterContext) -> None:
"""Send the notification to the user, and apply any specified infractions."""
@@ -150,7 +153,7 @@ class InfractionAndNotification(ActionEntry):
except Forbidden:
ctx.action_descriptions.append("failed to notify")
- if self.infraction_type is not None:
+ if self.infraction_type != Infraction.NONE:
alerts_channel = bot_module.instance.get_channel(Channels.mod_alerts)
if self.infraction_channel:
channel = bot_module.instance.get_channel(self.infraction_channel)
diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py
index 9fc15410e..17a933783 100644
--- a/bot/exts/filtering/_ui/ui.py
+++ b/bot/exts/filtering/_ui/ui.py
@@ -4,7 +4,7 @@ import re
from abc import ABC, abstractmethod
from enum import EnumMeta
from functools import partial
-from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
+from typing import Any, Callable, Coroutine, Optional, TypeVar
import discord
from botcore.site_api import ResponseCodeError
@@ -114,24 +114,12 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None:
embed.add_field(name=setting, value=value, inline=len(value) < MAX_INLINE_SIZE)
-def remove_optional(type_: type) -> tuple[bool, type]:
- """Return whether the type is Optional, and the Union of types which aren't None."""
- if not hasattr(type_, "__args__"):
- return False, type_
- args = list(type_.__args__)
- if type(None) not in args:
- return False, type_
- args.remove(type(None))
- return True, Union[tuple(args)]
-
-
def parse_value(value: str, type_: type[T]) -> T:
"""Parse the value and attempt to convert it to the provided type."""
- is_optional, type_ = remove_optional(type_)
- if is_optional and value == '""':
- return None
if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
type_ = type_.__origin__
+ if value == '""':
+ return type_()
if type_ in (tuple, list, set):
return list(value.split(","))
if type_ is bool:
@@ -273,7 +261,7 @@ class BooleanSelectView(discord.ui.View):
class FreeInputModal(discord.ui.Modal):
"""A modal to freely enter a value for a setting."""
- def __init__(self, setting_name: str, required: bool, type_: type, update_callback: Callable):
+ def __init__(self, setting_name: str, type_: type, update_callback: Callable):
title = f"{setting_name} Input" if len(setting_name) < MAX_MODAL_TITLE_LENGTH - 6 else "Setting Input"
super().__init__(timeout=COMPONENT_TIMEOUT, title=title)
@@ -282,13 +270,16 @@ class FreeInputModal(discord.ui.Modal):
self.update_callback = update_callback
label = setting_name if len(setting_name) < MAX_MODAL_TITLE_LENGTH else "Value"
- self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=required)
+ self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=False)
self.add_item(self.setting_input)
async def on_submit(self, interaction: Interaction) -> None:
"""Update the setting with the new value in the embed."""
try:
- value = self.type_(self.setting_input.value)
+ if not self.setting_input.value:
+ value = self.type_()
+ else:
+ value = self.type_(self.setting_input.value)
except (ValueError, TypeError):
await interaction.response.send_message(
f"Could not process the input value for `{self.setting_name}`.", ephemeral=True
@@ -436,7 +427,6 @@ class EditBaseView(ABC, discord.ui.View):
"""Prompt the user to give an override value for the setting they selected, and respond to the interaction."""
setting_name = select.values[0]
type_ = self.type_per_setting_name[setting_name]
- is_optional, type_ = remove_optional(type_)
if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
type_ = type_.__origin__
new_view = self.copy()
@@ -462,7 +452,7 @@ class EditBaseView(ABC, discord.ui.View):
view = EnumSelectView(setting_name, type_, update_callback)
await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True)
else:
- await interaction.response.send_modal(FreeInputModal(setting_name, not is_optional, type_, update_callback))
+ await interaction.response.send_modal(FreeInputModal(setting_name, type_, update_callback))
self.stop()
@abstractmethod