diff options
Diffstat (limited to 'backend/routes')
| -rw-r--r-- | backend/routes/admin.py | 12 | ||||
| -rw-r--r-- | backend/routes/auth/authorize.py | 42 | ||||
| -rw-r--r-- | backend/routes/discord.py | 18 | ||||
| -rw-r--r-- | backend/routes/forms/discover.py | 21 | ||||
| -rw-r--r-- | backend/routes/forms/form.py | 16 | ||||
| -rw-r--r-- | backend/routes/forms/index.py | 26 | ||||
| -rw-r--r-- | backend/routes/forms/response.py | 21 | ||||
| -rw-r--r-- | backend/routes/forms/responses.py | 46 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 134 | ||||
| -rw-r--r-- | backend/routes/forms/unittesting.py | 68 | ||||
| -rw-r--r-- | backend/routes/index.py | 21 | 
11 files changed, 197 insertions, 228 deletions
| diff --git a/backend/routes/admin.py b/backend/routes/admin.py index 0fd0700..848abce 100644 --- a/backend/routes/admin.py +++ b/backend/routes/admin.py @@ -1,6 +1,5 @@ -""" -Adds new admin user. -""" +"""Adds new admin user.""" +  from pydantic import BaseModel, Field  from spectree import Response  from starlette.authentication import requires @@ -22,7 +21,7 @@ async def grant(request: Request) -> JSONResponse:      admin = AdminModel(**data)      if await request.state.db.admins.find_one( -            {"_id": admin.id} +        {"_id": admin.id},      ):          return JSONResponse({"error": "already_exists"}, status_code=400) @@ -40,7 +39,7 @@ class AdminRoute(Route):      @api.validate(          json=AdminModel,          resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), -        tags=["admin"] +        tags=["admin"],      )      async def post(self, request: Request) -> JSONResponse:          """Grant a user administrator privileges.""" @@ -48,6 +47,7 @@ class AdminRoute(Route):  if not constants.PRODUCTION: +      class AdminDev(Route):          """Adds new admin user with no authentication.""" @@ -57,7 +57,7 @@ if not constants.PRODUCTION:          @api.validate(              json=AdminModel,              resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), -            tags=["admin"] +            tags=["admin"],          )          async def post(self, request: Request) -> JSONResponse:              """ diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 42fb3ec..bc80a7d 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -1,9 +1,6 @@ -""" -Use a token received from the Discord OAuth2 system to fetch user information. -""" +"""Use a token received from the Discord OAuth2 system to fetch user information."""  import datetime -from typing import Union  import httpx  import jwt @@ -35,8 +32,8 @@ class AuthorizeResponse(BaseModel):  async def process_token(      bearer_token: dict, -    request: Request -) -> Union[AuthorizeResponse, AUTH_FAILURE]: +    request: Request, +) -> AuthorizeResponse | responses.JSONResponse:      """Post a bearer token to Discord, and return a JWT and username."""      interaction_start = datetime.datetime.now() @@ -57,7 +54,7 @@ async def process_token(          "refresh": bearer_token["refresh_token"],          "user_details": user_details,          "in_guild": bool(member), -        "expiry": token_expiry.isoformat() +        "expiry": token_expiry.isoformat(),      }      token = jwt.encode(data, SECRET_KEY, algorithm="HS256") @@ -65,18 +62,18 @@ async def process_token(      response = responses.JSONResponse({          "username": user.display_name, -        "expiry": token_expiry.isoformat() +        "expiry": token_expiry.isoformat(),      }) -    await set_response_token(response, request, token, bearer_token["expires_in"]) +    set_response_token(response, request, token, bearer_token["expires_in"])      return response -async def set_response_token( -        response: responses.Response, -        request: Request, -        new_token: str, -        expiry: int +def set_response_token( +    response: responses.Response, +    request: Request, +    new_token: str, +    expiry: int,  ) -> None:      """Helper that handles logic for updating a token in a set-cookie response."""      origin_url = request.headers.get("origin") @@ -94,19 +91,18 @@ async def set_response_token(          samesite = "None"      response.set_cookie( -        "token", f"JWT {new_token}", +        "token", +        f"JWT {new_token}",          secure=constants.PRODUCTION,          httponly=True,          samesite=samesite,          domain=domain, -        max_age=expiry +        max_age=expiry,      )  class AuthorizeRoute(Route): -    """ -    Use the authorization code from Discord to generate a JWT token. -    """ +    """Use the authorization code from Discord to generate a JWT token."""      name = "authorize"      path = "/authorize" @@ -114,7 +110,7 @@ class AuthorizeRoute(Route):      @api.validate(          json=AuthorizeRequest,          resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), -        tags=["auth"] +        tags=["auth"],      )      async def post(self, request: Request) -> responses.JSONResponse:          """Generate an authorization token.""" @@ -129,9 +125,7 @@ class AuthorizeRoute(Route):  class TokenRefreshRoute(Route): -    """ -    Use the refresh code from a JWT to get a new token and generate a new JWT token. -    """ +    """Use the refresh code from a JWT to get a new token and generate a new JWT token."""      name = "refresh"      path = "/refresh" @@ -139,7 +133,7 @@ class TokenRefreshRoute(Route):      @requires(["authenticated"])      @api.validate(          resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), -        tags=["auth"] +        tags=["auth"],      )      async def post(self, request: Request) -> responses.JSONResponse:          """Refresh an authorization token.""" diff --git a/backend/routes/discord.py b/backend/routes/discord.py index bca1edb..53b8af3 100644 --- a/backend/routes/discord.py +++ b/backend/routes/discord.py @@ -10,7 +10,8 @@ from backend import discord, models, route  from backend.validation import ErrorMessage, api  NOT_FOUND_EXCEPTION = JSONResponse( -    {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 +    {"error": "Could not find the requested resource in the guild or cache."}, +    status_code=404,  ) @@ -28,7 +29,7 @@ class RolesRoute(route.Route):      @requires(["authenticated", "admin"])      @api.validate(          resp=Response(HTTP_200=RolesResponse), -        tags=["roles"] +        tags=["roles"],      )      async def patch(self, request: Request) -> JSONResponse:          """Refresh the roles database.""" @@ -54,7 +55,7 @@ class MemberRoute(route.Route):      @api.validate(          resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),          json=MemberRequest, -        tags=["auth"] +        tags=["auth"],      )      async def delete(self, request: Request) -> JSONResponse:          """Force a resync of the cache for the given user.""" @@ -63,21 +64,20 @@ class MemberRoute(route.Route):          if member:              return JSONResponse(member.dict()) -        else: -            return NOT_FOUND_EXCEPTION +        return NOT_FOUND_EXCEPTION      @requires(["authenticated", "admin"])      @api.validate(          resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),          json=MemberRequest, -        tags=["auth"] +        tags=["auth"],      )      async def get(self, request: Request) -> JSONResponse:          """Get a user's roles on the configured server."""          body = await request.json()          member = await discord.get_member(request.state.db, body["user_id"]) -        if member: -            return JSONResponse(member.dict()) -        else: +        if not member:              return NOT_FOUND_EXCEPTION + +        return JSONResponse(member.dict()) diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py index 75ff495..0fe10b5 100644 --- a/backend/routes/forms/discover.py +++ b/backend/routes/forms/discover.py @@ -1,6 +1,5 @@ -""" -Return a list of all publicly discoverable forms to unauthenticated users. -""" +"""Return a list of all publicly discoverable forms to unauthenticated users.""" +  from spectree.response import Response  from starlette.requests import Request  from starlette.responses import JSONResponse @@ -12,7 +11,7 @@ from backend.validation import api  __FEATURES = [      constants.FormFeatures.OPEN.value, -    constants.FormFeatures.REQUIRES_LOGIN.value +    constants.FormFeatures.REQUIRES_LOGIN.value,  ]  if not constants.PRODUCTION:      __FEATURES.append(constants.FormFeatures.DISCOVERABLE.value) @@ -22,7 +21,7 @@ __QUESTION = Question(      name="Click the button below to log into the forms application.",      type="section",      data={"text": ""}, -    required=False +    required=False,  )  AUTH_FORM = Form( @@ -31,14 +30,12 @@ AUTH_FORM = Form(      questions=[__QUESTION],      name="Login",      description="Log into Python Discord Forms.", -    submitted_text="This page can't be submitted." +    submitted_text="This page can't be submitted.",  )  class DiscoverableFormsList(Route): -    """ -    List all discoverable forms that should be shown on the homepage. -    """ +    """List all discoverable forms that should be shown on the homepage."""      name = "discoverable_forms_list"      path = "/discoverable" @@ -46,15 +43,11 @@ class DiscoverableFormsList(Route):      @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])      async def get(self, request: Request) -> JSONResponse:          """List all discoverable forms that should be shown on the homepage.""" -        forms = []          cursor = request.state.db.forms.find({"features": "DISCOVERABLE"}).sort("name")          # Parse it to Form and then back to dictionary          # to replace _id with id -        for form in await cursor.to_list(None): -            forms.append(Form(**form)) - -        forms = [form.dict(admin=False) for form in forms] +        forms = [Form(**form).dict(admin=False) for form in await cursor.to_list(None)]          # Return an empty form in development environments to help with authentication.          if not constants.PRODUCTION: diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index 020193c..410102a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -1,6 +1,5 @@ -""" -Returns, updates or deletes a single form given an ID. -""" +"""Returns, updates or deletes a single form given an ID.""" +  import json.decoder  import deepmerge @@ -48,7 +47,7 @@ class SingleForm(Route):              admin = False          filters = { -            "_id": form_id +            "_id": form_id,          }          if not admin: @@ -71,7 +70,7 @@ class SingleForm(Route):              HTTP_400=ErrorMessage,              HTTP_404=ErrorMessage,          ), -        tags=["forms"] +        tags=["forms"],      )      async def patch(self, request: Request) -> JSONResponse:          """Updates form by ID.""" @@ -90,7 +89,7 @@ class SingleForm(Route):              # Build Data Merger              merge_strategy = [ -                (dict, ["merge"]) +                (dict, ["merge"]),              ]              merger = deepmerge.Merger(merge_strategy, ["override"], ["override"]) @@ -105,13 +104,12 @@ class SingleForm(Route):              await request.state.db.forms.replace_one({"_id": form_id}, form.dict())              return JSONResponse(form.dict()) -        else: -            return JSONResponse({"error": "not_found"}, status_code=404) +        return JSONResponse({"error": "not_found"}, status_code=404)      @requires(["authenticated", "admin"])      @api.validate(          resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage), -        tags=["forms"] +        tags=["forms"],      )      async def delete(self, request: Request) -> JSONResponse:          """Deletes form by ID.""" diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py index 38be693..1fdfc48 100644 --- a/backend/routes/forms/index.py +++ b/backend/routes/forms/index.py @@ -1,6 +1,5 @@ -""" -Return a list of all forms to authenticated users. -""" +"""Return a list of all forms to authenticated users.""" +  from spectree.response import Response  from starlette.authentication import requires  from starlette.requests import Request @@ -14,9 +13,7 @@ from backend.validation import ErrorMessage, OkayResponse, api  class FormsList(Route): -    """ -    List all available forms for authorized viewers. -    """ +    """List all available forms for authorized viewers."""      name = "forms_list_create"      path = "/" @@ -25,24 +22,17 @@ class FormsList(Route):      @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])      async def get(self, request: Request) -> JSONResponse:          """Return a list of all forms to authenticated users.""" -        forms = []          cursor = request.state.db.forms.find() -        for form in await cursor.to_list(None): -            forms.append(Form(**form))  # For converting _id to id - -        # Covert them back to dictionaries -        forms = [form.dict() for form in forms] +        forms = [Form(**form).dict() for form in await cursor.to_list(None)] -        return JSONResponse( -            forms -        ) +        return JSONResponse(forms)      @requires(["authenticated", "Helpers"])      @api.validate(          json=Form,          resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), -        tags=["forms"] +        tags=["forms"],      )      async def post(self, request: Request) -> JSONResponse:          """Create a new form.""" @@ -66,9 +56,7 @@ class FormsList(Route):          form = Form(**form_data)          if await request.state.db.forms.find_one({"_id": form.id}): -            return JSONResponse({ -                "error": "id_taken" -            }, status_code=400) +            return JSONResponse({"error": "id_taken"}, status_code=400)          await request.state.db.forms.insert_one(form.dict(by_alias=True))          return JSONResponse(form.dict()) diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py index 565701f..b4f7f04 100644 --- a/backend/routes/forms/response.py +++ b/backend/routes/forms/response.py @@ -1,6 +1,4 @@ -""" -Returns or deletes form response by ID. -""" +"""Returns or deletes form response by ID."""  from spectree import Response as RouteResponse  from starlette.authentication import requires @@ -22,7 +20,7 @@ class Response(Route):      @requires(["authenticated"])      @api.validate(          resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage), -        tags=["forms", "responses"] +        tags=["forms", "responses"],      )      async def get(self, request: Request) -> JSONResponse:          """Return a single form response by ID.""" @@ -32,30 +30,29 @@ class Response(Route):          if raw_response := await request.state.db.responses.find_one(              {                  "_id": request.path_params["response_id"], -                "form_id": form_id -            } +                "form_id": form_id, +            },          ):              response = FormResponse(**raw_response)              return JSONResponse(response.dict()) -        else: -            return JSONResponse({"error": "response_not_found"}, status_code=404) +        return JSONResponse({"error": "response_not_found"}, status_code=404)      @requires(["authenticated", "admin"])      @api.validate(          resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), -        tags=["forms", "responses"] +        tags=["forms", "responses"],      )      async def delete(self, request: Request) -> JSONResponse:          """Delete a form response by ID."""          if not await request.state.db.responses.find_one(              {                  "_id": request.path_params["response_id"], -                "form_id": request.path_params["form_id"] -            } +                "form_id": request.path_params["form_id"], +            },          ):              return JSONResponse({"error": "not_found"}, status_code=404)          await request.state.db.responses.delete_one( -            {"_id": request.path_params["response_id"]} +            {"_id": request.path_params["response_id"]},          )          return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py index 818ebce..85e5af2 100644 --- a/backend/routes/forms/responses.py +++ b/backend/routes/forms/responses.py @@ -1,6 +1,5 @@ -""" -Returns all form responses by form ID. -""" +"""Returns all form responses by form ID.""" +  from pydantic import BaseModel  from spectree import Response  from starlette.authentication import requires @@ -18,9 +17,7 @@ class ResponseIdList(BaseModel):  class Responses(Route): -    """ -    Returns all form responses by form ID. -    """ +    """Returns all form responses by form ID."""      name = "form_responses"      path = "/{form_id:str}/responses" @@ -28,7 +25,7 @@ class Responses(Route):      @requires(["authenticated"])      @api.validate(          resp=Response(HTTP_200=ResponseList), -        tags=["forms", "responses"] +        tags=["forms", "responses"],      )      async def get(self, request: Request) -> JSONResponse:          """Returns all form responses by form ID.""" @@ -36,11 +33,9 @@ class Responses(Route):          await discord.verify_response_access(form_id, request)          cursor = request.state.db.responses.find( -            {"form_id": form_id} +            {"form_id": form_id},          ) -        responses = [ -            FormResponse(**response) for response in await cursor.to_list(None) -        ] +        responses = [FormResponse(**response) for response in await cursor.to_list(None)]          return JSONResponse([response.dict() for response in responses])      @requires(["authenticated", "admin"]) @@ -49,14 +44,14 @@ class Responses(Route):          resp=Response(              HTTP_200=OkayResponse,              HTTP_404=ErrorMessage, -            HTTP_400=ErrorMessage +            HTTP_400=ErrorMessage,          ), -        tags=["forms", "responses"] +        tags=["forms", "responses"],      )      async def delete(self, request: Request) -> JSONResponse:          """Bulk deletes form responses by IDs."""          if not await request.state.db.forms.find_one( -            {"_id": request.path_params["form_id"]} +            {"_id": request.path_params["form_id"]},          ):              return JSONResponse({"error": "not_found"}, status_code=404) @@ -67,37 +62,34 @@ class Responses(Route):          ids = set(response_ids.ids)          cursor = request.state.db.responses.find( -            {"_id": {"$in": list(ids)}}  # Convert here back to list, may throw error. +            {"_id": {"$in": list(ids)}},  # Convert here back to list, may throw error.          ) -        entries = [ -            FormResponse(**submission) for submission in await cursor.to_list(None) -        ] +        entries = [FormResponse(**submission) for submission in await cursor.to_list(None)]          actual_ids = {entry.id for entry in entries}          if len(ids) != len(actual_ids):              return JSONResponse(                  {                      "error": "responses_not_found", -                    "ids": list(ids - actual_ids) +                    "ids": list(ids - actual_ids),                  }, -                status_code=404 +                status_code=404,              )          if any(entry.form_id != request.path_params["form_id"] for entry in entries):              return JSONResponse(                  {                      "error": "wrong_form", -                    "ids": list( -                        entry.id for entry in entries -                        if entry.id != request.path_params["form_id"] -                    ) +                    "ids": [ +                        entry.id for entry in entries if entry.id != request.path_params["form_id"] +                    ],                  }, -                status_code=400 +                status_code=400,              )          await request.state.db.responses.delete_many(              { -                "_id": {"$in": list(actual_ids)} -            } +                "_id": {"$in": list(actual_ids)}, +            },          )          return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 765856e..8f01e2b 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -1,6 +1,4 @@ -""" -Submit a form. -""" +"""Submit a form."""  import asyncio  import binascii @@ -8,10 +6,9 @@ import datetime  import hashlib  import typing  import uuid -from typing import Any, Optional +from typing import Any  import httpx -import pymongo.database  import sentry_sdk  from pydantic import ValidationError  from pydantic.main import BaseModel @@ -29,13 +26,16 @@ from backend.routes.forms.discover import AUTH_FORM  from backend.routes.forms.unittesting import BypassDetectedError, execute_unittest  from backend.validation import ErrorMessage, api +if typing.TYPE_CHECKING: +    import pymongo.database +  HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"  HCAPTCHA_HEADERS = { -    "Content-Type": "application/x-www-form-urlencoded" +    "Content-Type": "application/x-www-form-urlencoded",  }  DISCORD_HEADERS = { -    "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}" +    "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}",  } @@ -46,7 +46,7 @@ class SubmissionResponse(BaseModel):  class PartialSubmission(BaseModel):      response: dict[str, Any] -    captcha: Optional[str] +    captcha: str | None  class UnittestError(BaseModel): @@ -62,9 +62,7 @@ class UnittestErrorMessage(ErrorMessage):  class SubmitForm(Route): -    """ -    Submit a form with the provided form ID. -    """ +    """Submit a form with the provided form ID."""      name = "submit_form"      path = "/submit/{form_id:str}" @@ -75,9 +73,9 @@ class SubmitForm(Route):              HTTP_200=SubmissionResponse,              HTTP_404=ErrorMessage,              HTTP_400=ErrorMessage, -            HTTP_422=UnittestErrorMessage +            HTTP_422=UnittestErrorMessage,          ), -        tags=["forms", "responses"] +        tags=["forms", "responses"],      )      async def post(self, request: Request) -> JSONResponse:          """Submit a response to the form.""" @@ -92,7 +90,7 @@ class SubmitForm(Route):                  if old != request.user.token:                      try:                          expiry = datetime.datetime.fromisoformat( -                            request.user.decoded_token.get("expiry") +                            request.user.decoded_token.get("expiry"),                          )                      except ValueError:                          expiry = None @@ -117,7 +115,7 @@ class SubmitForm(Route):                  id="not-submitted",                  form_id=AUTH_FORM.id,                  response={question.id: None for question in AUTH_FORM.questions}, -                timestamp=datetime.datetime.now().isoformat() +                timestamp=datetime.datetime.now().isoformat(),              ).dict()              return JSONResponse({"form": AUTH_FORM.dict(admin=False), "response": response}) @@ -131,8 +129,9 @@ class SubmitForm(Route):                  ip_hash_ctx = hashlib.md5()                  ip_hash_ctx.update(                      request.headers.get( -                        "Cf-Connecting-IP", request.client.host -                    ).encode() +                        "Cf-Connecting-IP", +                        request.client.host, +                    ).encode(),                  )                  ip_hash = binascii.hexlify(ip_hash_ctx.digest())                  user_agent_hash_ctx = hashlib.md5() @@ -142,12 +141,12 @@ class SubmitForm(Route):                  async with httpx.AsyncClient() as client:                      query_params = {                          "secret": constants.HCAPTCHA_API_SECRET, -                        "response": data.get("captcha") +                        "response": data.get("captcha"),                      }                      r = await client.post(                          HCAPTCHA_VERIFY_URL,                          params=query_params, -                        headers=HCAPTCHA_HEADERS +                        headers=HCAPTCHA_HEADERS,                      )                      r.raise_for_status()                      captcha_data = r.json() @@ -155,7 +154,7 @@ class SubmitForm(Route):                  response["antispam"] = {                      "ip_hash": ip_hash.decode(),                      "user_agent_hash": user_agent_hash.decode(), -                    "captcha_pass": captcha_data["success"] +                    "captcha_pass": captcha_data["success"],                  }              if constants.FormFeatures.REQUIRES_LOGIN.value in form.features: @@ -164,16 +163,12 @@ class SubmitForm(Route):                      response["user"]["admin"] = request.user.admin                      if ( -                            constants.FormFeatures.COLLECT_EMAIL.value in form.features -                            and "email" not in response["user"] +                        constants.FormFeatures.COLLECT_EMAIL.value in form.features +                        and "email" not in response["user"]                      ): -                        return JSONResponse({ -                            "error": "email_required" -                        }, status_code=400) +                        return JSONResponse({"error": "email_required"}, status_code=400)                  else: -                    return JSONResponse({ -                        "error": "missing_discord_data" -                    }, status_code=400) +                    return JSONResponse({"error": "missing_discord_data"}, status_code=400)              missing_fields = []              for question in form.questions: @@ -184,10 +179,13 @@ class SubmitForm(Route):                          missing_fields.append(question.id)              if missing_fields: -                return JSONResponse({ -                    "error": "missing_fields", -                    "fields": missing_fields -                }, status_code=400) +                return JSONResponse( +                    { +                        "error": "missing_fields", +                        "fields": missing_fields, +                    }, +                    status_code=400, +                )              try:                  response_obj = FormResponse(**response) @@ -200,10 +198,12 @@ class SubmitForm(Route):                  if len(errors):                      username = getattr(request.user, "user_id", "Unknown") -                    sentry_sdk.capture_exception(BypassDetectedError( -                        f"Detected unittest bypass attempt on form {form.id} by {username}. " -                        f"Submission has been written to reporting database ({response_obj.id})." -                    )) +                    sentry_sdk.capture_exception( +                        BypassDetectedError( +                            f"Detected unittest bypass attempt on form {form.id} by {username}. " +                            f"Submission has been written to reporting database ({response_obj.id}).", +                        ) +                    )                      database: pymongo.database.Database = request.state.db                      await database.get_collection("violations").insert_one({                          "user": username, @@ -219,7 +219,7 @@ class SubmitForm(Route):                  for test in unittest_results:                      response_obj.response[test.question_id] = {                          "value": response_obj.response[test.question_id], -                        "passed": test.passed +                        "passed": test.passed,                      }                      if test.return_code == 0: @@ -238,9 +238,8 @@ class SubmitForm(Route):                      # Report a failure on internal errors,                      # or if the test suite doesn't allow failures                      if not test.passed: -                        allow_failure = ( -                            form.questions[test.question_index].data["unittests"]["allow_failure"] -                        ) +                        question = form.questions[test.question_index] +                        allow_failure = question.data["unittests"]["allow_failure"]                          # An error while communicating with the test runner                          if test.return_code == 99: @@ -251,15 +250,16 @@ class SubmitForm(Route):                              failures.append(test)                  if len(failures): -                    return JSONResponse({ -                        "error": "failed_tests", -                        "test_results": [ -                            test._asdict() for test in failures -                        ] -                    }, status_code=status_code) +                    return JSONResponse( +                        { +                            "error": "failed_tests", +                            "test_results": [test._asdict() for test in failures], +                        }, +                        status_code=status_code, +                    )              await request.state.db.responses.insert_one( -                response_obj.dict(by_alias=True) +                response_obj.dict(by_alias=True),              )              tasks = BackgroundTasks() @@ -272,36 +272,37 @@ class SubmitForm(Route):                      self.send_submission_webhook,                      form=form,                      response=response_obj, -                    request_user=request_user +                    request_user=request_user,                  )              if constants.FormFeatures.ASSIGN_ROLE.value in form.features:                  tasks.add_task(                      self.assign_role,                      form=form, -                    request_user=request.user +                    request_user=request.user,                  ) -            return JSONResponse({ -                "form": form.dict(admin=False), -                "response": response_obj.dict() -            }, background=tasks) +            return JSONResponse( +                { +                    "form": form.dict(admin=False), +                    "response": response_obj.dict(), +                }, +                background=tasks, +            ) -        else: -            return JSONResponse({ -                "error": "Open form not found" -            }, status_code=404) +        return JSONResponse({"error": "Open form not found"}, status_code=404)      @staticmethod      async def send_submission_webhook( -            form: Form, -            response: FormResponse, -            request_user: typing.Optional[User] +        form: Form, +        response: FormResponse, +        request_user: User | None,      ) -> None:          """Helper to send a submission message to a discord webhook."""          # Stop if webhook is not available          if form.webhook is None: -            raise ValueError("Got empty webhook.") +            msg = "Got empty webhook." +            raise ValueError(msg)          try:              mention = request_user.discord_mention @@ -330,7 +331,7 @@ class SubmitForm(Route):          hook = {              "embeds": [embed],              "allowed_mentions": {"parse": ["users", "roles"]}, -            "username": form.name or "Python Discord Forms" +            "username": form.name or "Python Discord Forms",          }          # Set hook message @@ -345,8 +346,8 @@ class SubmitForm(Route):                  "time": response.timestamp,              } -            for key in ctx: -                message = message.replace(f"{{{key}}}", str(ctx[key])) +            for key, val in ctx.items(): +                message = message.replace(f"{{{key}}}", str(val))              hook["content"] = message.replace("_USER_MENTION_", mention) @@ -359,11 +360,12 @@ class SubmitForm(Route):      async def assign_role(form: Form, request_user: User) -> None:          """Assigns Discord role to user when user submitted response."""          if not form.discord_role: -            raise ValueError("Got empty Discord role ID.") +            msg = "Got empty Discord role ID." +            raise ValueError(msg)          url = (              f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" -            f"/members/{request_user.payload['id']}/roles/{form.discord_role}" +            f"/members/{request_user.payload["id"]}/roles/{form.discord_role}"          )          async with httpx.AsyncClient() as client: diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a02afea..3239d35 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -1,7 +1,8 @@  import base64 -from collections import namedtuple  from itertools import count +from pathlib import Path  from textwrap import indent +from typing import NamedTuple  import httpx  from httpx import HTTPStatusError @@ -9,7 +10,7 @@ from httpx import HTTPStatusError  from backend.constants import SNEKBOX_URL  from backend.models import Form, FormResponse -with open("resources/unittest_template.py") as file: +with Path("resources/unittest_template.py").open(encoding="utf8") as file:      TEST_TEMPLATE = file.read() @@ -17,9 +18,12 @@ class BypassDetectedError(Exception):      """Detected an attempt at bypassing the unittests.""" -UnittestResult = namedtuple( -    "UnittestResult", "question_id question_index return_code passed result" -) +class UnittestResult(NamedTuple): +    question_id: str +    question_index: int +    return_code: int +    passed: bool +    result: str  def filter_unittests(form: Form) -> Form: @@ -46,11 +50,11 @@ def _make_unit_code(units: dict[str, str]) -> str:          elif unit_name == "tearDown":              result += "\ndef tearDown(self):"          else: -            name = f"test_{unit_name.removeprefix('#').removeprefix('test_')}" +            name = f"test_{unit_name.removeprefix("#").removeprefix("test_")}"              result += f"\nasync def {name}(self):"          # Unite code -        result += f"\n{indent(unit_code, '    ')}" +        result += f"\n{indent(unit_code, "    ")}"      return indent(result, "    ") @@ -72,7 +76,8 @@ async def _post_eval(code: str) -> dict[str, str]:  async def execute_unittest( -        form_response: FormResponse, form: Form +    form_response: FormResponse, +    form: Form,  ) -> tuple[list[UnittestResult], list[BypassDetectedError]]:      """Execute all the unittests in this form and return the results."""      unittest_results = [] @@ -80,16 +85,17 @@ async def execute_unittest(      for index, question in enumerate(form.questions):          if question.type == "code": -              # Exit early if the suite doesn't have any tests              if question.data["unittests"] is None: -                unittest_results.append(UnittestResult( -                    question_id=question.id, -                    question_index=index, -                    return_code=0, -                    passed=True, -                    result="" -                )) +                unittest_results.append( +                    UnittestResult( +                        question_id=question.id, +                        question_index=index, +                        return_code=0, +                        passed=True, +                        result="", +                    ) +                )                  continue              passed = False @@ -98,7 +104,7 @@ async def execute_unittest(              hidden_test_counter = count(1)              hidden_tests = {                  test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter) -                for test in question.data["unittests"]["tests"].keys() +                for test in question.data["unittests"]["tests"]                  if test.startswith("#")              } @@ -124,18 +130,18 @@ async def execute_unittest(                          try:                              passed = bool(int(stdout[0]))                          except ValueError: -                            raise BypassDetectedError("Detected a bypass when reading result code.") +                            msg = "Detected a bypass when reading result code." +                            raise BypassDetectedError(msg)                          if passed and stdout.strip() != "1":                              # Most likely a bypass attempt                              # A 1 was written to stdout to indicate success,                              # followed by the actual output -                            raise BypassDetectedError( -                                "Detected improper value for stdout in unittest." -                            ) +                            msg = "Detected improper value for stdout in unittest." +                            raise BypassDetectedError(msg)                          # If the test failed, we have to populate the result string. -                        elif not passed: +                        if not passed:                              failed_tests = stdout[1:].strip().split(";")                              # Redact failed hidden tests @@ -146,7 +152,7 @@ async def execute_unittest(                              result = ";".join(failed_tests)                          else:                              result = "" -                    elif return_code in (5, 6, 99): +                    elif return_code in {5, 6, 99}:                          result = response["stdout"]                      # Killed by NsJail                      elif return_code == 137: @@ -162,12 +168,14 @@ async def execute_unittest(                  errors.append(error)                  passed = False -            unittest_results.append(UnittestResult( -                question_id=question.id, -                question_index=index, -                return_code=return_code, -                passed=passed, -                result=result -            )) +            unittest_results.append( +                UnittestResult( +                    question_id=question.id, +                    question_index=index, +                    return_code=return_code, +                    passed=passed, +                    result=result, +                ) +            )      return unittest_results, errors diff --git a/backend/routes/index.py b/backend/routes/index.py index 207c36a..c6e38ea 100644 --- a/backend/routes/index.py +++ b/backend/routes/index.py @@ -1,6 +1,5 @@ -""" -Index route for the forms API. -""" +"""Index route for the forms API.""" +  import platform  from pydantic import BaseModel @@ -20,13 +19,13 @@ class IndexResponse(BaseModel):          description=(              "The connecting client, in production this will"              " be an IP of our internal load balancer" -        ) +        ),      )      sha: str = Field( -        description="Current release Git SHA in production." +        description="Current release Git SHA in production.",      )      node: str = Field( -        description="The node that processed the request." +        description="The node that processed the request.",      ) @@ -42,24 +41,22 @@ class IndexRoute(Route):      @api.validate(resp=Response(HTTP_200=IndexResponse))      def get(self, request: Request) -> JSONResponse: -        """ -        Return a hello from Python Discord forms! -        """ +        """Return a hello from Python Discord forms!."""          response_data = {              "message": "Hello, world!",              "client": request.client.host,              "user": { -                "authenticated": False +                "authenticated": False,              },              "sha": GIT_SHA, -            "node": platform.uname().node +            "node": platform.uname().node,          }          if request.user.is_authenticated:              response_data["user"] = {                  "authenticated": True,                  "user": request.user.payload, -                "scopes": request.auth.scopes +                "scopes": request.auth.scopes,              }          return JSONResponse(response_data) | 
