diff options
Diffstat (limited to 'backend')
-rw-r--r-- | backend/models/__init__.py | 3 | ||||
-rw-r--r-- | backend/models/question.py | 23 | ||||
-rw-r--r-- | backend/routes/forms/unittesting.py | 8 |
3 files changed, 29 insertions, 5 deletions
diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 29ccb24..8ad7f7f 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -2,13 +2,14 @@ from .antispam import AntiSpam from .discord_user import DiscordUser from .form import Form, FormList from .form_response import FormResponse, ResponseList -from .question import Question +from .question import CodeQuestion, Question __all__ = [ "AntiSpam", "DiscordUser", "Form", "FormResponse", + "CodeQuestion", "Question", "FormList", "ResponseList" diff --git a/backend/models/question.py b/backend/models/question.py index 7daeb5a..9829843 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -4,6 +4,25 @@ from pydantic import BaseModel, Field, root_validator, validator from backend.constants import QUESTION_TYPES, REQUIRED_QUESTION_TYPE_DATA +_TESTS_TYPE = t.Union[t.Dict[str, str], int] + + +class Unittests(BaseModel): + allow_failure: bool = False + tests: _TESTS_TYPE + + @validator("tests") + def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE: + if isinstance(value, dict) and not len(value.keys()): + raise ValueError("Must have at least one test in a test suite.") + + return value + + +class CodeQuestion(BaseModel): + language: str + unittests: t.Optional[Unittests] + class Question(BaseModel): """Schema model for form question.""" @@ -49,4 +68,8 @@ class Question(BaseModel): f"got {type(value['data'][key]).__name__} instead." ) + # Validate unittest options + if value.get("type").lower() == "code": + value["data"] = CodeQuestion(**value.get("data")).dict() + return value diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index 3854314..590cb52 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -24,7 +24,7 @@ def filter_unittests(form: Form) -> Form: """ for question in form.questions: if question.type == "code" and "unittests" in question.data: - question.data["unittests"] = len(question.data["unittests"]) + question.data["unittests"]["tests"] = len(question.data["unittests"]["tests"]) return form @@ -62,7 +62,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit """Execute all the unittests in this form and return the results.""" unittest_results = [] - for question in form.questions: + for index, question in enumerate(form.questions): if question.type == "code" and "unittests" in question.data: passed = False @@ -70,12 +70,12 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit hidden_test_counter = count(1) hidden_tests = { test.lstrip("#").lstrip("test_"): next(hidden_test_counter) - for test in question.data["unittests"].keys() + for test in question.data["unittests"]["tests"].keys() if test.startswith("#") } # Compose runner code - unit_code = _make_unit_code(question.data["unittests"]) + unit_code = _make_unit_code(question.data["unittests"]["tests"]) user_code = _make_user_code(form_response.response[question.id]) code = TEST_TEMPLATE.replace("### USER CODE", user_code) |