diff options
| -rw-r--r-- | backend/models/question.py | 6 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 6 | ||||
| -rw-r--r-- | resources/unittest_template.py | 4 | 
3 files changed, 10 insertions, 6 deletions
| diff --git a/backend/models/question.py b/backend/models/question.py index 5a1334a..201aa51 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -15,8 +15,10 @@ class Unittests(BaseModel):      @validator("tests")      def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE:          """Confirm that at least one test exists in a test suite.""" -        if isinstance(value, dict) and len(value.keys()) == 0: -            raise ValueError("Must have at least one test in a test suite.") +        if isinstance(value, dict): +            keys = len(value.keys()) - (1 if "setUp" in value.keys() else 0) +            if keys == 0: +                raise ValueError("Must have at least one test in a test suite.")          return value diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a830775..c093718 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -36,8 +36,10 @@ def _make_unit_code(units: dict[str, str]) -> str:      result = ""      for unit_name, unit_code in units.items(): +        test_prefix = "test_" if unit_name != "setUp" else "" +          result += ( -            f"\ndef test_{unit_name.lstrip('#')}(unit):"  # Function definition +            f"\ndef {test_prefix}{unit_name.removeprefix('#')}(self):"  # Function definition              f"\n{indent(unit_code, '    ')}"  # Unit code          ) @@ -83,7 +85,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit              # Tests starting with an hashtag should have censored names.              hidden_test_counter = count(1)              hidden_tests = { -                test.lstrip("#").lstrip("test_"): next(hidden_test_counter) +                test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter)                  for test in question.data["unittests"]["tests"].keys()                  if test.startswith("#")              } diff --git a/resources/unittest_template.py b/resources/unittest_template.py index 05730ce..6f704f3 100644 --- a/resources/unittest_template.py +++ b/resources/unittest_template.py @@ -64,8 +64,8 @@ def _main() -> None:      if not result.wasSuccessful():          RESULT.write(              ";".join(chain( -                (error[0]._testMethodName.lstrip("test_") for error in result.errors), -                (failure[0]._testMethodName.lstrip("test_") for failure in result.failures) +                (error[0]._testMethodName.removeprefix("test_") for error in result.errors), +                (failure[0]._testMethodName.removeprefix("test_") for failure in result.failures)              ))          ) | 
