diff options
| -rw-r--r-- | bot/exts/filtering/_ui/ui.py | 17 | ||||
| -rw-r--r-- | bot/exts/filtering/_utils.py | 16 | 
2 files changed, 21 insertions, 12 deletions
| diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 1690d2286..0de511f03 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod  from collections.abc import Iterable  from enum import EnumMeta  from functools import partial -from typing import Any, Callable, Coroutine, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar, get_origin  import discord  from discord import Embed, Interaction @@ -21,7 +21,7 @@ import bot  from bot.constants import Colours  from bot.exts.filtering._filter_context import FilterContext  from bot.exts.filtering._filter_lists import FilterList -from bot.exts.filtering._utils import FakeContext +from bot.exts.filtering._utils import FakeContext, normalize_type  from bot.utils.messages import format_channel, format_user, upload_log  log = get_logger(__name__) @@ -135,10 +135,11 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None:  def parse_value(value: str, type_: type[T]) -> T: -    """Parse the value and attempt to convert it to the provided type.""" -    if hasattr(type_, "__origin__"):  # In case this is a types.GenericAlias or a typing._GenericAlias -        type_ = type_.__origin__ -    if value == '""': +    """Parse the value provided in the CLI and attempt to convert it to the provided type.""" +    blank = value == '""' +    type_ = normalize_type(type_, prioritize_nonetype=blank) + +    if blank or isinstance(None, type_):          return type_()      if type_ in (tuple, list, set):          return list(value.split(",")) @@ -461,8 +462,8 @@ class EditBaseView(ABC, discord.ui.View):          """Prompt the user to give an override value for the setting they selected, and respond to the interaction."""          setting_name = select.values[0]          type_ = self.type_per_setting_name[setting_name] -        if hasattr(type_, "__origin__"):  # In case this is a types.GenericAlias or a typing._GenericAlias -            type_ = type_.__origin__ +        if origin := get_origin(type_):  # In case this is a types.GenericAlias or a typing._GenericAlias +            type_ = origin          new_view = self.copy()          # This is in order to not block the interaction response. There's a potential race condition here, since          # a view's method is used without guaranteeing the task completed, but since it depends on user input diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index a43233f20..97a0fa8d4 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -132,16 +132,24 @@ def repr_equals(override: Any, default: Any) -> bool:      return str(override) == str(default) -def starting_value(type_: type[T]) -> T: -    """Return a value of the given type.""" +def normalize_type(type_: type, *, prioritize_nonetype: bool = True) -> type: +    """Reduce a given type to one that can be initialized."""      if get_origin(type_) in (Union, types.UnionType):  # In case of a Union          args = get_args(type_)          if type(None) in args: -            return None +            if prioritize_nonetype: +                return type(None) +            else: +                args = tuple(set(args) - {type(None)})          type_ = args[0]  # Pick one, doesn't matter      if origin := get_origin(type_):  # In case of a parameterized List, Set, Dict etc. -        type_ = origin +        return origin +    return type_ + +def starting_value(type_: type[T]) -> T: +    """Return a value of the given type.""" +    type_ = normalize_type(type_)      try:          return type_()      except TypeError:  # In case it all fails, return a string and let the user handle it. | 
