aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes
diff options
context:
space:
mode:
Diffstat (limited to 'backend/routes')
-rw-r--r--backend/routes/auth/authorize.py121
-rw-r--r--backend/routes/forms/form.py6
-rw-r--r--backend/routes/forms/submit.py58
-rw-r--r--backend/routes/forms/unittesting.py127
4 files changed, 291 insertions, 21 deletions
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 975936a..d4587f0 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,26 +2,101 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import datetime
+from typing import Union
+
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette import responses
+from starlette.authentication import requires
from starlette.requests import Request
-from starlette.responses import JSONResponse
+from backend import constants
+from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.route import Route
from backend.validation import ErrorMessage, api
+AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)
+
class AuthorizeRequest(BaseModel):
token: str = Field(description="The access token received from Discord.")
class AuthorizeResponse(BaseModel):
- token: str = Field(description="A JWT token containing the user information")
+ username: str = Field("Discord display name.")
+ expiry: str = Field("ISO formatted timestamp of expiry.")
+
+
+async def process_token(
+ bearer_token: dict,
+ request: Request
+) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ """Post a bearer token to Discord, and return a JWT and username."""
+ interaction_start = datetime.datetime.now()
+
+ try:
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ AUTH_FAILURE.delete_cookie("token")
+ return AUTH_FAILURE
+
+ max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
+ token_expiry = interaction_start + max_age
+
+ data = {
+ "token": bearer_token["access_token"],
+ "refresh": bearer_token["refresh_token"],
+ "user_details": user_details,
+ "expiry": token_expiry.isoformat()
+ }
+
+ token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
+ user = User(token, user_details)
+
+ response = responses.JSONResponse({
+ "username": user.display_name,
+ "expiry": token_expiry.isoformat()
+ })
+
+ await set_response_token(response, request, token, bearer_token["expires_in"])
+ return response
+
+
+async def set_response_token(
+ response: responses.Response,
+ request: Request,
+ new_token: str,
+ expiry: int
+) -> None:
+ """Helper that handles logic for updating a token in a set-cookie response."""
+ origin_url = request.headers.get("origin")
+
+ if origin_url == constants.PRODUCTION_URL:
+ domain = request.url.netloc
+ samesite = "strict"
+
+ elif not constants.PRODUCTION:
+ domain = None
+ samesite = "strict"
+
+ else:
+ domain = request.url.netloc
+ samesite = "None"
+
+ response.set_cookie(
+ "token", f"JWT {new_token}",
+ secure=constants.PRODUCTION,
+ httponly=True,
+ samesite=samesite,
+ domain=domain,
+ max_age=expiry
+ )
class AuthorizeRoute(Route):
@@ -37,22 +112,38 @@ class AuthorizeRoute(Route):
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
- async def post(self, request: Request) -> JSONResponse:
+ async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
data = await request.json()
-
try:
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
- return JSONResponse({
- "error": "auth_failure"
- }, status_code=400)
+ return AUTH_FAILURE
- user_details["admin"] = await request.state.db.admins.find_one(
- {"_id": user_details["id"]}
- ) is not None
+ return await process_token(bearer_token, request)
+
+
+class TokenRefreshRoute(Route):
+ """
+ Use the refresh code from a JWT to get a new token and generate a new JWT token.
+ """
- token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
+ name = "refresh"
+ path = "/refresh"
+
+ @requires(["authenticated"])
+ @api.validate(
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
+ async def post(self, request: Request) -> responses.JSONResponse:
+ """Refresh an authorization token."""
+ try:
+ token = request.user.decoded_token.get("refresh")
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(token, url, refresh=True)
+ except httpx.HTTPStatusError:
+ return AUTH_FAILURE
- return JSONResponse({"token": token})
+ return await process_token(bearer_token, request)
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index b6b722e..1c6e44a 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -10,6 +10,7 @@ from starlette.responses import JSONResponse
from backend.models import Form
from backend.route import Route
+from backend.routes.forms.unittesting import filter_unittests
from backend.validation import ErrorMessage, OkayResponse, api
@@ -26,7 +27,7 @@ class SingleForm(Route):
@api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
- admin = request.user.payload["admin"] if request.user.is_authenticated else False # noqa
+ admin = request.user.admin if request.user.is_authenticated else False
filters = {
"_id": request.path_params["form_id"]
@@ -37,6 +38,9 @@ class SingleForm(Route):
if raw_form := await request.state.db.forms.find_one(filters):
form = Form(**raw_form)
+ if not admin:
+ form = filter_unittests(form)
+
return JSONResponse(form.dict(admin=admin))
return JSONResponse({"error": "not_found"}, status_code=404)
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 37f76e0..23444a0 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -4,6 +4,7 @@ Submit a form.
import asyncio
import binascii
+import datetime
import hashlib
import uuid
from typing import Any, Optional
@@ -17,10 +18,12 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
from backend import constants
-from backend.authentication import User
+from backend.authentication.user import User
from backend.models import Form, FormResponse
from backend.route import Route
-from backend.validation import AuthorizationHeaders, ErrorMessage, api
+from backend.routes.auth.authorize import set_response_token
+from backend.routes.forms.unittesting import execute_unittest
+from backend.validation import ErrorMessage, api
HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"
HCAPTCHA_HEADERS = {
@@ -57,13 +60,37 @@ class SubmitForm(Route):
HTTP_404=ErrorMessage,
HTTP_400=ErrorMessage
),
- headers=AuthorizationHeaders,
tags=["forms", "responses"]
)
async def post(self, request: Request) -> JSONResponse:
"""Submit a response to the form."""
- data = await request.json()
+ response = await self.submit(request)
+
+ # Silently try to update user data
+ try:
+ if hasattr(request.user, User.refresh_data.__name__):
+ old = request.user.token
+ await request.user.refresh_data()
+
+ if old != request.user.token:
+ try:
+ expiry = datetime.datetime.fromisoformat(
+ request.user.decoded_token.get("expiry")
+ )
+ except ValueError:
+ expiry = None
+
+ expiry_seconds = (expiry - datetime.datetime.now()).seconds
+ await set_response_token(response, request, request.user.token, expiry_seconds)
+
+ except httpx.HTTPStatusError:
+ pass
+ return response
+
+ async def submit(self, request: Request) -> JSONResponse:
+ """Helper method for handling submission logic."""
+ data = await request.json()
data["timestamp"] = None
if form := await request.state.db.forms.find_one(
@@ -104,8 +131,12 @@ class SubmitForm(Route):
if constants.FormFeatures.REQUIRES_LOGIN.value in form.features:
if request.user.is_authenticated:
response["user"] = request.user.payload
+ response["user"]["admin"] = request.user.admin
- if constants.FormFeatures.COLLECT_EMAIL.value in form.features and "email" not in response["user"]: # noqa
+ if (
+ constants.FormFeatures.COLLECT_EMAIL.value in form.features
+ and "email" not in response["user"]
+ ):
return JSONResponse({
"error": "email_required"
}, status_code=400)
@@ -133,6 +164,23 @@ class SubmitForm(Route):
except ValidationError as e:
return JSONResponse(e.errors(), status_code=422)
+ # Run unittests if needed
+ if any("unittests" in question.data for question in form.questions):
+ unittest_results = await execute_unittest(response_obj, form)
+
+ if not all(test.passed for test in unittest_results):
+ # Return 500 if we encountered an internal error (code 99).
+ status_code = 500 if any(
+ test.return_code == 99 for test in unittest_results
+ ) else 403
+
+ return JSONResponse({
+ "error": "failed_tests",
+ "test_results": [
+ test._asdict() for test in unittest_results if not test.passed
+ ]
+ }, status_code=status_code)
+
await request.state.db.responses.insert_one(
response_obj.dict(by_alias=True)
)
diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py
new file mode 100644
index 0000000..3854314
--- /dev/null
+++ b/backend/routes/forms/unittesting.py
@@ -0,0 +1,127 @@
+import base64
+from collections import namedtuple
+from itertools import count
+from textwrap import indent
+
+import httpx
+from httpx import HTTPStatusError
+
+from backend.constants import SNEKBOX_URL
+from backend.models import FormResponse, Form
+
+with open("resources/unittest_template.py") as file:
+ TEST_TEMPLATE = file.read()
+
+
+UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result")
+
+
+def filter_unittests(form: Form) -> Form:
+ """
+ Replace the unittest data section of code questions with the number of test cases.
+
+ This is used to redact the exact tests when sending the form back to the frontend.
+ """
+ for question in form.questions:
+ if question.type == "code" and "unittests" in question.data:
+ question.data["unittests"] = len(question.data["unittests"])
+
+ return form
+
+
+def _make_unit_code(units: dict[str, str]) -> str:
+ """Compose a dict mapping unit names to their code into an actual class body."""
+ result = ""
+
+ for unit_name, unit_code in units.items():
+ result += (
+ f"\ndef test_{unit_name.lstrip('#')}(unit):" # Function definition
+ f"\n{indent(unit_code, ' ')}" # Unit code
+ )
+
+ return indent(result, " ")
+
+
+def _make_user_code(code: str) -> str:
+ """Compose the user code into an actual base64-encoded string variable."""
+ code = base64.b64encode(code.encode("utf8")).decode("utf8")
+ return f'USER_CODE = b"{code}"'
+
+
+async def _post_eval(code: str) -> dict[str, str]:
+ """Post the eval to snekbox and return the response."""
+ async with httpx.AsyncClient() as client:
+ data = {"input": code}
+ response = await client.post(SNEKBOX_URL, json=data, timeout=10)
+
+ response.raise_for_status()
+ return response.json()
+
+
+async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]:
+ """Execute all the unittests in this form and return the results."""
+ unittest_results = []
+
+ for question in form.questions:
+ if question.type == "code" and "unittests" in question.data:
+ passed = False
+
+ # Tests starting with an hashtag should have censored names.
+ hidden_test_counter = count(1)
+ hidden_tests = {
+ test.lstrip("#").lstrip("test_"): next(hidden_test_counter)
+ for test in question.data["unittests"].keys()
+ if test.startswith("#")
+ }
+
+ # Compose runner code
+ unit_code = _make_unit_code(question.data["unittests"])
+ user_code = _make_user_code(form_response.response[question.id])
+
+ code = TEST_TEMPLATE.replace("### USER CODE", user_code)
+ code = code.replace("### UNIT CODE", unit_code)
+
+ try:
+ response = await _post_eval(code)
+ except HTTPStatusError:
+ return_code = 99
+ result = "Unable to contact code runner."
+ else:
+ return_code = int(response["returncode"])
+
+ # Parse the stdout if the tests ran successfully
+ if return_code == 0:
+ stdout = response["stdout"]
+ passed = bool(int(stdout[0]))
+
+ # If the test failed, we have to populate the result string.
+ if not passed:
+ failed_tests = stdout[1:].strip().split(";")
+
+ # Redact failed hidden tests
+ for i, failed_test in enumerate(failed_tests.copy()):
+ if failed_test in hidden_tests:
+ failed_tests[i] = f"hidden_test_{hidden_tests[failed_test]}"
+
+ result = ";".join(failed_tests)
+ else:
+ result = ""
+ elif return_code in (5, 6, 99):
+ result = response["stdout"]
+ # Killed by NsJail
+ elif return_code == 137:
+ return_code = 7
+ result = "Timed out or ran out of memory."
+ # Another code has been returned by CPython because of another failure.
+ else:
+ return_code = 99
+ result = "Internal error."
+
+ unittest_results.append(UnittestResult(
+ question_id=question.id,
+ return_code=return_code,
+ passed=passed,
+ result=result
+ ))
+
+ return unittest_results