aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2023-04-06 15:14:53 +0300
committerGravatar Boris Muratov <[email protected]>2023-04-06 15:14:53 +0300
commite0ae45a006bf0f6502e9ae55abb9887785393651 (patch)
treee39acc5bef091244c57fa3d9ef44e8302e58635b
parentAdd in_guild attribute to FilterContext (diff)
.__origin__ -> get_origin
-rw-r--r--bot/exts/filtering/_ui/ui.py17
-rw-r--r--bot/exts/filtering/_utils.py16
2 files changed, 21 insertions, 12 deletions
diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py
index 1690d2286..0de511f03 100644
--- a/bot/exts/filtering/_ui/ui.py
+++ b/bot/exts/filtering/_ui/ui.py
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable
from enum import EnumMeta
from functools import partial
-from typing import Any, Callable, Coroutine, Optional, TypeVar
+from typing import Any, Callable, Coroutine, Optional, TypeVar, get_origin
import discord
from discord import Embed, Interaction
@@ -21,7 +21,7 @@ import bot
from bot.constants import Colours
from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._filter_lists import FilterList
-from bot.exts.filtering._utils import FakeContext
+from bot.exts.filtering._utils import FakeContext, normalize_type
from bot.utils.messages import format_channel, format_user, upload_log
log = get_logger(__name__)
@@ -135,10 +135,11 @@ def populate_embed_from_dict(embed: Embed, data: dict) -> None:
def parse_value(value: str, type_: type[T]) -> T:
- """Parse the value and attempt to convert it to the provided type."""
- if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
- type_ = type_.__origin__
- if value == '""':
+ """Parse the value provided in the CLI and attempt to convert it to the provided type."""
+ blank = value == '""'
+ type_ = normalize_type(type_, prioritize_nonetype=blank)
+
+ if blank or isinstance(None, type_):
return type_()
if type_ in (tuple, list, set):
return list(value.split(","))
@@ -461,8 +462,8 @@ class EditBaseView(ABC, discord.ui.View):
"""Prompt the user to give an override value for the setting they selected, and respond to the interaction."""
setting_name = select.values[0]
type_ = self.type_per_setting_name[setting_name]
- if hasattr(type_, "__origin__"): # In case this is a types.GenericAlias or a typing._GenericAlias
- type_ = type_.__origin__
+ if origin := get_origin(type_): # In case this is a types.GenericAlias or a typing._GenericAlias
+ type_ = origin
new_view = self.copy()
# This is in order to not block the interaction response. There's a potential race condition here, since
# a view's method is used without guaranteeing the task completed, but since it depends on user input
diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py
index a43233f20..97a0fa8d4 100644
--- a/bot/exts/filtering/_utils.py
+++ b/bot/exts/filtering/_utils.py
@@ -132,16 +132,24 @@ def repr_equals(override: Any, default: Any) -> bool:
return str(override) == str(default)
-def starting_value(type_: type[T]) -> T:
- """Return a value of the given type."""
+def normalize_type(type_: type, *, prioritize_nonetype: bool = True) -> type:
+ """Reduce a given type to one that can be initialized."""
if get_origin(type_) in (Union, types.UnionType): # In case of a Union
args = get_args(type_)
if type(None) in args:
- return None
+ if prioritize_nonetype:
+ return type(None)
+ else:
+ args = tuple(set(args) - {type(None)})
type_ = args[0] # Pick one, doesn't matter
if origin := get_origin(type_): # In case of a parameterized List, Set, Dict etc.
- type_ = origin
+ return origin
+ return type_
+
+def starting_value(type_: type[T]) -> T:
+ """Return a value of the given type."""
+ type_ = normalize_type(type_)
try:
return type_()
except TypeError: # In case it all fails, return a string and let the user handle it.