aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes/forms/unittesting.py
blob: fe8320f2144e0a3bbea7bc561624e6c7a3a35973 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import ast
from collections import namedtuple
from textwrap import indent
from typing import Optional

import httpx

from backend.constants import SNEKBOX_URL
from backend.models import FormResponse, Form

with open("resources/unittest_template.py") as file:
    TEST_TEMPLATE = file.read()


UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result")


def _make_unit_code(units: dict[str, str]) -> str:
    result = ""

    for unit_name, unit_code in units.items():
        result += f"\ndef test_{unit_name}(unit):\n{indent(unit_code, '    ')}"

    return indent(result, "    ")


def _make_user_code(code: str) -> str:
    # Make sure that we we escape triple quotes and backslashes in the user code
    code = code.replace('"""', '\\"""').replace("\\", "\\\\")
    return f'USER_CODE = """{code}"""'


async def _post_eval(code: str) -> Optional[dict[str, str]]:
    data = {"input": code}
    async with httpx.AsyncClient() as client:
        response = await client.post(SNEKBOX_URL, json=data)

        if not response.status_code == 200:
            return

        return response.json()


async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]:
    unittest_results = []

    for question in form.questions:
        if question.type == "code" and "unittests" in question.data:
            passed = False

            unit_code = _make_unit_code(question.data["unittests"])
            user_code = _make_user_code(form_response.response[question.id])

            code = TEST_TEMPLATE.replace("### USER CODE", user_code)
            code = code.replace("### UNIT CODE", unit_code)

            # Make sure that the code is well formatted (we don't check for the user code)
            try:
                ast.parse(code)
            except SyntaxError:
                return_code = 99
                result = "Invalid generated unit code."

            else:
                response = await _post_eval(code)

                if not response:
                    return_code = 99
                    result = "Unable to contact code runner."
                else:
                    return_code = int(response["returncode"])

                    if return_code not in (0, 5, 99):
                        return_code = 99
                        result = "Internal error."
                    else:
                        stdout = response["stdout"]
                        passed = bool(int(stdout[0]))

                        if not passed:
                            result = stdout[1:].strip()
                        else:
                            result = ""

            unittest_results.append(UnittestResult(
                question_id=question.id,
                return_code=return_code,
                passed=passed,
                result=result
            ))

    return unittest_results