diff options
| author | 2022-11-04 00:28:52 +0200 | |
|---|---|---|
| committer | 2022-11-04 00:28:52 +0200 | |
| commit | e100ae9b63f8fbb075e7ab5793d5028c74c4607b (patch) | |
| tree | 90d3b924d89d289a5abc82a500cb81b6575c69a2 | |
| parent | Remove old filtering constants (diff) | |
Stop using None as a valid setting value
A None value signifies that there's no override, and it shouldn't be used for anything else.
Doing so is confusing and bug-prone.
| -rw-r--r-- | bot/exts/filtering/_settings_types/actions/infraction_and_notification.py | 35 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/ui.py | 30 |
2 files changed, 29 insertions, 36 deletions
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py index fb679855a..b8b463626 100644 --- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -61,6 +61,7 @@ class Infraction(Enum): WARNING = auto() WATCH = auto() NOTE = auto() + NONE = auto() def __str__(self) -> str: return self.name @@ -70,8 +71,8 @@ class Infraction(Enum): user: Member | User, channel: discord.abc.Messageable, alerts_channel: discord.TextChannel, - duration: float | None, - reason: str | None + duration: float, + reason: str ) -> None: """Invokes the command matching the infraction name.""" command_name = self.name.lower() @@ -81,10 +82,10 @@ class Infraction(Enum): ctx = FakeContext(channel) if self.name in ("KICK", "WARNING", "WATCH", "NOTE"): - await command(ctx, user, reason=reason) + await command(ctx, user, reason=reason or None) else: duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None - await command(ctx, user, duration, reason=reason) + await command(ctx, user, duration, reason=reason or None) class InfractionAndNotification(ActionEntry): @@ -102,29 +103,31 @@ class InfractionAndNotification(ActionEntry): "the harsher one will be applied (by type or duration).\n\n" "Valid infraction types in order of harshness: " ) + ", ".join(infraction.name for infraction in Infraction), - "infraction_duration": "How long the infraction should last for in seconds, or 'None' for permanent.", + "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.", "infraction_reason": "The reason delivered with the infraction.", "infraction_channel": ( "The channel ID in which to invoke the infraction (and send the confirmation message). " - "If blank, the infraction will be sent in the context channel. If the ID fails to resolve, it will default " - "to the mod-alerts channel." + "If 0, the infraction will be sent in the context channel. If the ID otherwise fails to resolve, " + "it will default to the mod-alerts channel." ), "dm_content": "The contents of a message to be DMed to the offending user.", "dm_embed": "The contents of the embed to be DMed to the offending user." } - dm_content: str | None - dm_embed: str | None - infraction_type: Infraction | None - infraction_reason: str | None - infraction_duration: float | None - infraction_channel: int | None + dm_content: str + dm_embed: str + infraction_type: Infraction + infraction_reason: str + infraction_duration: float + infraction_channel: int @validator("infraction_type", pre=True) @classmethod - def convert_infraction_name(cls, infr_type: str) -> Infraction: + def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction: """Convert the string to an Infraction by name.""" - return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else None + if isinstance(infr_type, Infraction): + return infr_type + return Infraction[infr_type.replace(" ", "_").upper()] async def action(self, ctx: FilterContext) -> None: """Send the notification to the user, and apply any specified infractions.""" @@ -150,7 +153,7 @@ class InfractionAndNotification(ActionEntry): except Forbidden: ctx.action_descriptions.append("failed to notify") - if self.infraction_type is not None: + if self.infraction_type != Infraction.NONE: alerts_channel = bot_module.instance.get_channel(Channels.mod_alerts) if self.infraction_channel: channel = bot_module.instance.get_channel(self.infraction_channel) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 9fc15410e..17a933783 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -4,7 +4,7 @@ import re from abc import ABC, abstractmethod from enum import EnumMeta from functools import partial -from typing import Any, Callable, Coroutine, Optional, TypeVar, Union +from typing import Any, Callable, Coroutine, Optional, TypeVar import discord from botcore.site_api import ResponseCodeError @@ -114,24 +114,12 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None: embed.add_field(name=setting, value=value, inline=len(value) < MAX_INLINE_SIZE) -def remove_optional(type_: type) -> tuple[bool, type]: - """Return whether the type is Optional, and the Union of types which aren't None.""" - if not hasattr(type_, "__args__"): - return False, type_ - args = list(type_.__args__) - if type(None) not in args: - return False, type_ - args.remove(type(None)) - return True, Union[tuple(args)] - - def parse_value(value: str, type_: type[T]) -> T: """Parse the value and attempt to convert it to the provided type.""" - is_optional, type_ = remove_optional(type_) - if is_optional and value == '""': - return None if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias type_ = type_.__origin__ + if value == '""': + return type_() if type_ in (tuple, list, set): return list(value.split(",")) if type_ is bool: @@ -273,7 +261,7 @@ class BooleanSelectView(discord.ui.View): class FreeInputModal(discord.ui.Modal): """A modal to freely enter a value for a setting.""" - def __init__(self, setting_name: str, required: bool, type_: type, update_callback: Callable): + def __init__(self, setting_name: str, type_: type, update_callback: Callable): title = f"{setting_name} Input" if len(setting_name) < MAX_MODAL_TITLE_LENGTH - 6 else "Setting Input" super().__init__(timeout=COMPONENT_TIMEOUT, title=title) @@ -282,13 +270,16 @@ class FreeInputModal(discord.ui.Modal): self.update_callback = update_callback label = setting_name if len(setting_name) < MAX_MODAL_TITLE_LENGTH else "Value" - self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=required) + self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=False) self.add_item(self.setting_input) async def on_submit(self, interaction: Interaction) -> None: """Update the setting with the new value in the embed.""" try: - value = self.type_(self.setting_input.value) + if not self.setting_input.value: + value = self.type_() + else: + value = self.type_(self.setting_input.value) except (ValueError, TypeError): await interaction.response.send_message( f"Could not process the input value for `{self.setting_name}`.", ephemeral=True @@ -436,7 +427,6 @@ class EditBaseView(ABC, discord.ui.View): """Prompt the user to give an override value for the setting they selected, and respond to the interaction.""" setting_name = select.values[0] type_ = self.type_per_setting_name[setting_name] - is_optional, type_ = remove_optional(type_) if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias type_ = type_.__origin__ new_view = self.copy() @@ -462,7 +452,7 @@ class EditBaseView(ABC, discord.ui.View): view = EnumSelectView(setting_name, type_, update_callback) await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True) else: - await interaction.response.send_modal(FreeInputModal(setting_name, not is_optional, type_, update_callback)) + await interaction.response.send_modal(FreeInputModal(setting_name, type_, update_callback)) self.stop() @abstractmethod |