From d0e09d2ba567f23d91ac76d1844966bafb9b063a Mon Sep 17 00:00:00 2001 From: Joe Banks Date: Sun, 7 Jul 2024 02:29:26 +0100 Subject: Apply fixable lint settings with Ruff --- backend/routes/admin.py | 12 ++-- backend/routes/auth/authorize.py | 42 +++++------ backend/routes/discord.py | 18 ++--- backend/routes/forms/discover.py | 21 ++---- backend/routes/forms/form.py | 16 ++--- backend/routes/forms/index.py | 26 ++----- backend/routes/forms/response.py | 21 +++--- backend/routes/forms/responses.py | 46 +++++-------- backend/routes/forms/submit.py | 134 ++++++++++++++++++------------------ backend/routes/forms/unittesting.py | 68 ++++++++++-------- backend/routes/index.py | 21 +++--- 11 files changed, 197 insertions(+), 228 deletions(-) (limited to 'backend/routes') diff --git a/backend/routes/admin.py b/backend/routes/admin.py index 0fd0700..848abce 100644 --- a/backend/routes/admin.py +++ b/backend/routes/admin.py @@ -1,6 +1,5 @@ -""" -Adds new admin user. -""" +"""Adds new admin user.""" + from pydantic import BaseModel, Field from spectree import Response from starlette.authentication import requires @@ -22,7 +21,7 @@ async def grant(request: Request) -> JSONResponse: admin = AdminModel(**data) if await request.state.db.admins.find_one( - {"_id": admin.id} + {"_id": admin.id}, ): return JSONResponse({"error": "already_exists"}, status_code=400) @@ -40,7 +39,7 @@ class AdminRoute(Route): @api.validate( json=AdminModel, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["admin"] + tags=["admin"], ) async def post(self, request: Request) -> JSONResponse: """Grant a user administrator privileges.""" @@ -48,6 +47,7 @@ class AdminRoute(Route): if not constants.PRODUCTION: + class AdminDev(Route): """Adds new admin user with no authentication.""" @@ -57,7 +57,7 @@ if not constants.PRODUCTION: @api.validate( json=AdminModel, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["admin"] + tags=["admin"], ) async def post(self, request: Request) -> JSONResponse: """ diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 42fb3ec..bc80a7d 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -1,9 +1,6 @@ -""" -Use a token received from the Discord OAuth2 system to fetch user information. -""" +"""Use a token received from the Discord OAuth2 system to fetch user information.""" import datetime -from typing import Union import httpx import jwt @@ -35,8 +32,8 @@ class AuthorizeResponse(BaseModel): async def process_token( bearer_token: dict, - request: Request -) -> Union[AuthorizeResponse, AUTH_FAILURE]: + request: Request, +) -> AuthorizeResponse | responses.JSONResponse: """Post a bearer token to Discord, and return a JWT and username.""" interaction_start = datetime.datetime.now() @@ -57,7 +54,7 @@ async def process_token( "refresh": bearer_token["refresh_token"], "user_details": user_details, "in_guild": bool(member), - "expiry": token_expiry.isoformat() + "expiry": token_expiry.isoformat(), } token = jwt.encode(data, SECRET_KEY, algorithm="HS256") @@ -65,18 +62,18 @@ async def process_token( response = responses.JSONResponse({ "username": user.display_name, - "expiry": token_expiry.isoformat() + "expiry": token_expiry.isoformat(), }) - await set_response_token(response, request, token, bearer_token["expires_in"]) + set_response_token(response, request, token, bearer_token["expires_in"]) return response -async def set_response_token( - response: responses.Response, - request: Request, - new_token: str, - expiry: int +def set_response_token( + response: responses.Response, + request: Request, + new_token: str, + expiry: int, ) -> None: """Helper that handles logic for updating a token in a set-cookie response.""" origin_url = request.headers.get("origin") @@ -94,19 +91,18 @@ async def set_response_token( samesite = "None" response.set_cookie( - "token", f"JWT {new_token}", + "token", + f"JWT {new_token}", secure=constants.PRODUCTION, httponly=True, samesite=samesite, domain=domain, - max_age=expiry + max_age=expiry, ) class AuthorizeRoute(Route): - """ - Use the authorization code from Discord to generate a JWT token. - """ + """Use the authorization code from Discord to generate a JWT token.""" name = "authorize" path = "/authorize" @@ -114,7 +110,7 @@ class AuthorizeRoute(Route): @api.validate( json=AuthorizeRequest, resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), - tags=["auth"] + tags=["auth"], ) async def post(self, request: Request) -> responses.JSONResponse: """Generate an authorization token.""" @@ -129,9 +125,7 @@ class AuthorizeRoute(Route): class TokenRefreshRoute(Route): - """ - Use the refresh code from a JWT to get a new token and generate a new JWT token. - """ + """Use the refresh code from a JWT to get a new token and generate a new JWT token.""" name = "refresh" path = "/refresh" @@ -139,7 +133,7 @@ class TokenRefreshRoute(Route): @requires(["authenticated"]) @api.validate( resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), - tags=["auth"] + tags=["auth"], ) async def post(self, request: Request) -> responses.JSONResponse: """Refresh an authorization token.""" diff --git a/backend/routes/discord.py b/backend/routes/discord.py index bca1edb..53b8af3 100644 --- a/backend/routes/discord.py +++ b/backend/routes/discord.py @@ -10,7 +10,8 @@ from backend import discord, models, route from backend.validation import ErrorMessage, api NOT_FOUND_EXCEPTION = JSONResponse( - {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 + {"error": "Could not find the requested resource in the guild or cache."}, + status_code=404, ) @@ -28,7 +29,7 @@ class RolesRoute(route.Route): @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=RolesResponse), - tags=["roles"] + tags=["roles"], ) async def patch(self, request: Request) -> JSONResponse: """Refresh the roles database.""" @@ -54,7 +55,7 @@ class MemberRoute(route.Route): @api.validate( resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), json=MemberRequest, - tags=["auth"] + tags=["auth"], ) async def delete(self, request: Request) -> JSONResponse: """Force a resync of the cache for the given user.""" @@ -63,21 +64,20 @@ class MemberRoute(route.Route): if member: return JSONResponse(member.dict()) - else: - return NOT_FOUND_EXCEPTION + return NOT_FOUND_EXCEPTION @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), json=MemberRequest, - tags=["auth"] + tags=["auth"], ) async def get(self, request: Request) -> JSONResponse: """Get a user's roles on the configured server.""" body = await request.json() member = await discord.get_member(request.state.db, body["user_id"]) - if member: - return JSONResponse(member.dict()) - else: + if not member: return NOT_FOUND_EXCEPTION + + return JSONResponse(member.dict()) diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py index 75ff495..0fe10b5 100644 --- a/backend/routes/forms/discover.py +++ b/backend/routes/forms/discover.py @@ -1,6 +1,5 @@ -""" -Return a list of all publicly discoverable forms to unauthenticated users. -""" +"""Return a list of all publicly discoverable forms to unauthenticated users.""" + from spectree.response import Response from starlette.requests import Request from starlette.responses import JSONResponse @@ -12,7 +11,7 @@ from backend.validation import api __FEATURES = [ constants.FormFeatures.OPEN.value, - constants.FormFeatures.REQUIRES_LOGIN.value + constants.FormFeatures.REQUIRES_LOGIN.value, ] if not constants.PRODUCTION: __FEATURES.append(constants.FormFeatures.DISCOVERABLE.value) @@ -22,7 +21,7 @@ __QUESTION = Question( name="Click the button below to log into the forms application.", type="section", data={"text": ""}, - required=False + required=False, ) AUTH_FORM = Form( @@ -31,14 +30,12 @@ AUTH_FORM = Form( questions=[__QUESTION], name="Login", description="Log into Python Discord Forms.", - submitted_text="This page can't be submitted." + submitted_text="This page can't be submitted.", ) class DiscoverableFormsList(Route): - """ - List all discoverable forms that should be shown on the homepage. - """ + """List all discoverable forms that should be shown on the homepage.""" name = "discoverable_forms_list" path = "/discoverable" @@ -46,15 +43,11 @@ class DiscoverableFormsList(Route): @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """List all discoverable forms that should be shown on the homepage.""" - forms = [] cursor = request.state.db.forms.find({"features": "DISCOVERABLE"}).sort("name") # Parse it to Form and then back to dictionary # to replace _id with id - for form in await cursor.to_list(None): - forms.append(Form(**form)) - - forms = [form.dict(admin=False) for form in forms] + forms = [Form(**form).dict(admin=False) for form in await cursor.to_list(None)] # Return an empty form in development environments to help with authentication. if not constants.PRODUCTION: diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index 020193c..410102a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -1,6 +1,5 @@ -""" -Returns, updates or deletes a single form given an ID. -""" +"""Returns, updates or deletes a single form given an ID.""" + import json.decoder import deepmerge @@ -48,7 +47,7 @@ class SingleForm(Route): admin = False filters = { - "_id": form_id + "_id": form_id, } if not admin: @@ -71,7 +70,7 @@ class SingleForm(Route): HTTP_400=ErrorMessage, HTTP_404=ErrorMessage, ), - tags=["forms"] + tags=["forms"], ) async def patch(self, request: Request) -> JSONResponse: """Updates form by ID.""" @@ -90,7 +89,7 @@ class SingleForm(Route): # Build Data Merger merge_strategy = [ - (dict, ["merge"]) + (dict, ["merge"]), ] merger = deepmerge.Merger(merge_strategy, ["override"], ["override"]) @@ -105,13 +104,12 @@ class SingleForm(Route): await request.state.db.forms.replace_one({"_id": form_id}, form.dict()) return JSONResponse(form.dict()) - else: - return JSONResponse({"error": "not_found"}, status_code=404) + return JSONResponse({"error": "not_found"}, status_code=404) @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage), - tags=["forms"] + tags=["forms"], ) async def delete(self, request: Request) -> JSONResponse: """Deletes form by ID.""" diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py index 38be693..1fdfc48 100644 --- a/backend/routes/forms/index.py +++ b/backend/routes/forms/index.py @@ -1,6 +1,5 @@ -""" -Return a list of all forms to authenticated users. -""" +"""Return a list of all forms to authenticated users.""" + from spectree.response import Response from starlette.authentication import requires from starlette.requests import Request @@ -14,9 +13,7 @@ from backend.validation import ErrorMessage, OkayResponse, api class FormsList(Route): - """ - List all available forms for authorized viewers. - """ + """List all available forms for authorized viewers.""" name = "forms_list_create" path = "/" @@ -25,24 +22,17 @@ class FormsList(Route): @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """Return a list of all forms to authenticated users.""" - forms = [] cursor = request.state.db.forms.find() - for form in await cursor.to_list(None): - forms.append(Form(**form)) # For converting _id to id - - # Covert them back to dictionaries - forms = [form.dict() for form in forms] + forms = [Form(**form).dict() for form in await cursor.to_list(None)] - return JSONResponse( - forms - ) + return JSONResponse(forms) @requires(["authenticated", "Helpers"]) @api.validate( json=Form, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["forms"] + tags=["forms"], ) async def post(self, request: Request) -> JSONResponse: """Create a new form.""" @@ -66,9 +56,7 @@ class FormsList(Route): form = Form(**form_data) if await request.state.db.forms.find_one({"_id": form.id}): - return JSONResponse({ - "error": "id_taken" - }, status_code=400) + return JSONResponse({"error": "id_taken"}, status_code=400) await request.state.db.forms.insert_one(form.dict(by_alias=True)) return JSONResponse(form.dict()) diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py index 565701f..b4f7f04 100644 --- a/backend/routes/forms/response.py +++ b/backend/routes/forms/response.py @@ -1,6 +1,4 @@ -""" -Returns or deletes form response by ID. -""" +"""Returns or deletes form response by ID.""" from spectree import Response as RouteResponse from starlette.authentication import requires @@ -22,7 +20,7 @@ class Response(Route): @requires(["authenticated"]) @api.validate( resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def get(self, request: Request) -> JSONResponse: """Return a single form response by ID.""" @@ -32,30 +30,29 @@ class Response(Route): if raw_response := await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], - "form_id": form_id - } + "form_id": form_id, + }, ): response = FormResponse(**raw_response) return JSONResponse(response.dict()) - else: - return JSONResponse({"error": "response_not_found"}, status_code=404) + return JSONResponse({"error": "response_not_found"}, status_code=404) @requires(["authenticated", "admin"]) @api.validate( resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def delete(self, request: Request) -> JSONResponse: """Delete a form response by ID.""" if not await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], - "form_id": request.path_params["form_id"] - } + "form_id": request.path_params["form_id"], + }, ): return JSONResponse({"error": "not_found"}, status_code=404) await request.state.db.responses.delete_one( - {"_id": request.path_params["response_id"]} + {"_id": request.path_params["response_id"]}, ) return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py index 818ebce..85e5af2 100644 --- a/backend/routes/forms/responses.py +++ b/backend/routes/forms/responses.py @@ -1,6 +1,5 @@ -""" -Returns all form responses by form ID. -""" +"""Returns all form responses by form ID.""" + from pydantic import BaseModel from spectree import Response from starlette.authentication import requires @@ -18,9 +17,7 @@ class ResponseIdList(BaseModel): class Responses(Route): - """ - Returns all form responses by form ID. - """ + """Returns all form responses by form ID.""" name = "form_responses" path = "/{form_id:str}/responses" @@ -28,7 +25,7 @@ class Responses(Route): @requires(["authenticated"]) @api.validate( resp=Response(HTTP_200=ResponseList), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def get(self, request: Request) -> JSONResponse: """Returns all form responses by form ID.""" @@ -36,11 +33,9 @@ class Responses(Route): await discord.verify_response_access(form_id, request) cursor = request.state.db.responses.find( - {"form_id": form_id} + {"form_id": form_id}, ) - responses = [ - FormResponse(**response) for response in await cursor.to_list(None) - ] + responses = [FormResponse(**response) for response in await cursor.to_list(None)] return JSONResponse([response.dict() for response in responses]) @requires(["authenticated", "admin"]) @@ -49,14 +44,14 @@ class Responses(Route): resp=Response( HTTP_200=OkayResponse, HTTP_404=ErrorMessage, - HTTP_400=ErrorMessage + HTTP_400=ErrorMessage, ), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def delete(self, request: Request) -> JSONResponse: """Bulk deletes form responses by IDs.""" if not await request.state.db.forms.find_one( - {"_id": request.path_params["form_id"]} + {"_id": request.path_params["form_id"]}, ): return JSONResponse({"error": "not_found"}, status_code=404) @@ -67,37 +62,34 @@ class Responses(Route): ids = set(response_ids.ids) cursor = request.state.db.responses.find( - {"_id": {"$in": list(ids)}} # Convert here back to list, may throw error. + {"_id": {"$in": list(ids)}}, # Convert here back to list, may throw error. ) - entries = [ - FormResponse(**submission) for submission in await cursor.to_list(None) - ] + entries = [FormResponse(**submission) for submission in await cursor.to_list(None)] actual_ids = {entry.id for entry in entries} if len(ids) != len(actual_ids): return JSONResponse( { "error": "responses_not_found", - "ids": list(ids - actual_ids) + "ids": list(ids - actual_ids), }, - status_code=404 + status_code=404, ) if any(entry.form_id != request.path_params["form_id"] for entry in entries): return JSONResponse( { "error": "wrong_form", - "ids": list( - entry.id for entry in entries - if entry.id != request.path_params["form_id"] - ) + "ids": [ + entry.id for entry in entries if entry.id != request.path_params["form_id"] + ], }, - status_code=400 + status_code=400, ) await request.state.db.responses.delete_many( { - "_id": {"$in": list(actual_ids)} - } + "_id": {"$in": list(actual_ids)}, + }, ) return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 765856e..8f01e2b 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -1,6 +1,4 @@ -""" -Submit a form. -""" +"""Submit a form.""" import asyncio import binascii @@ -8,10 +6,9 @@ import datetime import hashlib import typing import uuid -from typing import Any, Optional +from typing import Any import httpx -import pymongo.database import sentry_sdk from pydantic import ValidationError from pydantic.main import BaseModel @@ -29,13 +26,16 @@ from backend.routes.forms.discover import AUTH_FORM from backend.routes.forms.unittesting import BypassDetectedError, execute_unittest from backend.validation import ErrorMessage, api +if typing.TYPE_CHECKING: + import pymongo.database + HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" HCAPTCHA_HEADERS = { - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", } DISCORD_HEADERS = { - "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}" + "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}", } @@ -46,7 +46,7 @@ class SubmissionResponse(BaseModel): class PartialSubmission(BaseModel): response: dict[str, Any] - captcha: Optional[str] + captcha: str | None class UnittestError(BaseModel): @@ -62,9 +62,7 @@ class UnittestErrorMessage(ErrorMessage): class SubmitForm(Route): - """ - Submit a form with the provided form ID. - """ + """Submit a form with the provided form ID.""" name = "submit_form" path = "/submit/{form_id:str}" @@ -75,9 +73,9 @@ class SubmitForm(Route): HTTP_200=SubmissionResponse, HTTP_404=ErrorMessage, HTTP_400=ErrorMessage, - HTTP_422=UnittestErrorMessage + HTTP_422=UnittestErrorMessage, ), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def post(self, request: Request) -> JSONResponse: """Submit a response to the form.""" @@ -92,7 +90,7 @@ class SubmitForm(Route): if old != request.user.token: try: expiry = datetime.datetime.fromisoformat( - request.user.decoded_token.get("expiry") + request.user.decoded_token.get("expiry"), ) except ValueError: expiry = None @@ -117,7 +115,7 @@ class SubmitForm(Route): id="not-submitted", form_id=AUTH_FORM.id, response={question.id: None for question in AUTH_FORM.questions}, - timestamp=datetime.datetime.now().isoformat() + timestamp=datetime.datetime.now().isoformat(), ).dict() return JSONResponse({"form": AUTH_FORM.dict(admin=False), "response": response}) @@ -131,8 +129,9 @@ class SubmitForm(Route): ip_hash_ctx = hashlib.md5() ip_hash_ctx.update( request.headers.get( - "Cf-Connecting-IP", request.client.host - ).encode() + "Cf-Connecting-IP", + request.client.host, + ).encode(), ) ip_hash = binascii.hexlify(ip_hash_ctx.digest()) user_agent_hash_ctx = hashlib.md5() @@ -142,12 +141,12 @@ class SubmitForm(Route): async with httpx.AsyncClient() as client: query_params = { "secret": constants.HCAPTCHA_API_SECRET, - "response": data.get("captcha") + "response": data.get("captcha"), } r = await client.post( HCAPTCHA_VERIFY_URL, params=query_params, - headers=HCAPTCHA_HEADERS + headers=HCAPTCHA_HEADERS, ) r.raise_for_status() captcha_data = r.json() @@ -155,7 +154,7 @@ class SubmitForm(Route): response["antispam"] = { "ip_hash": ip_hash.decode(), "user_agent_hash": user_agent_hash.decode(), - "captcha_pass": captcha_data["success"] + "captcha_pass": captcha_data["success"], } if constants.FormFeatures.REQUIRES_LOGIN.value in form.features: @@ -164,16 +163,12 @@ class SubmitForm(Route): response["user"]["admin"] = request.user.admin if ( - constants.FormFeatures.COLLECT_EMAIL.value in form.features - and "email" not in response["user"] + constants.FormFeatures.COLLECT_EMAIL.value in form.features + and "email" not in response["user"] ): - return JSONResponse({ - "error": "email_required" - }, status_code=400) + return JSONResponse({"error": "email_required"}, status_code=400) else: - return JSONResponse({ - "error": "missing_discord_data" - }, status_code=400) + return JSONResponse({"error": "missing_discord_data"}, status_code=400) missing_fields = [] for question in form.questions: @@ -184,10 +179,13 @@ class SubmitForm(Route): missing_fields.append(question.id) if missing_fields: - return JSONResponse({ - "error": "missing_fields", - "fields": missing_fields - }, status_code=400) + return JSONResponse( + { + "error": "missing_fields", + "fields": missing_fields, + }, + status_code=400, + ) try: response_obj = FormResponse(**response) @@ -200,10 +198,12 @@ class SubmitForm(Route): if len(errors): username = getattr(request.user, "user_id", "Unknown") - sentry_sdk.capture_exception(BypassDetectedError( - f"Detected unittest bypass attempt on form {form.id} by {username}. " - f"Submission has been written to reporting database ({response_obj.id})." - )) + sentry_sdk.capture_exception( + BypassDetectedError( + f"Detected unittest bypass attempt on form {form.id} by {username}. " + f"Submission has been written to reporting database ({response_obj.id}).", + ) + ) database: pymongo.database.Database = request.state.db await database.get_collection("violations").insert_one({ "user": username, @@ -219,7 +219,7 @@ class SubmitForm(Route): for test in unittest_results: response_obj.response[test.question_id] = { "value": response_obj.response[test.question_id], - "passed": test.passed + "passed": test.passed, } if test.return_code == 0: @@ -238,9 +238,8 @@ class SubmitForm(Route): # Report a failure on internal errors, # or if the test suite doesn't allow failures if not test.passed: - allow_failure = ( - form.questions[test.question_index].data["unittests"]["allow_failure"] - ) + question = form.questions[test.question_index] + allow_failure = question.data["unittests"]["allow_failure"] # An error while communicating with the test runner if test.return_code == 99: @@ -251,15 +250,16 @@ class SubmitForm(Route): failures.append(test) if len(failures): - return JSONResponse({ - "error": "failed_tests", - "test_results": [ - test._asdict() for test in failures - ] - }, status_code=status_code) + return JSONResponse( + { + "error": "failed_tests", + "test_results": [test._asdict() for test in failures], + }, + status_code=status_code, + ) await request.state.db.responses.insert_one( - response_obj.dict(by_alias=True) + response_obj.dict(by_alias=True), ) tasks = BackgroundTasks() @@ -272,36 +272,37 @@ class SubmitForm(Route): self.send_submission_webhook, form=form, response=response_obj, - request_user=request_user + request_user=request_user, ) if constants.FormFeatures.ASSIGN_ROLE.value in form.features: tasks.add_task( self.assign_role, form=form, - request_user=request.user + request_user=request.user, ) - return JSONResponse({ - "form": form.dict(admin=False), - "response": response_obj.dict() - }, background=tasks) + return JSONResponse( + { + "form": form.dict(admin=False), + "response": response_obj.dict(), + }, + background=tasks, + ) - else: - return JSONResponse({ - "error": "Open form not found" - }, status_code=404) + return JSONResponse({"error": "Open form not found"}, status_code=404) @staticmethod async def send_submission_webhook( - form: Form, - response: FormResponse, - request_user: typing.Optional[User] + form: Form, + response: FormResponse, + request_user: User | None, ) -> None: """Helper to send a submission message to a discord webhook.""" # Stop if webhook is not available if form.webhook is None: - raise ValueError("Got empty webhook.") + msg = "Got empty webhook." + raise ValueError(msg) try: mention = request_user.discord_mention @@ -330,7 +331,7 @@ class SubmitForm(Route): hook = { "embeds": [embed], "allowed_mentions": {"parse": ["users", "roles"]}, - "username": form.name or "Python Discord Forms" + "username": form.name or "Python Discord Forms", } # Set hook message @@ -345,8 +346,8 @@ class SubmitForm(Route): "time": response.timestamp, } - for key in ctx: - message = message.replace(f"{{{key}}}", str(ctx[key])) + for key, val in ctx.items(): + message = message.replace(f"{{{key}}}", str(val)) hook["content"] = message.replace("_USER_MENTION_", mention) @@ -359,11 +360,12 @@ class SubmitForm(Route): async def assign_role(form: Form, request_user: User) -> None: """Assigns Discord role to user when user submitted response.""" if not form.discord_role: - raise ValueError("Got empty Discord role ID.") + msg = "Got empty Discord role ID." + raise ValueError(msg) url = ( f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" - f"/members/{request_user.payload['id']}/roles/{form.discord_role}" + f"/members/{request_user.payload["id"]}/roles/{form.discord_role}" ) async with httpx.AsyncClient() as client: diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a02afea..3239d35 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -1,7 +1,8 @@ import base64 -from collections import namedtuple from itertools import count +from pathlib import Path from textwrap import indent +from typing import NamedTuple import httpx from httpx import HTTPStatusError @@ -9,7 +10,7 @@ from httpx import HTTPStatusError from backend.constants import SNEKBOX_URL from backend.models import Form, FormResponse -with open("resources/unittest_template.py") as file: +with Path("resources/unittest_template.py").open(encoding="utf8") as file: TEST_TEMPLATE = file.read() @@ -17,9 +18,12 @@ class BypassDetectedError(Exception): """Detected an attempt at bypassing the unittests.""" -UnittestResult = namedtuple( - "UnittestResult", "question_id question_index return_code passed result" -) +class UnittestResult(NamedTuple): + question_id: str + question_index: int + return_code: int + passed: bool + result: str def filter_unittests(form: Form) -> Form: @@ -46,11 +50,11 @@ def _make_unit_code(units: dict[str, str]) -> str: elif unit_name == "tearDown": result += "\ndef tearDown(self):" else: - name = f"test_{unit_name.removeprefix('#').removeprefix('test_')}" + name = f"test_{unit_name.removeprefix("#").removeprefix("test_")}" result += f"\nasync def {name}(self):" # Unite code - result += f"\n{indent(unit_code, ' ')}" + result += f"\n{indent(unit_code, " ")}" return indent(result, " ") @@ -72,7 +76,8 @@ async def _post_eval(code: str) -> dict[str, str]: async def execute_unittest( - form_response: FormResponse, form: Form + form_response: FormResponse, + form: Form, ) -> tuple[list[UnittestResult], list[BypassDetectedError]]: """Execute all the unittests in this form and return the results.""" unittest_results = [] @@ -80,16 +85,17 @@ async def execute_unittest( for index, question in enumerate(form.questions): if question.type == "code": - # Exit early if the suite doesn't have any tests if question.data["unittests"] is None: - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=0, - passed=True, - result="" - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=0, + passed=True, + result="", + ) + ) continue passed = False @@ -98,7 +104,7 @@ async def execute_unittest( hidden_test_counter = count(1) hidden_tests = { test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter) - for test in question.data["unittests"]["tests"].keys() + for test in question.data["unittests"]["tests"] if test.startswith("#") } @@ -124,18 +130,18 @@ async def execute_unittest( try: passed = bool(int(stdout[0])) except ValueError: - raise BypassDetectedError("Detected a bypass when reading result code.") + msg = "Detected a bypass when reading result code." + raise BypassDetectedError(msg) if passed and stdout.strip() != "1": # Most likely a bypass attempt # A 1 was written to stdout to indicate success, # followed by the actual output - raise BypassDetectedError( - "Detected improper value for stdout in unittest." - ) + msg = "Detected improper value for stdout in unittest." + raise BypassDetectedError(msg) # If the test failed, we have to populate the result string. - elif not passed: + if not passed: failed_tests = stdout[1:].strip().split(";") # Redact failed hidden tests @@ -146,7 +152,7 @@ async def execute_unittest( result = ";".join(failed_tests) else: result = "" - elif return_code in (5, 6, 99): + elif return_code in {5, 6, 99}: result = response["stdout"] # Killed by NsJail elif return_code == 137: @@ -162,12 +168,14 @@ async def execute_unittest( errors.append(error) passed = False - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=return_code, - passed=passed, - result=result - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=return_code, + passed=passed, + result=result, + ) + ) return unittest_results, errors diff --git a/backend/routes/index.py b/backend/routes/index.py index 207c36a..c6e38ea 100644 --- a/backend/routes/index.py +++ b/backend/routes/index.py @@ -1,6 +1,5 @@ -""" -Index route for the forms API. -""" +"""Index route for the forms API.""" + import platform from pydantic import BaseModel @@ -20,13 +19,13 @@ class IndexResponse(BaseModel): description=( "The connecting client, in production this will" " be an IP of our internal load balancer" - ) + ), ) sha: str = Field( - description="Current release Git SHA in production." + description="Current release Git SHA in production.", ) node: str = Field( - description="The node that processed the request." + description="The node that processed the request.", ) @@ -42,24 +41,22 @@ class IndexRoute(Route): @api.validate(resp=Response(HTTP_200=IndexResponse)) def get(self, request: Request) -> JSONResponse: - """ - Return a hello from Python Discord forms! - """ + """Return a hello from Python Discord forms!.""" response_data = { "message": "Hello, world!", "client": request.client.host, "user": { - "authenticated": False + "authenticated": False, }, "sha": GIT_SHA, - "node": platform.uname().node + "node": platform.uname().node, } if request.user.is_authenticated: response_data["user"] = { "authenticated": True, "user": request.user.payload, - "scopes": request.auth.scopes + "scopes": request.auth.scopes, } return JSONResponse(response_data) -- cgit v1.2.3