aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Matteo Bertucci <[email protected]>2021-02-24 12:07:41 +0100
committerGravatar Matteo Bertucci <[email protected]>2021-02-24 12:07:41 +0100
commit6c38d1f153211e1731ed805da992fa5978ead91e (patch)
treea3ed7c7bca59efc758046263fdc26b9c47681072
parentAdd unittest template (diff)
Support code unit testing through snekbox
-rw-r--r--backend/routes/forms/unittesting.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py
new file mode 100644
index 0000000..3e1d280
--- /dev/null
+++ b/backend/routes/forms/unittesting.py
@@ -0,0 +1,91 @@
+import ast
+from collections import namedtuple
+from textwrap import indent
+from typing import Optional
+
+import httpx
+
+from backend.constants import SNEKBOX_URL
+from backend.models import FormResponse, Form
+
+with open("resources/unittest_template.py") as file:
+ TEST_TEMPLATE = file.read()
+
+
+UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result")
+
+
+def _make_unit_code(units: dict[str, str]) -> str:
+ result = ""
+
+ for unit_name, unit_code in units.items():
+ result += f"\ndef test_{unit_name}(unit):\n{indent(unit_code, ' ')}"
+
+ return indent(result, " ")
+
+
+def _make_user_code(code: str) -> str:
+ # Make sure that we we escape triple quotes and backslashes in the user code
+ code = code.replace('"""', '\\"""').replace("\\", "\\\\")
+ return f'USER_CODE = """{code}"""'
+
+
+async def _post_eval(code: str) -> Optional[dict[str, str]]:
+ data = {"input": code}
+ async with httpx.AsyncClient() as client:
+ response = await client.post(SNEKBOX_URL, json=data)
+
+ if not response.status_code == 200:
+ return
+
+ return response.json()
+
+
+async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]:
+ unittest_results = []
+
+ for question in form.questions:
+ if question.type == "code" and "unittests" in question.data:
+ passed = False
+
+ unit_code = _make_unit_code(question.data["unittests"])
+ user_code = _make_user_code(form_response.response[question.id])
+
+ code = TEST_TEMPLATE.replace("### USER CODE", user_code).replace("### UNIT CODE", unit_code)
+
+ # Make sure that the code is well formatted (we don't check for the user code)
+ try:
+ ast.parse(code)
+ except SyntaxError:
+ return_code = 99
+ result = "Invalid generated unit code."
+
+ else:
+ response = await _post_eval(code)
+
+ if not response:
+ return_code = 99
+ result = "Unable to contact code runner."
+ else:
+ return_code = int(response["returncode"])
+
+ if return_code not in (0, 5, 99):
+ return_code = 99
+ result = "Internal error."
+ else:
+ stdout = response["stdout"]
+ passed = bool(int(stdout[0]))
+
+ if not passed:
+ result = stdout[1:].strip()
+ else:
+ result = ""
+
+ unittest_results.append(UnittestResult(
+ question_id=question.id,
+ return_code=return_code,
+ passed=passed,
+ result=result
+ ))
+
+ return unittest_results