diff options
| -rw-r--r-- | tests/bot/test_constants.py | 40 | 
1 files changed, 32 insertions, 8 deletions
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index db9a9bcb0..2937b6189 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -5,6 +5,31 @@ import unittest  from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: +    """ +    Return True if `value` is an instance of the type represented by `annotation`. + +    This doesn't account for things like Unions or checking for homogenous types in collections. +    """ +    origin = typing.get_origin(annotation) + +    # This is done in case a bare e.g. `typing.List` is used. +    # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. +    # `get_origin()` does this normalisation for us. +    type_ = annotation if origin is None else origin + +    return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: +    """Return True if `value` is an instance of any type in `types`.""" +    for type_ in types: +        if is_annotation_instance(value, type_): +            return True + +    return False + +  class ConstantsTests(unittest.TestCase):      """Tests for our constants.""" @@ -20,14 +45,13 @@ class ConstantsTests(unittest.TestCase):              for name, annotation in section.__annotations__.items():                  with self.subTest(section=section, name=name, annotation=annotation):                      value = getattr(section, name) +                    origin = typing.get_origin(annotation)                      annotation_args = typing.get_args(annotation) +                    failure_msg = f"{value} is not an instance of {annotation}" -                    if not annotation_args: -                        self.assertIsInstance(value, annotation) +                    if origin is typing.Union: +                        is_instance = is_any_instance(value, annotation_args) +                        self.assertTrue(is_instance, failure_msg)                      else: -                        origin = typing.get_origin(annotation) -                        if origin is typing.Union: -                            is_instance = any(isinstance(value, arg) for arg in annotation_args) -                            self.assertTrue(is_instance) -                        else: -                            self.skipTest(f"Validating type {annotation} is unsupported.") +                        is_instance = is_annotation_instance(value, annotation) +                        self.assertTrue(is_instance, failure_msg)  |