diff options
| -rw-r--r-- | backend/routes/forms/submit.py | 2 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 20 | ||||
| -rw-r--r-- | resources/unittest_template.py | 13 | 
3 files changed, 23 insertions, 12 deletions
| diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 7618a33..d6b549e 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -131,12 +131,14 @@ class SubmitForm(Route):              except ValidationError as e:                  return JSONResponse(e.errors(), status_code=422) +            # Run unittests if needed              has_unittests = any("unittests" in question.data for question in form.questions)              if has_unittests:                  unittest_results = await execute_unittest(response_obj, form)                  was_successful = all(test.passed for test in unittest_results)                  if not was_successful: +                    # Return 500 if we encountered an internal error (code 99).                      status_code = 500 if any(                          test.return_code == 99 for test in unittest_results                      ) else 403 diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index 0cb7d8d..e038f3a 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -30,6 +30,7 @@ def filter_unittests(form: Form) -> Form:  def _make_unit_code(units: dict[str, str]) -> str: +    """Compose a dict mapping unit names to their code into an actual class body."""      result = ""      for unit_name, unit_code in units.items(): @@ -39,14 +40,16 @@ def _make_unit_code(units: dict[str, str]) -> str:  def _make_user_code(code: str) -> str: -    # Make sure that we we escape triple quotes and backslashes in the user code -    code = code.replace('"""', '\\"""').replace("\\", "\\\\") -    return f'USER_CODE = """{code}"""' +    """Compose the user code into an actual string variable.""" +    # Make sure that we we escape triple quotes in the user code +    code = code.replace('"""', '\\"""') +    return f'USER_CODE = r"""{code}"""'  async def _post_eval(code: str) -> Optional[dict[str, str]]: -    data = {"input": code} +    """Post the eval to snekbox and return the response."""      async with httpx.AsyncClient() as client: +        data = {"input": code}          response = await client.post(SNEKBOX_URL, json=data)          if not response.status_code == 200: @@ -56,12 +59,14 @@ async def _post_eval(code: str) -> Optional[dict[str, str]]:  async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]: +    """Execute all the unittests in this form and return the results."""      unittest_results = []      for question in form.questions:          if question.type == "code" and "unittests" in question.data:              passed = False +            # Tests starting with an hashtag should have censored names.              hidden_test_counter = count(1)              hidden_tests = {                  test.lstrip("#"): next(hidden_test_counter) @@ -69,19 +74,20 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit                  if test.startswith("#")              } +            # Compose runner code              unit_code = _make_unit_code(question.data["unittests"])              user_code = _make_user_code(form_response.response[question.id])              code = TEST_TEMPLATE.replace("### USER CODE", user_code)              code = code.replace("### UNIT CODE", unit_code) -            # Make sure that the code is well formatted (we don't check for the user code) +            # Make sure that the code is well formatted (we don't check for the user code).              try:                  ast.parse(code)              except SyntaxError:                  return_code = 99                  result = "Invalid generated unit code." - +            # The runner is correctly formatted, we can run it.              else:                  response = await _post_eval(code) @@ -91,6 +97,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit                  else:                      return_code = int(response["returncode"]) +                    # Another code has been returned by CPython because of another failure.                      if return_code not in (0, 5, 99):                          return_code = 99                          result = "Internal error." @@ -98,6 +105,7 @@ async def execute_unittest(form_response: FormResponse, form: Form) -> list[Unit                          stdout = response["stdout"]                          passed = bool(int(stdout[0])) +                        # If the test failed, we have to populate the result string.                          if not passed:                              failed_tests = stdout[1:].strip().split(";") diff --git a/resources/unittest_template.py b/resources/unittest_template.py index c792944..4c9b0bb 100644 --- a/resources/unittest_template.py +++ b/resources/unittest_template.py @@ -1,4 +1,5 @@  # flake8: noqa +"""This template is used inside snekbox to evaluate and test user code."""  import ast  import io  import os @@ -23,27 +24,26 @@ DEVNULL = SimpleNamespace(write=lambda *_: None, flush=lambda *_: None)  RESULT = io.StringIO()  ORIGINAL_STDOUT = sys.stdout +# stdout/err is patched in order to control what is outputted by the runner  sys.stdout = DEVNULL  sys.stderr = DEVNULL  def _exit_sandbox(code: int) -> NoReturn:      """ +    Exit the sandbox by printing the result to the actual stdout and exit with the provided code. +      Codes:      - 0: Executed with success      - 5: Syntax error while parsing user code      - 99: Internal error      """ -    result_content = RESULT.getvalue() - -    print( -        f"{result_content}", -        file=ORIGINAL_STDOUT -    ) +    print(RESULT.getvalue(), file=ORIGINAL_STDOUT, end="")      sys.exit(code)  def _load_user_module() -> ModuleType: +    """Load the user code into a new module and return it."""      try:          ast.parse(USER_CODE, "<input>")      except SyntaxError: @@ -74,6 +74,7 @@ def _main() -> None:  try: +    # Load the user code as a global module variable      module = _load_user_module()      _main()  except Exception: | 
