diff options
author | 2023-04-06 15:14:53 +0300 | |
---|---|---|
committer | 2023-04-06 15:14:53 +0300 | |
commit | e0ae45a006bf0f6502e9ae55abb9887785393651 (patch) | |
tree | e39acc5bef091244c57fa3d9ef44e8302e58635b | |
parent | Add in_guild attribute to FilterContext (diff) |
.__origin__ -> get_origin
-rw-r--r-- | bot/exts/filtering/_ui/ui.py | 17 | ||||
-rw-r--r-- | bot/exts/filtering/_utils.py | 16 |
2 files changed, 21 insertions, 12 deletions
diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 1690d2286..0de511f03 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from enum import EnumMeta from functools import partial -from typing import Any, Callable, Coroutine, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar, get_origin import discord from discord import Embed, Interaction @@ -21,7 +21,7 @@ import bot from bot.constants import Colours from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._filter_lists import FilterList -from bot.exts.filtering._utils import FakeContext +from bot.exts.filtering._utils import FakeContext, normalize_type from bot.utils.messages import format_channel, format_user, upload_log log = get_logger(__name__) @@ -135,10 +135,11 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None: def parse_value(value: str, type_: type[T]) -> T: - """Parse the value and attempt to convert it to the provided type.""" - if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias - type_ = type_.__origin__ - if value == '""': + """Parse the value provided in the CLI and attempt to convert it to the provided type.""" + blank = value == '""' + type_ = normalize_type(type_, prioritize_nonetype=blank) + + if blank or isinstance(None, type_): return type_() if type_ in (tuple, list, set): return list(value.split(",")) @@ -461,8 +462,8 @@ class EditBaseView(ABC, discord.ui.View): """Prompt the user to give an override value for the setting they selected, and respond to the interaction.""" setting_name = select.values[0] type_ = self.type_per_setting_name[setting_name] - if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias - type_ = type_.__origin__ + if origin := get_origin(type_): # In case this is a types.GenericAlias or a typing._GenericAlias + type_ = origin new_view = self.copy() # This is in order to not block the interaction response. There's a potential race condition here, since # a view's method is used without guaranteeing the task completed, but since it depends on user input diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index a43233f20..97a0fa8d4 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -132,16 +132,24 @@ def repr_equals(override: Any, default: Any) -> bool: return str(override) == str(default) -def starting_value(type_: type[T]) -> T: - """Return a value of the given type.""" +def normalize_type(type_: type, *, prioritize_nonetype: bool = True) -> type: + """Reduce a given type to one that can be initialized.""" if get_origin(type_) in (Union, types.UnionType): # In case of a Union args = get_args(type_) if type(None) in args: - return None + if prioritize_nonetype: + return type(None) + else: + args = tuple(set(args) - {type(None)}) type_ = args[0] # Pick one, doesn't matter if origin := get_origin(type_): # In case of a parameterized List, Set, Dict etc. - type_ = origin + return origin + return type_ + +def starting_value(type_: type[T]) -> T: + """Return a value of the given type.""" + type_ = normalize_type(type_) try: return type_() except TypeError: # In case it all fails, return a string and let the user handle it. |