diff options
-rw-r--r-- | backend/models/question.py | 6 | ||||
-rw-r--r-- | backend/routes/forms/unittesting.py | 14 | ||||
-rw-r--r-- | resources/unittest_template.py | 6 |
3 files changed, 16 insertions, 10 deletions
diff --git a/backend/models/question.py b/backend/models/question.py index 5a1334a..201aa51 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -15,8 +15,10 @@ class Unittests(BaseModel): @validator("tests") def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE: """Confirm that at least one test exists in a test suite.""" - if isinstance(value, dict) and len(value.keys()) == 0: - raise ValueError("Must have at least one test in a test suite.") + if isinstance(value, dict): + keys = len(value.keys()) - (1 if "setUp" in value.keys() else 0) + if keys == 0: + raise ValueError("Must have at least one test in a test suite.") return value diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a830775..13fa639 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -36,10 +36,14 @@ def _make_unit_code(units: dict[str, str]) -> str: result = "" for unit_name, unit_code in units.items(): - result += ( - f"\ndef test_{unit_name.lstrip('#')}(unit):" # Function definition - f"\n{indent(unit_code, ' ')}" # Unit code - ) + # Function definition + if unit_name == "setUp": + result += "\ndef setUp(self):" + else: + result += f"\nasync def {unit_name.removeprefix('#')}(self):" + + # Unite code + result += f"\n{indent(unit_code, ' ')}" return indent(result, " ") @@ -83,7 +87,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit # Tests starting with an hashtag should have censored names. hidden_test_counter = count(1) hidden_tests = { - test.lstrip("#").lstrip("test_"): next(hidden_test_counter) + test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter) for test in question.data["unittests"]["tests"].keys() if test.startswith("#") } diff --git a/resources/unittest_template.py b/resources/unittest_template.py index 05730ce..104b3b4 100644 --- a/resources/unittest_template.py +++ b/resources/unittest_template.py @@ -15,7 +15,7 @@ from unittest import mock ### USER CODE -class RunnerTestCase(unittest.TestCase): +class RunnerTestCase(unittest.IsolatedAsyncioTestCase): ### UNIT CODE @@ -64,8 +64,8 @@ def _main() -> None: if not result.wasSuccessful(): RESULT.write( ";".join(chain( - (error[0]._testMethodName.lstrip("test_") for error in result.errors), - (failure[0]._testMethodName.lstrip("test_") for failure in result.failures) + (error[0]._testMethodName.removeprefix("test_") for error in result.errors), + (failure[0]._testMethodName.removeprefix("test_") for failure in result.failures) )) ) |