diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/filtering/_settings_types/actions/infraction_and_notification.py | 35 | ||||
| -rw-r--r-- | bot/exts/filtering/_ui/ui.py | 30 | 
2 files changed, 29 insertions, 36 deletions
diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py index fb679855a..b8b463626 100644 --- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -61,6 +61,7 @@ class Infraction(Enum):      WARNING = auto()      WATCH = auto()      NOTE = auto() +    NONE = auto()      def __str__(self) -> str:          return self.name @@ -70,8 +71,8 @@ class Infraction(Enum):          user: Member | User,          channel: discord.abc.Messageable,          alerts_channel: discord.TextChannel, -        duration: float | None, -        reason: str | None +        duration: float, +        reason: str      ) -> None:          """Invokes the command matching the infraction name."""          command_name = self.name.lower() @@ -81,10 +82,10 @@ class Infraction(Enum):          ctx = FakeContext(channel)          if self.name in ("KICK", "WARNING", "WATCH", "NOTE"): -            await command(ctx, user, reason=reason) +            await command(ctx, user, reason=reason or None)          else:              duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None -            await command(ctx, user, duration, reason=reason) +            await command(ctx, user, duration, reason=reason or None)  class InfractionAndNotification(ActionEntry): @@ -102,29 +103,31 @@ class InfractionAndNotification(ActionEntry):              "the harsher one will be applied (by type or duration).\n\n"              "Valid infraction types in order of harshness: "          ) + ", ".join(infraction.name for infraction in Infraction), -        "infraction_duration": "How long the infraction should last for in seconds, or 'None' for permanent.", +        "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.",          "infraction_reason": "The reason delivered with the infraction.",          "infraction_channel": (              "The channel ID in which to invoke the infraction (and send the confirmation message). " -            "If blank, the infraction will be sent in the context channel. If the ID fails to resolve, it will default " -            "to the mod-alerts channel." +            "If 0, the infraction will be sent in the context channel. If the ID otherwise fails to resolve, " +            "it will default to the mod-alerts channel."          ),          "dm_content": "The contents of a message to be DMed to the offending user.",          "dm_embed": "The contents of the embed to be DMed to the offending user."      } -    dm_content: str | None -    dm_embed: str | None -    infraction_type: Infraction | None -    infraction_reason: str | None -    infraction_duration: float | None -    infraction_channel: int | None +    dm_content: str +    dm_embed: str +    infraction_type: Infraction +    infraction_reason: str +    infraction_duration: float +    infraction_channel: int      @validator("infraction_type", pre=True)      @classmethod -    def convert_infraction_name(cls, infr_type: str) -> Infraction: +    def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction:          """Convert the string to an Infraction by name.""" -        return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else None +        if isinstance(infr_type, Infraction): +            return infr_type +        return Infraction[infr_type.replace(" ", "_").upper()]      async def action(self, ctx: FilterContext) -> None:          """Send the notification to the user, and apply any specified infractions.""" @@ -150,7 +153,7 @@ class InfractionAndNotification(ActionEntry):              except Forbidden:                  ctx.action_descriptions.append("failed to notify") -        if self.infraction_type is not None: +        if self.infraction_type != Infraction.NONE:              alerts_channel = bot_module.instance.get_channel(Channels.mod_alerts)              if self.infraction_channel:                  channel = bot_module.instance.get_channel(self.infraction_channel) diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 9fc15410e..17a933783 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -4,7 +4,7 @@ import re  from abc import ABC, abstractmethod  from enum import EnumMeta  from functools import partial -from typing import Any, Callable, Coroutine, Optional, TypeVar, Union +from typing import Any, Callable, Coroutine, Optional, TypeVar  import discord  from botcore.site_api import ResponseCodeError @@ -114,24 +114,12 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None:          embed.add_field(name=setting, value=value, inline=len(value) < MAX_INLINE_SIZE) -def remove_optional(type_: type) -> tuple[bool, type]: -    """Return whether the type is Optional, and the Union of types which aren't None.""" -    if not hasattr(type_, "__args__"): -        return False, type_ -    args = list(type_.__args__) -    if type(None) not in args: -        return False, type_ -    args.remove(type(None)) -    return True, Union[tuple(args)] - -  def parse_value(value: str, type_: type[T]) -> T:      """Parse the value and attempt to convert it to the provided type.""" -    is_optional, type_ = remove_optional(type_) -    if is_optional and value == '""': -        return None      if hasattr(type_, "__origin__"):  # In case this is a types.GenericAlias or a typing._GenericAlias          type_ = type_.__origin__ +    if value == '""': +        return type_()      if type_ in (tuple, list, set):          return list(value.split(","))      if type_ is bool: @@ -273,7 +261,7 @@ class BooleanSelectView(discord.ui.View):  class FreeInputModal(discord.ui.Modal):      """A modal to freely enter a value for a setting.""" -    def __init__(self, setting_name: str, required: bool, type_: type, update_callback: Callable): +    def __init__(self, setting_name: str, type_: type, update_callback: Callable):          title = f"{setting_name} Input" if len(setting_name) < MAX_MODAL_TITLE_LENGTH - 6 else "Setting Input"          super().__init__(timeout=COMPONENT_TIMEOUT, title=title) @@ -282,13 +270,16 @@ class FreeInputModal(discord.ui.Modal):          self.update_callback = update_callback          label = setting_name if len(setting_name) < MAX_MODAL_TITLE_LENGTH else "Value" -        self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=required) +        self.setting_input = discord.ui.TextInput(label=label, style=discord.TextStyle.paragraph, required=False)          self.add_item(self.setting_input)      async def on_submit(self, interaction: Interaction) -> None:          """Update the setting with the new value in the embed."""          try: -            value = self.type_(self.setting_input.value) +            if not self.setting_input.value: +                value = self.type_() +            else: +                value = self.type_(self.setting_input.value)          except (ValueError, TypeError):              await interaction.response.send_message(                  f"Could not process the input value for `{self.setting_name}`.", ephemeral=True @@ -436,7 +427,6 @@ class EditBaseView(ABC, discord.ui.View):          """Prompt the user to give an override value for the setting they selected, and respond to the interaction."""          setting_name = select.values[0]          type_ = self.type_per_setting_name[setting_name] -        is_optional, type_ = remove_optional(type_)          if hasattr(type_, "__origin__"):  # In case this is a types.GenericAlias or a typing._GenericAlias              type_ = type_.__origin__          new_view = self.copy() @@ -462,7 +452,7 @@ class EditBaseView(ABC, discord.ui.View):              view = EnumSelectView(setting_name, type_, update_callback)              await interaction.response.send_message(f"Choose a value for `{setting_name}`:", view=view, ephemeral=True)          else: -            await interaction.response.send_modal(FreeInputModal(setting_name, not is_optional, type_, update_callback)) +            await interaction.response.send_modal(FreeInputModal(setting_name, type_, update_callback))          self.stop()      @abstractmethod  |