diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/test_constants.py | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index db9a9bcb0..2937b6189 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -5,6 +5,31 @@ import unittest from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: + """ + Return True if `value` is an instance of the type represented by `annotation`. + + This doesn't account for things like Unions or checking for homogenous types in collections. + """ + origin = typing.get_origin(annotation) + + # This is done in case a bare e.g. `typing.List` is used. + # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. + # `get_origin()` does this normalisation for us. + type_ = annotation if origin is None else origin + + return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: + """Return True if `value` is an instance of any type in `types`.""" + for type_ in types: + if is_annotation_instance(value, type_): + return True + + return False + + class ConstantsTests(unittest.TestCase): """Tests for our constants.""" @@ -20,14 +45,13 @@ class ConstantsTests(unittest.TestCase): for name, annotation in section.__annotations__.items(): with self.subTest(section=section, name=name, annotation=annotation): value = getattr(section, name) + origin = typing.get_origin(annotation) annotation_args = typing.get_args(annotation) + failure_msg = f"{value} is not an instance of {annotation}" - if not annotation_args: - self.assertIsInstance(value, annotation) + if origin is typing.Union: + is_instance = is_any_instance(value, annotation_args) + self.assertTrue(is_instance, failure_msg) else: - origin = typing.get_origin(annotation) - if origin is typing.Union: - is_instance = any(isinstance(value, arg) for arg in annotation_args) - self.assertTrue(is_instance) - else: - self.skipTest(f"Validating type {annotation} is unsupported.") + is_instance = is_annotation_instance(value, annotation) + self.assertTrue(is_instance, failure_msg) |