aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/models/__init__.py3
-rw-r--r--backend/models/question.py23
-rw-r--r--backend/routes/forms/unittesting.py8
3 files changed, 29 insertions, 5 deletions
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index 29ccb24..8ad7f7f 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -2,13 +2,14 @@ from .antispam import AntiSpam
from .discord_user import DiscordUser
from .form import Form, FormList
from .form_response import FormResponse, ResponseList
-from .question import Question
+from .question import CodeQuestion, Question
__all__ = [
"AntiSpam",
"DiscordUser",
"Form",
"FormResponse",
+ "CodeQuestion",
"Question",
"FormList",
"ResponseList"
diff --git a/backend/models/question.py b/backend/models/question.py
index 7daeb5a..9829843 100644
--- a/backend/models/question.py
+++ b/backend/models/question.py
@@ -4,6 +4,25 @@ from pydantic import BaseModel, Field, root_validator, validator
from backend.constants import QUESTION_TYPES, REQUIRED_QUESTION_TYPE_DATA
+_TESTS_TYPE = t.Union[t.Dict[str, str], int]
+
+
+class Unittests(BaseModel):
+ allow_failure: bool = False
+ tests: _TESTS_TYPE
+
+ @validator("tests")
+ def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE:
+ if isinstance(value, dict) and not len(value.keys()):
+ raise ValueError("Must have at least one test in a test suite.")
+
+ return value
+
+
+class CodeQuestion(BaseModel):
+ language: str
+ unittests: t.Optional[Unittests]
+
class Question(BaseModel):
"""Schema model for form question."""
@@ -49,4 +68,8 @@ class Question(BaseModel):
f"got {type(value['data'][key]).__name__} instead."
)
+ # Validate unittest options
+ if value.get("type").lower() == "code":
+ value["data"] = CodeQuestion(**value.get("data")).dict()
+
return value
diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py
index 3854314..590cb52 100644
--- a/backend/routes/forms/unittesting.py
+++ b/backend/routes/forms/unittesting.py
@@ -24,7 +24,7 @@ def filter_unittests(form: Form) -> Form:
"""
for question in form.questions:
if question.type == "code" and "unittests" in question.data:
- question.data["unittests"] = len(question.data["unittests"])
+ question.data["unittests"]["tests"] = len(question.data["unittests"]["tests"])
return form
@@ -62,7 +62,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit
"""Execute all the unittests in this form and return the results."""
unittest_results = []
- for question in form.questions:
+ for index, question in enumerate(form.questions):
if question.type == "code" and "unittests" in question.data:
passed = False
@@ -70,12 +70,12 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit
hidden_test_counter = count(1)
hidden_tests = {
test.lstrip("#").lstrip("test_"): next(hidden_test_counter)
- for test in question.data["unittests"].keys()
+ for test in question.data["unittests"]["tests"].keys()
if test.startswith("#")
}
# Compose runner code
- unit_code = _make_unit_code(question.data["unittests"])
+ unit_code = _make_unit_code(question.data["unittests"]["tests"])
user_code = _make_user_code(form_response.response[question.id])
code = TEST_TEMPLATE.replace("### USER CODE", user_code)