diff options
Diffstat (limited to '')
| -rw-r--r-- | backend/__init__.py | 21 | ||||
| -rw-r--r-- | backend/authentication/backend.py | 37 | ||||
| -rw-r--r-- | backend/authentication/user.py | 26 | ||||
| -rw-r--r-- | backend/constants.py | 13 | ||||
| -rw-r--r-- | backend/discord.py | 15 | ||||
| -rw-r--r-- | backend/routes/auth/authorize.py | 121 | ||||
| -rw-r--r-- | backend/routes/forms/form.py | 6 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 58 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 127 | ||||
| -rw-r--r-- | backend/validation.py | 11 | 
10 files changed, 378 insertions, 57 deletions
diff --git a/backend/__init__.py b/backend/__init__.py index a3704a0..220b457 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -7,10 +7,21 @@ from starlette.middleware.cors import CORSMiddleware  from backend import constants  from backend.authentication import JWTAuthenticationBackend -from backend.route_manager import create_route_map  from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware +from backend.route_manager import create_route_map  from backend.validation import api +ORIGINS = [ +    r"(https://[^.?#]*--pydis-forms\.netlify\.app)",  # Netlify Previews +    r"(https?://[^.?#]*.forms-frontend.pages.dev)",  # Cloudflare Previews +] + +if not constants.PRODUCTION: +    # Allow all hosts on non-production deployments +    ORIGINS.append(r"(.*)") + +ALLOW_ORIGIN_REGEX = "|".join(ORIGINS) +  sentry_sdk.init(      dsn=constants.FORMS_BACKEND_DSN,      send_default_pii=True, @@ -20,13 +31,13 @@ sentry_sdk.init(  middleware = [      Middleware(          CORSMiddleware, -        # TODO: Convert this into a RegEx that works for prod, netlify & previews -        allow_origins=["*"], +        allow_origins=["https://forms.pythondiscord.com"], +        allow_origin_regex=ALLOW_ORIGIN_REGEX,          allow_headers=[ -            "Authorization",              "Content-Type"          ], -        allow_methods=["*"] +        allow_methods=["*"], +        allow_credentials=True      ),      Middleware(DatabaseMiddleware),      Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()), diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index f1d2ece..c7590e9 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,6 @@ -import jwt  import typing as t +import jwt  from starlette import authentication  from starlette.requests import Request @@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):      """Custom Starlette authentication backend for JWT."""      @staticmethod -    def get_token_from_header(header: str) -> str: -        """Parse JWT token from header value.""" +    def get_token_from_cookie(cookie: str) -> str: +        """Parse JWT token from cookie."""          try: -            prefix, token = header.split() +            prefix, token = cookie.split()          except ValueError:              raise authentication.AuthenticationError( -                "Unable to split prefix and token from Authorization header." +                "Unable to split prefix and token from authorization cookie."              )          if prefix.upper() != "JWT":              raise authentication.AuthenticationError( -                f"Invalid Authorization header prefix '{prefix}'." +                f"Invalid authorization cookie prefix '{prefix}'."              )          return token @@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          self, request: Request      ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:          """Handles JWT authentication process.""" -        if "Authorization" not in request.headers: +        cookie = request.cookies.get("token") +        if not cookie:              return None -        auth = request.headers["Authorization"] -        token = self.get_token_from_header(auth) +        token = self.get_token_from_cookie(cookie)          try:              payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) @@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          scopes = ["authenticated"] -        if payload.get("admin") is True: +        if not payload.get("token"): +            raise authentication.AuthenticationError("Token is missing from JWT.") +        if not payload.get("refresh"): +            raise authentication.AuthenticationError( +                "Refresh token is missing from JWT." +            ) + +        try: +            user_details = payload.get("user_details") +            if not user_details or not user_details.get("id"): +                raise authentication.AuthenticationError("Improper user details.") +        except Exception: +            raise authentication.AuthenticationError("Could not parse user details.") + +        user = User(token, user_details) +        if await user.fetch_admin_status(request):              scopes.append("admin") -        return authentication.AuthCredentials(scopes), User(token, payload) +        return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index f40c68c..857c2ed 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,6 +1,11 @@  import typing as t +import jwt  from starlette.authentication import BaseUser +from starlette.requests import Request + +from backend.constants import SECRET_KEY +from backend.discord import fetch_user_details  class User(BaseUser): @@ -9,6 +14,7 @@ class User(BaseUser):      def __init__(self, token: str, payload: dict[str, t.Any]) -> None:          self.token = token          self.payload = payload +        self.admin = False      @property      def is_authenticated(self) -> bool: @@ -23,3 +29,23 @@ class User(BaseUser):      @property      def discord_mention(self) -> str:          return f"<@{self.payload['id']}>" + +    @property +    def decoded_token(self) -> dict[str, any]: +        return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) + +    async def fetch_admin_status(self, request: Request) -> bool: +        self.admin = await request.state.db.admins.find_one( +            {"_id": self.payload["id"]} +        ) is not None + +        return self.admin + +    async def refresh_data(self) -> None: +        """Fetches user data from discord, and updates the instance.""" +        self.payload = await fetch_user_details(self.decoded_token.get("token")) + +        updated_info = self.decoded_token +        updated_info["user_details"] = self.payload + +        self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256") diff --git a/backend/constants.py b/backend/constants.py index 812bef4..7ea4519 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -1,14 +1,19 @@ +import binascii +import os +from enum import Enum +  from dotenv import load_dotenv -load_dotenv() -import os  # noqa -import binascii  # noqa -from enum import Enum  # noqa +load_dotenv()  FRONTEND_URL = os.getenv("FRONTEND_URL", "https://forms.pythondiscord.com")  DATABASE_URL = os.getenv("DATABASE_URL")  MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms") +SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval") + +PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false" +PRODUCTION_URL = "https://forms.pythondiscord.com"  OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")  OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET") diff --git a/backend/discord.py b/backend/discord.py index 1dc8ed7..e5c7f8f 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -2,20 +2,25 @@  import httpx  from backend.constants import ( -    DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET, OAUTH2_REDIRECT_URI +    DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET  ) -async def fetch_bearer_token(access_code: str) -> dict: +async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:      async with httpx.AsyncClient() as client:          data = {              "client_id": OAUTH2_CLIENT_ID,              "client_secret": OAUTH2_CLIENT_SECRET, -            "grant_type": "authorization_code", -            "code": access_code, -            "redirect_uri": OAUTH2_REDIRECT_URI +            "redirect_uri": f"{redirect}/callback"          } +        if refresh: +            data["grant_type"] = "refresh_token" +            data["refresh_token"] = code +        else: +            data["grant_type"] = "authorization_code" +            data["code"] = code +          r = await client.post(f"{DISCORD_API_BASE_URL}/oauth2/token", headers={              "Content-Type": "application/x-www-form-urlencoded"          }, data=data) diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 975936a..d4587f0 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -2,26 +2,101 @@  Use a token received from the Discord OAuth2 system to fetch user information.  """ +import datetime +from typing import Union +  import httpx  import jwt  from pydantic.fields import Field  from pydantic.main import BaseModel  from spectree.response import Response +from starlette import responses +from starlette.authentication import requires  from starlette.requests import Request -from starlette.responses import JSONResponse +from backend import constants +from backend.authentication.user import User  from backend.constants import SECRET_KEY -from backend.route import Route  from backend.discord import fetch_bearer_token, fetch_user_details +from backend.route import Route  from backend.validation import ErrorMessage, api +AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400) +  class AuthorizeRequest(BaseModel):      token: str = Field(description="The access token received from Discord.")  class AuthorizeResponse(BaseModel): -    token: str = Field(description="A JWT token containing the user information") +    username: str = Field("Discord display name.") +    expiry: str = Field("ISO formatted timestamp of expiry.") + + +async def process_token( +        bearer_token: dict, +        request: Request +) -> Union[AuthorizeResponse, AUTH_FAILURE]: +    """Post a bearer token to Discord, and return a JWT and username.""" +    interaction_start = datetime.datetime.now() + +    try: +        user_details = await fetch_user_details(bearer_token["access_token"]) +    except httpx.HTTPStatusError: +        AUTH_FAILURE.delete_cookie("token") +        return AUTH_FAILURE + +    max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"])) +    token_expiry = interaction_start + max_age + +    data = { +        "token": bearer_token["access_token"], +        "refresh": bearer_token["refresh_token"], +        "user_details": user_details, +        "expiry": token_expiry.isoformat() +    } + +    token = jwt.encode(data, SECRET_KEY, algorithm="HS256") +    user = User(token, user_details) + +    response = responses.JSONResponse({ +        "username": user.display_name, +        "expiry": token_expiry.isoformat() +    }) + +    await set_response_token(response, request, token, bearer_token["expires_in"]) +    return response + + +async def set_response_token( +        response: responses.Response, +        request: Request, +        new_token: str, +        expiry: int +) -> None: +    """Helper that handles logic for updating a token in a set-cookie response.""" +    origin_url = request.headers.get("origin") + +    if origin_url == constants.PRODUCTION_URL: +        domain = request.url.netloc +        samesite = "strict" + +    elif not constants.PRODUCTION: +        domain = None +        samesite = "strict" + +    else: +        domain = request.url.netloc +        samesite = "None" + +    response.set_cookie( +        "token", f"JWT {new_token}", +        secure=constants.PRODUCTION, +        httponly=True, +        samesite=samesite, +        domain=domain, +        max_age=expiry +    )  class AuthorizeRoute(Route): @@ -37,22 +112,38 @@ class AuthorizeRoute(Route):          resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),          tags=["auth"]      ) -    async def post(self, request: Request) -> JSONResponse: +    async def post(self, request: Request) -> responses.JSONResponse:          """Generate an authorization token."""          data = await request.json() -          try: -            bearer_token = await fetch_bearer_token(data["token"]) -            user_details = await fetch_user_details(bearer_token["access_token"]) +            url = request.headers.get("origin") +            bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)          except httpx.HTTPStatusError: -            return JSONResponse({ -                "error": "auth_failure" -            }, status_code=400) +            return AUTH_FAILURE -        user_details["admin"] = await request.state.db.admins.find_one( -            {"_id": user_details["id"]} -        ) is not None +        return await process_token(bearer_token, request) + + +class TokenRefreshRoute(Route): +    """ +    Use the refresh code from a JWT to get a new token and generate a new JWT token. +    """ -        token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256") +    name = "refresh" +    path = "/refresh" + +    @requires(["authenticated"]) +    @api.validate( +        resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), +        tags=["auth"] +    ) +    async def post(self, request: Request) -> responses.JSONResponse: +        """Refresh an authorization token.""" +        try: +            token = request.user.decoded_token.get("refresh") +            url = request.headers.get("origin") +            bearer_token = await fetch_bearer_token(token, url, refresh=True) +        except httpx.HTTPStatusError: +            return AUTH_FAILURE -        return JSONResponse({"token": token}) +        return await process_token(bearer_token, request) diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index b6b722e..1c6e44a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -10,6 +10,7 @@ from starlette.responses import JSONResponse  from backend.models import Form  from backend.route import Route +from backend.routes.forms.unittesting import filter_unittests  from backend.validation import ErrorMessage, OkayResponse, api @@ -26,7 +27,7 @@ class SingleForm(Route):      @api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])      async def get(self, request: Request) -> JSONResponse:          """Returns single form information by ID.""" -        admin = request.user.payload["admin"] if request.user.is_authenticated else False  # noqa +        admin = request.user.admin if request.user.is_authenticated else False          filters = {              "_id": request.path_params["form_id"] @@ -37,6 +38,9 @@ class SingleForm(Route):          if raw_form := await request.state.db.forms.find_one(filters):              form = Form(**raw_form) +            if not admin: +                form = filter_unittests(form) +              return JSONResponse(form.dict(admin=admin))          return JSONResponse({"error": "not_found"}, status_code=404) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 37f76e0..23444a0 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -4,6 +4,7 @@ Submit a form.  import asyncio  import binascii +import datetime  import hashlib  import uuid  from typing import Any, Optional @@ -17,10 +18,12 @@ from starlette.requests import Request  from starlette.responses import JSONResponse  from backend import constants -from backend.authentication import User +from backend.authentication.user import User  from backend.models import Form, FormResponse  from backend.route import Route -from backend.validation import AuthorizationHeaders, ErrorMessage, api +from backend.routes.auth.authorize import set_response_token +from backend.routes.forms.unittesting import execute_unittest +from backend.validation import ErrorMessage, api  HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"  HCAPTCHA_HEADERS = { @@ -57,13 +60,37 @@ class SubmitForm(Route):              HTTP_404=ErrorMessage,              HTTP_400=ErrorMessage          ), -        headers=AuthorizationHeaders,          tags=["forms", "responses"]      )      async def post(self, request: Request) -> JSONResponse:          """Submit a response to the form.""" -        data = await request.json() +        response = await self.submit(request) + +        # Silently try to update user data +        try: +            if hasattr(request.user, User.refresh_data.__name__): +                old = request.user.token +                await request.user.refresh_data() + +                if old != request.user.token: +                    try: +                        expiry = datetime.datetime.fromisoformat( +                            request.user.decoded_token.get("expiry") +                        ) +                    except ValueError: +                        expiry = None + +                    expiry_seconds = (expiry - datetime.datetime.now()).seconds +                    await set_response_token(response, request, request.user.token, expiry_seconds) + +        except httpx.HTTPStatusError: +            pass +        return response + +    async def submit(self, request: Request) -> JSONResponse: +        """Helper method for handling submission logic.""" +        data = await request.json()          data["timestamp"] = None          if form := await request.state.db.forms.find_one( @@ -104,8 +131,12 @@ class SubmitForm(Route):              if constants.FormFeatures.REQUIRES_LOGIN.value in form.features:                  if request.user.is_authenticated:                      response["user"] = request.user.payload +                    response["user"]["admin"] = request.user.admin -                    if constants.FormFeatures.COLLECT_EMAIL.value in form.features and "email" not in response["user"]:  # noqa +                    if ( +                            constants.FormFeatures.COLLECT_EMAIL.value in form.features +                            and "email" not in response["user"] +                    ):                          return JSONResponse({                              "error": "email_required"                          }, status_code=400) @@ -133,6 +164,23 @@ class SubmitForm(Route):              except ValidationError as e:                  return JSONResponse(e.errors(), status_code=422) +            # Run unittests if needed +            if any("unittests" in question.data for question in form.questions): +                unittest_results = await execute_unittest(response_obj, form) + +                if not all(test.passed for test in unittest_results): +                    # Return 500 if we encountered an internal error (code 99). +                    status_code = 500 if any( +                        test.return_code == 99 for test in unittest_results +                    ) else 403 + +                    return JSONResponse({ +                        "error": "failed_tests", +                        "test_results": [ +                            test._asdict() for test in unittest_results if not test.passed +                        ] +                    }, status_code=status_code) +              await request.state.db.responses.insert_one(                  response_obj.dict(by_alias=True)              ) diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py new file mode 100644 index 0000000..3854314 --- /dev/null +++ b/backend/routes/forms/unittesting.py @@ -0,0 +1,127 @@ +import base64 +from collections import namedtuple +from itertools import count +from textwrap import indent + +import httpx +from httpx import HTTPStatusError + +from backend.constants import SNEKBOX_URL +from backend.models import FormResponse, Form + +with open("resources/unittest_template.py") as file: +    TEST_TEMPLATE = file.read() + + +UnittestResult = namedtuple("UnittestResult", "question_id return_code passed result") + + +def filter_unittests(form: Form) -> Form: +    """ +    Replace the unittest data section of code questions with the number of test cases. + +    This is used to redact the exact tests when sending the form back to the frontend. +    """ +    for question in form.questions: +        if question.type == "code" and "unittests" in question.data: +            question.data["unittests"] = len(question.data["unittests"]) + +    return form + + +def _make_unit_code(units: dict[str, str]) -> str: +    """Compose a dict mapping unit names to their code into an actual class body.""" +    result = "" + +    for unit_name, unit_code in units.items(): +        result += ( +            f"\ndef test_{unit_name.lstrip('#')}(unit):"  # Function definition +            f"\n{indent(unit_code, '    ')}"  # Unit code +        ) + +    return indent(result, "    ") + + +def _make_user_code(code: str) -> str: +    """Compose the user code into an actual base64-encoded string variable.""" +    code = base64.b64encode(code.encode("utf8")).decode("utf8") +    return f'USER_CODE = b"{code}"' + + +async def _post_eval(code: str) -> dict[str, str]: +    """Post the eval to snekbox and return the response.""" +    async with httpx.AsyncClient() as client: +        data = {"input": code} +        response = await client.post(SNEKBOX_URL, json=data, timeout=10) + +        response.raise_for_status() +        return response.json() + + +async def execute_unittest(form_response: FormResponse, form: Form) -> list[UnittestResult]: +    """Execute all the unittests in this form and return the results.""" +    unittest_results = [] + +    for question in form.questions: +        if question.type == "code" and "unittests" in question.data: +            passed = False + +            # Tests starting with an hashtag should have censored names. +            hidden_test_counter = count(1) +            hidden_tests = { +                test.lstrip("#").lstrip("test_"): next(hidden_test_counter) +                for test in question.data["unittests"].keys() +                if test.startswith("#") +            } + +            # Compose runner code +            unit_code = _make_unit_code(question.data["unittests"]) +            user_code = _make_user_code(form_response.response[question.id]) + +            code = TEST_TEMPLATE.replace("### USER CODE", user_code) +            code = code.replace("### UNIT CODE", unit_code) + +            try: +                response = await _post_eval(code) +            except HTTPStatusError: +                return_code = 99 +                result = "Unable to contact code runner." +            else: +                return_code = int(response["returncode"]) + +                # Parse the stdout if the tests ran successfully +                if return_code == 0: +                    stdout = response["stdout"] +                    passed = bool(int(stdout[0])) + +                    # If the test failed, we have to populate the result string. +                    if not passed: +                        failed_tests = stdout[1:].strip().split(";") + +                        # Redact failed hidden tests +                        for i, failed_test in enumerate(failed_tests.copy()): +                            if failed_test in hidden_tests: +                                failed_tests[i] = f"hidden_test_{hidden_tests[failed_test]}" + +                        result = ";".join(failed_tests) +                    else: +                        result = "" +                elif return_code in (5, 6, 99): +                    result = response["stdout"] +                # Killed by NsJail +                elif return_code == 137: +                    return_code = 7 +                    result = "Timed out or ran out of memory." +                # Another code has been returned by CPython because of another failure. +                else: +                    return_code = 99 +                    result = "Internal error." + +            unittest_results.append(UnittestResult( +                question_id=question.id, +                return_code=return_code, +                passed=passed, +                result=result +            )) + +    return unittest_results diff --git a/backend/validation.py b/backend/validation.py index e696683..8771924 100644 --- a/backend/validation.py +++ b/backend/validation.py @@ -1,6 +1,5 @@  """Utilities for providing API payload validation.""" -from typing import Optional  from pydantic.fields import Field  from pydantic.main import BaseModel  from spectree import SpecTree @@ -18,13 +17,3 @@ class ErrorMessage(BaseModel):  class OkayResponse(BaseModel):      status: str = "ok" - - -class AuthorizationHeaders(BaseModel): -    authorization: Optional[str] = Field( -        title="Authorization", -        description=( -            "The Authorization JWT token received from the " -            "authorize route in the format `JWT {token}`" -        ) -    )  |