diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/models/question.py | 6 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 14 | 
2 files changed, 13 insertions, 7 deletions
| diff --git a/backend/models/question.py b/backend/models/question.py index 5a1334a..201aa51 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -15,8 +15,10 @@ class Unittests(BaseModel):      @validator("tests")      def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE:          """Confirm that at least one test exists in a test suite.""" -        if isinstance(value, dict) and len(value.keys()) == 0: -            raise ValueError("Must have at least one test in a test suite.") +        if isinstance(value, dict): +            keys = len(value.keys()) - (1 if "setUp" in value.keys() else 0) +            if keys == 0: +                raise ValueError("Must have at least one test in a test suite.")          return value diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a830775..13fa639 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -36,10 +36,14 @@ def _make_unit_code(units: dict[str, str]) -> str:      result = ""      for unit_name, unit_code in units.items(): -        result += ( -            f"\ndef test_{unit_name.lstrip('#')}(unit):"  # Function definition -            f"\n{indent(unit_code, '    ')}"  # Unit code -        ) +        # Function definition +        if unit_name == "setUp": +            result += "\ndef setUp(self):" +        else: +            result += f"\nasync def {unit_name.removeprefix('#')}(self):" + +        # Unite code +        result += f"\n{indent(unit_code, '    ')}"      return indent(result, "    ") @@ -83,7 +87,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit              # Tests starting with an hashtag should have censored names.              hidden_test_counter = count(1)              hidden_tests = { -                test.lstrip("#").lstrip("test_"): next(hidden_test_counter) +                test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter)                  for test in question.data["unittests"]["tests"].keys()                  if test.startswith("#")              } | 
