diff options
author | 2021-03-09 08:48:58 +0200 | |
---|---|---|
committer | 2021-03-09 08:48:58 +0200 | |
commit | 0ec4a370d476f3f8b7453c887b0b02fe83aced9c (patch) | |
tree | e7fb0e7a71369affd79222445a35438815ad4cd3 /backend | |
parent | Add missing "is" to error message (diff) | |
parent | Fixes Production URL Constant (diff) |
Merge branch 'main' into ks123/role-assigning
Diffstat (limited to 'backend')
-rw-r--r-- | backend/__init__.py | 21 | ||||
-rw-r--r-- | backend/authentication/backend.py | 37 | ||||
-rw-r--r-- | backend/authentication/user.py | 26 | ||||
-rw-r--r-- | backend/constants.py | 13 | ||||
-rw-r--r-- | backend/discord.py | 15 | ||||
-rw-r--r-- | backend/routes/auth/authorize.py | 121 | ||||
-rw-r--r-- | backend/routes/forms/form.py | 6 | ||||
-rw-r--r-- | backend/routes/forms/submit.py | 58 | ||||
-rw-r--r-- | backend/routes/forms/unittesting.py | 127 | ||||
-rw-r--r-- | backend/validation.py | 11 |
10 files changed, 378 insertions, 57 deletions
diff --git a/backend/__init__.py b/backend/__init__.py index a3704a0..220b457 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -7,10 +7,21 @@ from starlette.middleware.cors import CORSMiddleware from backend import constants from backend.authentication import JWTAuthenticationBackend -from backend.route_manager import create_route_map from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware +from backend.route_manager import create_route_map from backend.validation import api +ORIGINS = [ + r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews + r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews +] + +if not constants.PRODUCTION: + # Allow all hosts on non-production deployments + ORIGINS.append(r"(.*)") + +ALLOW_ORIGIN_REGEX = "|".join(ORIGINS) + sentry_sdk.init( dsn=constants.FORMS_BACKEND_DSN, send_default_pii=True, @@ -20,13 +31,13 @@ sentry_sdk.init( middleware = [ Middleware( CORSMiddleware, - # TODO: Convert this into a RegEx that works for prod, netlify & previews - allow_origins=["*"], + allow_origins=["https://forms.pythondiscord.com"], + allow_origin_regex=ALLOW_ORIGIN_REGEX, allow_headers=[ - "Authorization", "Content-Type" ], - allow_methods=["*"] + allow_methods=["*"], + allow_credentials=True ), Middleware(DatabaseMiddleware), Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()), diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index f1d2ece..c7590e9 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,6 @@ -import jwt import typing as t +import jwt from starlette import authentication from starlette.requests import Request @@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): """Custom Starlette authentication backend for JWT.""" @staticmethod - def get_token_from_header(header: str) -> str: - """Parse JWT token from header value.""" + def get_token_from_cookie(cookie: str) -> str: + """Parse JWT token from cookie.""" try: - prefix, token = header.split() + prefix, token = cookie.split() except ValueError: raise authentication.AuthenticationError( - "Unable to split prefix and token from Authorization header." + "Unable to split prefix and token from authorization cookie." ) if prefix.upper() != "JWT": raise authentication.AuthenticationError( - f"Invalid Authorization header prefix '{prefix}'." + f"Invalid authorization cookie prefix '{prefix}'." ) return token @@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): self, request: Request ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]: """Handles JWT authentication process.""" - if "Authorization" not in request.headers: + cookie = request.cookies.get("token") + if not cookie: return None - auth = request.headers["Authorization"] - token = self.get_token_from_header(auth) + token = self.get_token_from_cookie(cookie) try: payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) @@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): scopes = ["authenticated"] - if payload.get("admin") is True: + if not payload.get("token"): + raise authentication.AuthenticationError("Token is missing from JWT.") + if not payload.get("refresh"): + raise authentication.AuthenticationError( + "Refresh token is missing from JWT." + ) + + try: + user_details = payload.get("user_details") + if not user_details or not user_details.get("id"): + raise authentication.AuthenticationError("Improper user details.") + except Exception: + raise authentication.AuthenticationError("Could not parse user details.") + + user = User(token, user_details) + if await user.fetch_admin_status(request): scopes.append("admin") - return authentication.AuthCredentials(scopes), User(token, payload) + return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index f40c68c..857c2ed 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,6 +1,11 @@ import typing as t +import jwt from starlette.authentication import BaseUser +from starlette.requests import Request + +from backend.constants import SECRET_KEY +from backend.discord import fetch_user_details class User(BaseUser): @@ -9,6 +14,7 @@ class User(BaseUser): def __init__(self, token: str, payload: dict[str, t.Any]) -> None: self.token = token self.payload = payload + self.admin = False @property def is_authenticated(self) -> bool: @@ -23,3 +29,23 @@ class User(BaseUser): @property def discord_mention(self) -> str: return f"<@{self.payload['id']}>" + + @property + def decoded_token(self) -> dict[str, any]: + return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) + + async def fetch_admin_status(self, request: Request) -> bool: + self.admin = await request.state.db.admins.find_one( + {"_id": self.payload["id"]} + ) is not None + + return self.admin + + async def refresh_data(self) -> None: + """Fetches user data from discord, and updates the instance.""" + self.payload = await fetch_user_details(self.decoded_token.get("token")) + + updated_info = self.decoded_token + updated_info["user_details"] = self.payload + + self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256") diff --git a/backend/constants.py b/backend/constants.py index 812bef4..7ea4519 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -1,14 +1,19 @@ +import binascii +import os +from enum import Enum + from dotenv import load_dotenv -load_dotenv() -import os # noqa -import binascii # noqa -from enum import Enum # noqa +load_dotenv() FRONTEND_URL = os.getenv("FRONTEND_URL", "https://forms.pythondiscord.com") DATABASE_URL = os.getenv("DATABASE_URL") MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms") +SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval") + +PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false" +PRODUCTION_URL = "https://forms.pythondiscord.com" OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID") OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET") diff --git a/backend/discord.py b/backend/discord.py index 1dc8ed7..e5c7f8f 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -2,20 +2,25 @@ import httpx from backend.constants import ( - DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET, OAUTH2_REDIRECT_URI + DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET ) -async def fetch_bearer_token(access_code: str) -> dict: +async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict: async with httpx.AsyncClient() as client: data = { "client_id": OAUTH2_CLIENT_ID, "client_secret": OAUTH2_CLIENT_SECRET, - "grant_type": "authorization_code", - "code": access_code, - "redirect_uri": OAUTH2_REDIRECT_URI + "redirect_uri": f"{redirect}/callback" } + if refresh: + data["grant_type"] = "refresh_token" + data["refresh_token"] = code + else: + data["grant_type"] = "authorization_code" + data["code"] = code + r = await client.post(f"{DISCORD_API_BASE_URL}/oauth2/token", headers={ "Content-Type": "application/x-www-form-urlencoded" }, data=data) diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 975936a..d4587f0 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -2,26 +2,101 @@ Use a token received from the Discord OAuth2 system to fetch user information. """ +import datetime +from typing import Union + import httpx import jwt from pydantic.fields import Field from pydantic.main import BaseModel from spectree.response import Response +from starlette import responses +from starlette.authentication import requires from starlette.requests import Request -from starlette.responses import JSONResponse +from backend import constants +from backend.authentication.user import User from backend.constants import SECRET_KEY -from backend.route import Route from backend.discord import fetch_bearer_token, fetch_user_details +from backend.route import Route from backend.validation import ErrorMessage, api +AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400) + class AuthorizeRequest(BaseModel): token: str = Field(description="The access token received from Discord.") class AuthorizeResponse(BaseModel): - token: str = Field(description="A JWT token containing the user information") + username: str = Field("Discord display name.") + expiry: str = Field("ISO formatted timestamp of expiry.") + + +async def process_token( + bearer_token: dict, + request: Request +) -> Union[AuthorizeResponse, AUTH_FAILURE]: + """Post a bearer token to Discord, and return a JWT and username.""" + interaction_start = datetime.datetime.now() + + try: + user_details = await fetch_user_details(bearer_token["access_token"]) + except httpx.HTTPStatusError: + AUTH_FAILURE.delete_cookie("token") + return AUTH_FAILURE + + max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"])) + token_expiry = interaction_start + max_age + + data = { + "token": bearer_token["access_token"], + "refresh": bearer_token["refresh_token"], + "user_details": user_details, + "expiry": token_expiry.isoformat() + } + + token = jwt.encode(data, SECRET_KEY, algorithm="HS256") + user = User(token, user_details) + + response = responses.JSONResponse({ + "username": user.display_name, + "expiry": token_expiry.isoformat() + }) + + await set_response_token(response, request, token, bearer_token["expires_in"]) + return response + + +async def set_response_token( + response: responses.Response, + request: Request, + new_token: str, + expiry: int +) -> None: + """Helper that handles logic for updating a token in a set-cookie response.""" + origin_url = request.headers.get("origin") + + if origin_url == constants.PRODUCTION_URL: + domain = request.url.netloc + samesite = "strict" + + elif not constants.PRODUCTION: + domain = None + samesite = "strict" + + else: + domain = request.url.netloc + samesite = "None" + + response.set_cookie( + "token", f"JWT {new_token}", + secure=constants.PRODUCTION, + httponly=True, + samesite=samesite, + domain=domain, + max_age=expiry + ) class AuthorizeRoute(Route): @@ -37,22 +112,38 @@ class AuthorizeRoute(Route): resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), tags=["auth"] ) - async def post(self, request: Request) -> JSONResponse: + async def post(self, request: Request) -> responses.JSONResponse: """Generate an authorization token.""" data = await request.json() - try: - bearer_token = await fetch_bearer_token(data["token"]) - user_details = await fetch_user_details(bearer_token["access_token"]) + url = request.headers.get("origin") + bearer_token = await fetch_bearer_token(data["token"], url, refresh=False) except httpx.HTTPStatusError: - return JSONResponse({ - "error": "auth_failure" - }, status_code=400) + return AUTH_FAILURE - user_details["admin"] = await request.state.db.admins.find_one( - {"_id": user_details["id"]} - ) is not None + return await process_token(bearer_token, request) + + +class TokenRefreshRoute(Route): + """ + Use the refresh code from a JWT to get a new token and generate a new JWT token. + """ - token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256") + name = "refresh" + path = "/refresh" + + @requires(["authenticated"]) + @api.validate( + resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), + tags=["auth"] + ) + async def post(self, request: Request) -> responses.JSONResponse: + """Refresh an authorization token.""" + try: + token = request.user.decoded_token.get("refresh") + url = request.headers.get("origin") + bearer_token = await fetch_bearer_token(token, url, refresh=True) + except httpx.HTTPStatusError: + return AUTH_FAILURE - return JSONResponse({"token": token}) + return await process_token(bearer_token, request) diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index b6b722e..1c6e44a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -10,6 +10,7 @@ from starlette.responses import JSONResponse from backend.models import Form from backend.route import Route +from backend.routes.forms.unittesting import filter_unittests from backend.validation import ErrorMessage, OkayResponse, api @@ -26,7 +27,7 @@ class SingleForm(Route): @api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """Returns single form information by ID.""" - admin = request.user.payload["admin"] if request.user.is_authenticated else False # noqa + admin = request.user.admin if request.user.is_authenticated else False filters = { "_id": request.path_params["form_id"] @@ -37,6 +38,9 @@ class SingleForm(Route): if raw_form := await request.state.db.forms.find_one(filters): form = Form(**raw_form) + if not admin: + form = filter_unittests(form) + return JSONResponse(form.dict(admin=admin)) return JSONResponse({"error": "not_found"}, status_code=404) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 37f76e0..23444a0 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -4,6 +4,7 @@ Submit a form. import asyncio import binascii +import datetime import hashlib import uuid from typing import Any, Optional @@ -17,10 +18,12 @@ from starlette.requests import Request from starlette.responses import JSONResponse from backend import constants -from backend.authentication import User +from backend.authentication.user import User from backend.models import Form, FormResponse from backend.route import Route -from backend.validation import AuthorizationHeaders, ErrorMessage, api +from backend.routes.auth.authorize import set_response_token +from backend.routes.forms.unittesting import execute_unittest +from backend.validation import ErrorMessage, api HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" HCAPTCHA_HEADERS = { @@ -57,13 +60,37 @@ class SubmitForm(Route): HTTP_404=ErrorMessage, HTTP_400=ErrorMessage ), - headers=AuthorizationHeaders, tags=["forms", "responses"] ) async def post(self, request: Request) -> JSONResponse: """Submit a response to the form.""" - data = await request.json() + response = await self.submit(request) + + # Silently try to update user data + try: + if hasattr(request.user, User.refresh_data.__name__): + old = request.user.token + await request.user.refresh_data() + + if old != request.user.token: + try: + expiry = datetime.datetime.fromisoformat( + request.user.decoded_token.get("expiry") + ) + except ValueError: + expiry = None + + expiry_seconds = (expiry - datetime.datetime.now()).seconds + await set_response_token(response, request, request.user.token, expiry_seconds) + + except httpx.HTTPStatusError: + pass + return response + + async def submit(self, request: Request) -> JSONResponse: + """Helper method for handling submission logic.""" + data = await request.json() data["timestamp"] = None if form := await request.state.db.forms.find_one( @@ -104,8 +131,12 @@ class SubmitForm(Route): if constants.FormFeatures.REQUIRES_LOGIN.value in form.features: if request.user.is_authenticated: response["user"] = request.user.payload + response["user"]["admin"] = request.user.admin - if constants.FormFeatures.COLLECT_EMAIL.value in form.features and "email" not in response["user"]: # noqa + if ( + constants.FormFeatures.COLLECT_EMAIL.value in form.features + and "email" not in response["user"] + ): return JSONResponse({ "error": "email_required" }, status_code=400) @@ -133,6 +164,23 @@ class SubmitForm(Route): except ValidationError as e: return JSONResponse(e.errors(), status_code=422) + # Run unittests if needed + if any("unittests" in question.data for question in form.questions): + unittest_results = await execute_unittest(response_obj, form) + + if not all(test.passed for test in unittest_results): + # Return 500 if we encountered an internal error (code 99). + status_code = 500 if any( + test.return_code == 99 for test in unittest_results + ) else 403 + + return JSONResponse({ + "error": "failed_tests", + "test_results": [ + test._asdict() for test in unittest_results if not test.passed + ] + }, status_code=status_code) + await request.state.db.responses.insert_one( response_obj.dict(by_alias=True) ) diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py new file mode 100644 index 0000000..3854314 --- /dev/null +++ b/backend/routes/forms/unittesting.py @@ -0,0 +1,127 @@ +import base64 +from collections import namedtuple +from itertools import count +from textwrap import indent + +import httpx +from httpx import HTTPStatusError + +from backend.constants import SNEKBOX_URL +from backend.models import FormResponse, Form + +with open("resources/unittest_template.py") as file: + TEST_TEMPLATE = file.read() + + +UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result") + + +def filter_unittests(form: Form) -> Form: + """ + Replace the unittest data section of code questions with the number of test cases. + + This is used to redact the exact tests when sending the form back to the frontend. + """ + for question in form.questions: + if question.type == "code" and "unittests" in question.data: + question.data["unittests"] = len(question.data["unittests"]) + + return form + + +def _make_unit_code(units: dict[str, str]) -> str: + """Compose a dict mapping unit names to their code into an actual class body.""" + result = "" + + for unit_name, unit_code in units.items(): + result += ( + f"\ndef test_{unit_name.lstrip('#')}(unit):" # Function definition + f"\n{indent(unit_code, ' ')}" # Unit code + ) + + return indent(result, " ") + + +def _make_user_code(code: str) -> str: + """Compose the user code into an actual base64-encoded string variable.""" + code = base64.b64encode(code.encode("utf8")).decode("utf8") + return f'USER_CODE = b"{code}"' + + +async def _post_eval(code: str) -> dict[str, str]: + """Post the eval to snekbox and return the response.""" + async with httpx.AsyncClient() as client: + data = {"input": code} + response = await client.post(SNEKBOX_URL, json=data, timeout=10) + + response.raise_for_status() + return response.json() + + +async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]: + """Execute all the unittests in this form and return the results.""" + unittest_results = [] + + for question in form.questions: + if question.type == "code" and "unittests" in question.data: + passed = False + + # Tests starting with an hashtag should have censored names. + hidden_test_counter = count(1) + hidden_tests = { + test.lstrip("#").lstrip("test_"): next(hidden_test_counter) + for test in question.data["unittests"].keys() + if test.startswith("#") + } + + # Compose runner code + unit_code = _make_unit_code(question.data["unittests"]) + user_code = _make_user_code(form_response.response[question.id]) + + code = TEST_TEMPLATE.replace("### USER CODE", user_code) + code = code.replace("### UNIT CODE", unit_code) + + try: + response = await _post_eval(code) + except HTTPStatusError: + return_code = 99 + result = "Unable to contact code runner." + else: + return_code = int(response["returncode"]) + + # Parse the stdout if the tests ran successfully + if return_code == 0: + stdout = response["stdout"] + passed = bool(int(stdout[0])) + + # If the test failed, we have to populate the result string. + if not passed: + failed_tests = stdout[1:].strip().split(";") + + # Redact failed hidden tests + for i, failed_test in enumerate(failed_tests.copy()): + if failed_test in hidden_tests: + failed_tests[i] = f"hidden_test_{hidden_tests[failed_test]}" + + result = ";".join(failed_tests) + else: + result = "" + elif return_code in (5, 6, 99): + result = response["stdout"] + # Killed by NsJail + elif return_code == 137: + return_code = 7 + result = "Timed out or ran out of memory." + # Another code has been returned by CPython because of another failure. + else: + return_code = 99 + result = "Internal error." + + unittest_results.append(UnittestResult( + question_id=question.id, + return_code=return_code, + passed=passed, + result=result + )) + + return unittest_results diff --git a/backend/validation.py b/backend/validation.py index e696683..8771924 100644 --- a/backend/validation.py +++ b/backend/validation.py @@ -1,6 +1,5 @@ """Utilities for providing API payload validation.""" -from typing import Optional from pydantic.fields import Field from pydantic.main import BaseModel from spectree import SpecTree @@ -18,13 +17,3 @@ class ErrorMessage(BaseModel): class OkayResponse(BaseModel): status: str = "ok" - - -class AuthorizationHeaders(BaseModel): - authorization: Optional[str] = Field( - title="Authorization", - description=( - "The Authorization JWT token received from the " - "authorize route in the format `JWT {token}`" - ) - ) |