diff options
| -rw-r--r-- | .github/CODEOWNERS | 4 | ||||
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | SCHEMA.md | 7 | ||||
| -rw-r--r-- | backend/constants.py | 9 | ||||
| -rw-r--r-- | backend/models/form.py | 2 | ||||
| -rw-r--r-- | backend/routes/forms/form.py | 4 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 23 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 127 | ||||
| -rw-r--r-- | docker-compose.yml | 9 | ||||
| -rw-r--r-- | resources/unittest_template.py | 90 | ||||
| -rw-r--r-- | tox.ini | 4 | 
11 files changed, 271 insertions, 10 deletions
| diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 483fdc7..3fa665c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ -# Request joe and ks123 for all PRs. -* @jb3 @ks129 +# Request joe, ks123, HassanAbouelela for all PRs. +* @jb3 @ks129 @HassanAbouelela @@ -8,6 +8,7 @@ To start working on forms-backend, you'll need few things:  2. Poetry  3. Docker and docker-compose (optional)  4. Running MongoDB instance (when not using Docker) +5. Running [Snekbox](https://git.pydis.com/snekbox) instance (when not using Docker, optional)  ### Running with Docker  The easiest way to run forms-backend is using Docker (and docker-compose). @@ -30,6 +31,7 @@ Create a `.env` file with the same contents as the Docker section above and the  - `FRONTEND_URL`: Forms frontend URL.  - `DATABASE_URL`: MongoDB instance URI, in format `mongodb://(username):(password)@(database IP or domain):(port)`.  - `MONGO_DB`: MongoDB database name, defaults to `pydis_forms`. +- `SNEKBOX_URL`: Snekbox evaluation endpoint.  #### Running  Simply run: `$ uvicorn --reload --host 0.0.0.0 --debug backend:app`. @@ -123,7 +123,12 @@ Textareas require no additional configuration.  ```js  {      // A supported language from https://prismjs.com/#supported-languages -    "language": "python" +    "language": "python", +    // An optinal mapping of unit tests +    "unittests": { +        "unit_1": "unit_code()", +        ... +    }  }  ``` diff --git a/backend/constants.py b/backend/constants.py index af25d84..e1f4a5b 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -1,14 +1,15 @@  from dotenv import load_dotenv -load_dotenv() +import os +import binascii +from enum import Enum -import os  # noqa -import binascii  # noqa -from enum import Enum  # noqa +load_dotenv()  FRONTEND_URL = os.getenv("FRONTEND_URL", "https://forms.pythondiscord.com")  DATABASE_URL = os.getenv("DATABASE_URL")  MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms") +SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")  PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false" diff --git a/backend/models/form.py b/backend/models/form.py index 8e59905..eac0b63 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -47,7 +47,7 @@ class Form(BaseModel):          if any(v not in allowed_values for v in value):              raise ValueError("Form features list contains one or more invalid values.") -        if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value:  # noqa +        if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value:              raise ValueError("COLLECT_EMAIL feature require REQUIRES_LOGIN feature.")          return value diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index e3360b1..1c6e44a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -10,6 +10,7 @@ from starlette.responses import JSONResponse  from backend.models import Form  from backend.route import Route +from backend.routes.forms.unittesting import filter_unittests  from backend.validation import ErrorMessage, OkayResponse, api @@ -37,6 +38,9 @@ class SingleForm(Route):          if raw_form := await request.state.db.forms.find_one(filters):              form = Form(**raw_form) +            if not admin: +                form = filter_unittests(form) +              return JSONResponse(form.dict(admin=admin))          return JSONResponse({"error": "not_found"}, status_code=404) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 8627a29..4224586 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -20,6 +20,7 @@ from backend import constants  from backend.authentication.user import User  from backend.models import Form, FormResponse  from backend.route import Route +from backend.routes.forms.unittesting import execute_unittest  from backend.validation import ErrorMessage, api  HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" @@ -129,7 +130,10 @@ class SubmitForm(Route):                      response["user"] = request.user.payload                      response["user"]["admin"] = request.user.admin -                    if constants.FormFeatures.COLLECT_EMAIL.value in form.features and "email" not in response["user"]:  # noqa +                    if ( +                            constants.FormFeatures.COLLECT_EMAIL.value in form.features +                            and "email" not in response["user"] +                    ):                          return JSONResponse({                              "error": "email_required"                          }, status_code=400) @@ -157,6 +161,23 @@ class SubmitForm(Route):              except ValidationError as e:                  return JSONResponse(e.errors(), status_code=422) +            # Run unittests if needed +            if any("unittests" in question.data for question in form.questions): +                unittest_results = await execute_unittest(response_obj, form) + +                if not all(test.passed for test in unittest_results): +                    # Return 500 if we encountered an internal error (code 99). +                    status_code = 500 if any( +                        test.return_code == 99 for test in unittest_results +                    ) else 403 + +                    return JSONResponse({ +                        "error": "failed_tests", +                        "test_results": [ +                            test._asdict() for test in unittest_results if not test.passed +                        ] +                    }, status_code=status_code) +              await request.state.db.responses.insert_one(                  response_obj.dict(by_alias=True)              ) diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py new file mode 100644 index 0000000..3854314 --- /dev/null +++ b/backend/routes/forms/unittesting.py @@ -0,0 +1,127 @@ +import base64 +from collections import namedtuple +from itertools import count +from textwrap import indent + +import httpx +from httpx import HTTPStatusError + +from backend.constants import SNEKBOX_URL +from backend.models import FormResponse, Form + +with open("resources/unittest_template.py") as file: +    TEST_TEMPLATE = file.read() + + +UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result") + + +def filter_unittests(form: Form) -> Form: +    """ +    Replace the unittest data section of code questions with the number of test cases. + +    This is used to redact the exact tests when sending the form back to the frontend. +    """ +    for question in form.questions: +        if question.type == "code" and "unittests" in question.data: +            question.data["unittests"] = len(question.data["unittests"]) + +    return form + + +def _make_unit_code(units: dict[str, str]) -> str: +    """Compose a dict mapping unit names to their code into an actual class body.""" +    result = "" + +    for unit_name, unit_code in units.items(): +        result += ( +            f"\ndef test_{unit_name.lstrip('#')}(unit):"  # Function definition +            f"\n{indent(unit_code, '    ')}"  # Unit code +        ) + +    return indent(result, "    ") + + +def _make_user_code(code: str) -> str: +    """Compose the user code into an actual base64-encoded string variable.""" +    code = base64.b64encode(code.encode("utf8")).decode("utf8") +    return f'USER_CODE = b"{code}"' + + +async def _post_eval(code: str) -> dict[str, str]: +    """Post the eval to snekbox and return the response.""" +    async with httpx.AsyncClient() as client: +        data = {"input": code} +        response = await client.post(SNEKBOX_URL, json=data, timeout=10) + +        response.raise_for_status() +        return response.json() + + +async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]: +    """Execute all the unittests in this form and return the results.""" +    unittest_results = [] + +    for question in form.questions: +        if question.type == "code" and "unittests" in question.data: +            passed = False + +            # Tests starting with an hashtag should have censored names. +            hidden_test_counter = count(1) +            hidden_tests = { +                test.lstrip("#").lstrip("test_"): next(hidden_test_counter) +                for test in question.data["unittests"].keys() +                if test.startswith("#") +            } + +            # Compose runner code +            unit_code = _make_unit_code(question.data["unittests"]) +            user_code = _make_user_code(form_response.response[question.id]) + +            code = TEST_TEMPLATE.replace("### USER CODE", user_code) +            code = code.replace("### UNIT CODE", unit_code) + +            try: +                response = await _post_eval(code) +            except HTTPStatusError: +                return_code = 99 +                result = "Unable to contact code runner." +            else: +                return_code = int(response["returncode"]) + +                # Parse the stdout if the tests ran successfully +                if return_code == 0: +                    stdout = response["stdout"] +                    passed = bool(int(stdout[0])) + +                    # If the test failed, we have to populate the result string. +                    if not passed: +                        failed_tests = stdout[1:].strip().split(";") + +                        # Redact failed hidden tests +                        for i, failed_test in enumerate(failed_tests.copy()): +                            if failed_test in hidden_tests: +                                failed_tests[i] = f"hidden_test_{hidden_tests[failed_test]}" + +                        result = ";".join(failed_tests) +                    else: +                        result = "" +                elif return_code in (5, 6, 99): +                    result = response["stdout"] +                # Killed by NsJail +                elif return_code == 137: +                    return_code = 7 +                    result = "Timed out or ran out of memory." +                # Another code has been returned by CPython because of another failure. +                else: +                    return_code = 99 +                    result = "Internal error." + +            unittest_results.append(UnittestResult( +                question_id=question.id, +                return_code=return_code, +                passed=passed, +                result=result +            )) + +    return unittest_results diff --git a/docker-compose.yml b/docker-compose.yml index d44b4e0..4e58ef7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,13 @@ services:        MONGO_INITDB_ROOT_PASSWORD: forms-backend        MONGO_INITDB_DATABASE: pydis_forms +  snekbox: +    image: ghcr.io/python-discord/snekbox:latest +    ipc: none +    ports: +      - "127.0.0.1:8060:8060" +    privileged: true +    backend:      build:        context: . @@ -19,11 +26,13 @@ services:        - "127.0.0.1:8000:8000"      depends_on:        - mongo +      - snekbox      tty: true      volumes:        - .:/app:ro      environment:        - DATABASE_URL=mongodb://forms-backend:forms-backend@mongo:27017 +      - SNEKBOX_URL=http://snekbox:8060/eval        - OAUTH2_CLIENT_ID        - OAUTH2_CLIENT_SECRET        - ALLOWED_URL diff --git a/resources/unittest_template.py b/resources/unittest_template.py new file mode 100644 index 0000000..2410278 --- /dev/null +++ b/resources/unittest_template.py @@ -0,0 +1,90 @@ +# flake8: noqa +"""This template is used inside snekbox to evaluate and test user code.""" +import ast +import base64 +import io +import os +import sys +import traceback +import unittest +from itertools import chain +from types import ModuleType, SimpleNamespace +from typing import NoReturn +from unittest import mock + +### USER CODE + + +class RunnerTestCase(unittest.TestCase): +### UNIT CODE + + +def _exit_sandbox(code: int) -> NoReturn: +    """ +    Exit the sandbox by printing the result to the actual stdout and exit with the provided code. + +    Codes: +    - 0: Executed with success +    - 5: Syntax error while parsing user code +    - 6: Uncaught exception while loading user code +    - 99: Internal error + +    137 can also be generated by NsJail when killing the process. +    """ +    print(RESULT.getvalue(), file=ORIGINAL_STDOUT, end="") +    sys.exit(code) + + +def _load_user_module() -> ModuleType: +    """Load the user code into a new module and return it.""" +    code = base64.b64decode(USER_CODE).decode("utf8") +    try: +        ast.parse(code, "<input>") +    except SyntaxError: +        RESULT.write("".join(traceback.format_exception(*sys.exc_info(), limit=0))) +        _exit_sandbox(5) + +    _module = ModuleType("module") +    exec(code, _module.__dict__) + +    return _module + + +def _main() -> None: +    suite = unittest.defaultTestLoader.loadTestsFromTestCase(RunnerTestCase) +    result = suite.run(unittest.TestResult()) + +    RESULT.write(str(int(result.wasSuccessful()))) + +    if not result.wasSuccessful(): +        RESULT.write( +            ";".join(chain( +                (error[0]._testMethodName.lstrip("test_") for error in result.errors), +                (failure[0]._testMethodName.lstrip("test_") for failure in result.failures) +            )) +        ) + +    _exit_sandbox(0) + + +try: +    # Fake file object not writing anything +    DEVNULL = SimpleNamespace(write=lambda *_: None, flush=lambda *_: None) + +    RESULT = io.StringIO() +    ORIGINAL_STDOUT = sys.stdout + +    # stdout/err is patched in order to control what is outputted by the runner +    sys.stdout = DEVNULL +    sys.stderr = DEVNULL +     +    # Load the user code as a global module variable +    try: +        module = _load_user_module() +    except Exception: +        RESULT.write("Uncaught exception while loading user code.") +        _exit_sandbox(6) +    _main() +except Exception: +    RESULT.write("Uncaught exception inside runner.") +    _exit_sandbox(99) @@ -1,8 +1,10 @@  [flake8] -max-line-length=88 +max-line-length=100  exclude=.cache,.venv,.git  docstring-convention=all  import-order-style=pycharm  ignore=      # Type annotations      ANN101,ANN102 +    # Line breaks +    W503 | 
