diff options
author | 2024-07-07 02:29:26 +0100 | |
---|---|---|
committer | 2024-07-08 15:00:10 +0100 | |
commit | d0e09d2ba567f23d91ac76d1844966bafb9b063a (patch) | |
tree | 9e825e3f09df02ab32e401c7e9555df26356dd4c /backend | |
parent | Change linting config to Ruff (diff) |
Apply fixable lint settings with Ruff
Diffstat (limited to 'backend')
26 files changed, 380 insertions, 396 deletions
diff --git a/backend/__init__.py b/backend/__init__.py index dcbdcdf..67015d7 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -27,7 +27,7 @@ sentry_sdk.init( dsn=constants.FORMS_BACKEND_DSN, send_default_pii=True, release=SENTRY_RELEASE, - environment=SENTRY_RELEASE + environment=SENTRY_RELEASE, ) middleware = [ @@ -36,10 +36,10 @@ middleware = [ allow_origins=["https://forms.pythondiscord.com"], allow_origin_regex=ALLOW_ORIGIN_REGEX, allow_headers=[ - "Content-Type" + "Content-Type", ], allow_methods=["*"], - allow_credentials=True + allow_credentials=True, ), Middleware(DatabaseMiddleware), Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()), diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index 54385e2..2512761 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,11 +1,9 @@ -import typing as t - import jwt from starlette import authentication from starlette.requests import Request -from backend import constants -from backend import discord +from backend import constants, discord + # We must import user such way here to avoid circular imports from .user import User @@ -19,20 +17,19 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): try: prefix, token = cookie.split() except ValueError: - raise authentication.AuthenticationError( - "Unable to split prefix and token from authorization cookie." - ) + msg = "Unable to split prefix and token from authorization cookie." + raise authentication.AuthenticationError(msg) if prefix.upper() != "JWT": - raise authentication.AuthenticationError( - f"Invalid authorization cookie prefix '{prefix}'." - ) + msg = f"Invalid authorization cookie prefix '{prefix}'." + raise authentication.AuthenticationError(msg) return token async def authenticate( - self, request: Request - ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]: + self, + request: Request, + ) -> tuple[authentication.AuthCredentials, authentication.BaseUser] | None: """Handles JWT authentication process.""" cookie = request.cookies.get("token") if not cookie: @@ -48,21 +45,25 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): scopes = ["authenticated"] if not payload.get("token"): - raise authentication.AuthenticationError("Token is missing from JWT.") + msg = "Token is missing from JWT." + raise authentication.AuthenticationError(msg) if not payload.get("refresh"): - raise authentication.AuthenticationError( - "Refresh token is missing from JWT." - ) + msg = "Refresh token is missing from JWT." + raise authentication.AuthenticationError(msg) try: user_details = payload.get("user_details") if not user_details or not user_details.get("id"): - raise authentication.AuthenticationError("Improper user details.") - except Exception: - raise authentication.AuthenticationError("Could not parse user details.") + msg = "Improper user details." + raise authentication.AuthenticationError(msg) # noqa: TRY301 + except Exception: # noqa: BLE001 + msg = "Could not parse user details." + raise authentication.AuthenticationError(msg) user = User( - token, user_details, await discord.get_member(request.state.db, user_details["id"]) + token, + user_details, + await discord.get_member(request.state.db, user_details["id"]), ) if await user.fetch_admin_status(request.state.db): scopes.append("admin") diff --git a/backend/authentication/user.py b/backend/authentication/user.py index cd5a249..c81b7a9 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,4 +1,3 @@ -import typing import typing as t import jwt @@ -16,7 +15,7 @@ class User(BaseUser): self, token: str, payload: dict[str, t.Any], - member: typing.Optional[models.DiscordMember], + member: models.DiscordMember | None, ) -> None: self.token = token self.payload = payload @@ -31,11 +30,11 @@ class User(BaseUser): @property def display_name(self) -> str: """Return username and discriminator as display name.""" - return f"{self.payload['username']}#{self.payload['discriminator']}" + return f"{self.payload["username"]}#{self.payload["discriminator"]}" @property def discord_mention(self) -> str: - return f"<@{self.payload['id']}>" + return f"<@{self.payload["id"]}>" @property def user_id(self) -> str: @@ -61,9 +60,10 @@ class User(BaseUser): return roles async def fetch_admin_status(self, database: Database) -> bool: - self.admin = await database.admins.find_one( - {"_id": self.payload["id"]} - ) is not None + query = {"_id": self.payload["id"]} + found_admin = await database.admins.find_one(query) + + self.admin = found_admin is not None return self.admin diff --git a/backend/constants.py b/backend/constants.py index e1c38d3..8089077 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -18,7 +18,8 @@ PRODUCTION_URL = "https://forms.pythondiscord.com" OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID") OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET") OAUTH2_REDIRECT_URI = os.getenv( - "OAUTH2_REDIRECT_URI", "https://forms.pythondiscord.com/callback" + "OAUTH2_REDIRECT_URI", + "https://forms.pythondiscord.com/callback", ) GIT_SHA = os.getenv("GIT_SHA", "dev") @@ -28,7 +29,7 @@ DOCS_PASSWORD = os.getenv("DOCS_PASSWORD") SECRET_KEY = os.getenv("SECRET_KEY", binascii.hexlify(os.urandom(30)).decode()) DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN") -DISCORD_GUILD = os.getenv("DISCORD_GUILD", 267624335836053506) +DISCORD_GUILD = os.getenv("DISCORD_GUILD", "267624335836053506") HCAPTCHA_API_SECRET = os.getenv("HCAPTCHA_API_SECRET") diff --git a/backend/discord.py b/backend/discord.py index ff6c1bb..dc5989a 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -2,7 +2,6 @@ import datetime import json -import typing import httpx import starlette.requests @@ -17,7 +16,7 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict data = { "client_id": constants.OAUTH2_CLIENT_ID, "client_secret": constants.OAUTH2_CLIENT_SECRET, - "redirect_uri": f"{redirect}/callback" + "redirect_uri": f"{redirect}/callback", } if refresh: @@ -27,9 +26,13 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict data["grant_type"] = "authorization_code" data["code"] = code - r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={ - "Content-Type": "application/x-www-form-urlencoded" - }, data=data) + r = await client.post( + f"{constants.DISCORD_API_BASE_URL}/oauth2/token", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + data=data, + ) r.raise_for_status() @@ -38,9 +41,12 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict async def fetch_user_details(bearer_token: str) -> dict: async with httpx.AsyncClient() as client: - r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={ - "Authorization": f"Bearer {bearer_token}" - }) + r = await client.get( + f"{constants.DISCORD_API_BASE_URL}/users/@me", + headers={ + "Authorization": f"Bearer {bearer_token}", + }, + ) r.raise_for_status() @@ -52,7 +58,7 @@ async def _get_role_info() -> list[models.DiscordRole]: async with httpx.AsyncClient() as client: r = await client.get( f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles", - headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} + headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}, ) r.raise_for_status() @@ -60,7 +66,9 @@ async def _get_role_info() -> list[models.DiscordRole]: async def get_roles( - database: Database, *, force_refresh: bool = False + database: Database, + *, + force_refresh: bool = False, ) -> list[models.DiscordRole]: """ Get a list of all roles from the cache, or discord API if not available. @@ -86,23 +94,26 @@ async def get_roles( if len(roles) == 0: # Fetch roles from the API and insert into the database roles = await _get_role_info() - await collection.insert_many({ - "name": role.name, - "id": role.id, - "data": role.json(), - "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), - } for role in roles) + await collection.insert_many( + { + "name": role.name, + "id": role.id, + "data": role.json(), + "inserted_at": datetime.datetime.now(tz=datetime.UTC), + } + for role in roles + ) return roles -async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]: +async def _fetch_member_api(member_id: str) -> models.DiscordMember | None: """Get a member by ID from the configured guild using the discord API.""" async with httpx.AsyncClient() as client: r = await client.get( f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" f"/members/{member_id}", - headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} + headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}, ) if r.status_code == 404: @@ -113,8 +124,11 @@ async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMem async def get_member( - database: Database, user_id: str, *, force_refresh: bool = False -) -> typing.Optional[models.DiscordMember]: + database: Database, + user_id: str, + *, + force_refresh: bool = False, +) -> models.DiscordMember | None: """ Get a member from the cache, or from the discord API. @@ -147,7 +161,7 @@ async def get_member( await collection.insert_one({ "user": user_id, "data": member.json(), - "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), + "inserted_at": datetime.datetime.now(tz=datetime.UTC), }) return member @@ -161,7 +175,9 @@ class UnauthorizedError(exceptions.HTTPException): async def _verify_access_helper( - form_id: str, request: starlette.requests.Request, attribute: str + form_id: str, + request: starlette.requests.Request, + attribute: str, ) -> None: """A low level helper to validate access to a form resource based on the user's scopes.""" form = await request.state.db.forms.find_one({"_id": form_id}) diff --git a/backend/middleware.py b/backend/middleware.py index 7a3bdc8..0b08859 100644 --- a/backend/middleware.py +++ b/backend/middleware.py @@ -7,14 +7,13 @@ from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE class DatabaseMiddleware: - def __init__(self, app: ASGIApp) -> None: self._app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: client: AsyncIOMotorClient = AsyncIOMotorClient( DATABASE_URL, - tlsAllowInvalidCertificates=True + tlsAllowInvalidCertificates=True, ) db = client[MONGO_DATABASE] Request(scope).state.db = db @@ -22,7 +21,6 @@ class DatabaseMiddleware: class ProtectedDocsMiddleware: - def __init__(self, app: ASGIApp) -> None: self._app = app diff --git a/backend/models/__init__.py b/backend/models/__init__.py index a9f76e0..336e28b 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -7,13 +7,13 @@ from .question import CodeQuestion, Question __all__ = [ "AntiSpam", + "CodeQuestion", + "DiscordMember", "DiscordRole", "DiscordUser", - "DiscordMember", "Form", + "FormList", "FormResponse", - "CodeQuestion", "Question", - "FormList", - "ResponseList" + "ResponseList", ] diff --git a/backend/models/discord_role.py b/backend/models/discord_role.py index ada35ef..195f557 100644 --- a/backend/models/discord_role.py +++ b/backend/models/discord_role.py @@ -1,13 +1,11 @@ -import typing - from pydantic import BaseModel class RoleTags(BaseModel): """Meta information about a discord role.""" - bot_id: typing.Optional[str] - integration_id: typing.Optional[str] + bot_id: str | None + integration_id: str | None premium_subscriber: bool def __init__(self, **data) -> None: @@ -20,7 +18,7 @@ class RoleTags(BaseModel): We manually parse the raw data to determine if the field exists, and give it a useful bool value. """ - data["premium_subscriber"] = "premium_subscriber" in data.keys() + data["premium_subscriber"] = "premium_subscriber" in data super().__init__(**data) @@ -31,10 +29,10 @@ class DiscordRole(BaseModel): name: str color: int hoist: bool - icon: typing.Optional[str] - unicode_emoji: typing.Optional[str] + icon: str | None + unicode_emoji: str | None position: int permissions: str managed: bool mentionable: bool - tags: typing.Optional[RoleTags] + tags: RoleTags | None diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py index 0eca15b..be10672 100644 --- a/backend/models/discord_user.py +++ b/backend/models/discord_user.py @@ -11,15 +11,15 @@ class _User(BaseModel): username: str id: str discriminator: str - avatar: t.Optional[str] - bot: t.Optional[bool] - system: t.Optional[bool] - locale: t.Optional[str] - verified: t.Optional[bool] - email: t.Optional[str] - flags: t.Optional[int] - premium_type: t.Optional[int] - public_flags: t.Optional[int] + avatar: str | None + bot: bool | None + system: bool | None + locale: str | None + verified: bool | None + email: str | None + flags: int | None + premium_type: int | None + public_flags: int | None class DiscordUser(_User): @@ -33,16 +33,16 @@ class DiscordMember(BaseModel): """A discord guild member.""" user: _User - nick: t.Optional[str] - avatar: t.Optional[str] + nick: str | None + avatar: str | None roles: list[str] joined_at: datetime.datetime - premium_since: t.Optional[datetime.datetime] + premium_since: datetime.datetime | None deaf: bool mute: bool - pending: t.Optional[bool] - permissions: t.Optional[str] - communication_disabled_until: t.Optional[datetime.datetime] + pending: bool | None + permissions: str | None + communication_disabled_until: datetime.datetime | None def dict(self, *args, **kwargs) -> dict[str, t.Any]: """Convert the model to a python dict, and encode timestamps in a serializable format.""" diff --git a/backend/models/form.py b/backend/models/form.py index 10c8bfd..3db267e 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, constr, root_validator, validator from pydantic.error_wrappers import ErrorWrapper, ValidationError from backend.constants import DISCORD_GUILD, FormFeatures, WebHook + from .question import Question PUBLIC_FIELDS = [ @@ -14,20 +15,22 @@ PUBLIC_FIELDS = [ "name", "description", "submitted_text", - "discord_role" + "discord_role", ] class _WebHook(BaseModel): """Schema model of discord webhooks.""" + url: str - message: t.Optional[str] + message: str | None @validator("url") def validate_url(cls, url: str) -> str: """Validates URL parameter.""" if "discord.com/api/webhooks/" not in url: - raise ValueError("URL must be a discord webhook.") + msg = "URL must be a discord webhook." + raise ValueError(msg) return url @@ -40,56 +43,55 @@ class Form(BaseModel): questions: list[Question] name: str description: str - submitted_text: t.Optional[str] = None + submitted_text: str | None = None webhook: _WebHook = None - discord_role: t.Optional[str] - response_readers: t.Optional[list[str]] - editors: t.Optional[list[str]] + discord_role: str | None + response_readers: list[str] | None + editors: list[str] | None class Config: allow_population_by_field_name = True @validator("features") - def validate_features(cls, value: list[str]) -> t.Optional[list[str]]: + def validate_features(cls, value: list[str]) -> list[str]: """Validates is all features in allowed list.""" # Uppercase everything to avoid mixed case in DB value = [v.upper() for v in value] allowed_values = [v.value for v in FormFeatures.__members__.values()] if any(v not in allowed_values for v in value): - raise ValueError("Form features list contains one or more invalid values.") + msg = "Form features list contains one or more invalid values." + raise ValueError(msg) if FormFeatures.REQUIRES_LOGIN.value not in value: if FormFeatures.COLLECT_EMAIL.value in value: - raise ValueError( - "COLLECT_EMAIL feature require REQUIRES_LOGIN feature." - ) + msg = "COLLECT_EMAIL feature require REQUIRES_LOGIN feature." + raise ValueError(msg) if FormFeatures.ASSIGN_ROLE.value in value: - raise ValueError("ASSIGN_ROLE feature require REQUIRES_LOGIN feature.") + msg = "ASSIGN_ROLE feature require REQUIRES_LOGIN feature." + raise ValueError(msg) return value @validator("response_readers", "editors") - def validate_role_scoping(cls, value: t.Optional[list[str]]) -> t.Optional[list[str]]: + def validate_role_scoping(cls, value: list[str] | None) -> list[str]: """Ensure special role based permissions aren't granted to the @everyone role.""" - if value and str(DISCORD_GUILD) in value: - raise ValueError("You can not add the everyone role as an access scope.") + if value and DISCORD_GUILD in value: + msg = "You can not add the everyone role as an access scope." + raise ValueError(msg) return value @root_validator - def validate_role(cls, values: dict[str, t.Any]) -> t.Optional[dict[str, t.Any]]: + def validate_role(cls, values: dict[str, t.Any]) -> dict[str, t.Any]: """Validates does Discord role provided when flag provided.""" - if ( - FormFeatures.ASSIGN_ROLE.value in values.get("features", []) - and not values.get("discord_role") - ): - raise ValueError( - "discord_role field is required when ASSIGN_ROLE flag is provided." - ) + is_role_assigner = FormFeatures.ASSIGN_ROLE.value in values.get("features", []) + if is_role_assigner and not values.get("discord_role"): + msg = "discord_role field is required when ASSIGN_ROLE flag is provided." + raise ValueError(msg) return values - def dict(self, admin: bool = True, **kwargs) -> dict[str, t.Any]: + def dict(self, admin: bool = True, **kwargs) -> dict[str, t.Any]: # noqa: FBT001, FBT002 """Wrapper for original function to exclude private data for public access.""" data = super().dict(**kwargs) @@ -97,10 +99,7 @@ class Form(BaseModel): if not admin: for field in PUBLIC_FIELDS: - if field == "id" and kwargs.get("by_alias"): - fetch_field = "_id" - else: - fetch_field = field + fetch_field = "_id" if field == "id" and kwargs.get("by_alias") else field returned_data[field] = data[fetch_field] else: @@ -110,17 +109,20 @@ class Form(BaseModel): class FormList(BaseModel): - __root__: t.List[Form] + __root__: list[Form] -async def validate_hook_url(url: str) -> t.Optional[ValidationError]: +async def validate_hook_url(url: str) -> ValidationError | None: """Validator for discord webhook urls.""" - async def validate() -> t.Optional[str]: + + async def validate() -> str | None: if not isinstance(url, str): - raise ValueError("Webhook URL must be a string.") + msg = "Webhook URL must be a string." + raise TypeError(msg) if "discord.com/api/webhooks/" not in url: - raise ValueError("URL must be a discord webhook.") + msg = "URL must be a discord webhook." + raise ValueError(msg) try: async with httpx.AsyncClient() as client: @@ -129,36 +131,32 @@ async def validate_hook_url(url: str) -> t.Optional[ValidationError]: except httpx.RequestError as error: # Catch exceptions in request format - raise ValueError( - f"Encountered error while trying to connect to url: `{error}`" - ) + msg = f"Encountered error while trying to connect to url: `{error}`" + raise ValueError(msg) except httpx.HTTPStatusError as error: # Catch exceptions in response status = error.response.status_code if status == 401: - raise ValueError( - "Could not authenticate with target. Please check the webhook url." - ) - elif status == 404: - raise ValueError( - "Target could not find webhook url. Please check the webhook url." - ) - else: - raise ValueError( - f"Unknown error ({status}) while connecting to target: {error}" - ) + msg = "Could not authenticate with target. Please check the webhook url." + raise ValueError(msg) + if status == 404: + msg = "Target could not find webhook url. Please check the webhook url." + raise ValueError(msg) + + msg = f"Unknown error ({status}) while connecting to target: {error}" + raise ValueError(msg) return url # Validate, and return errors, if any try: await validate() - except Exception as e: + except Exception as e: # noqa: BLE001 loc = ( WebHook.__name__.lower(), - WebHook.URL.value + WebHook.URL.value, ) return ValidationError([ErrorWrapper(e, loc=loc)], _WebHook) diff --git a/backend/models/form_response.py b/backend/models/form_response.py index 933f5e4..3c8297b 100644 --- a/backend/models/form_response.py +++ b/backend/models/form_response.py @@ -11,19 +11,20 @@ class FormResponse(BaseModel): """Schema model for form response.""" id: str = Field(alias="_id") - user: t.Optional[DiscordUser] - antispam: t.Optional[AntiSpam] + user: DiscordUser | None + antispam: AntiSpam | None response: dict[str, t.Any] form_id: str timestamp: str @validator("timestamp", pre=True) - def set_timestamp(cls, iso_string: t.Optional[str]) -> t.Optional[str]: + def set_timestamp(cls, iso_string: str | None) -> str: if iso_string is None: - return datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + return datetime.datetime.now(tz=datetime.UTC).isoformat() - elif not isinstance(iso_string, str): - raise ValueError("Submission timestamp must be a string.") + if not isinstance(iso_string, str): + msg = "Submission timestamp must be a string." + raise TypeError(msg) # Convert to datetime and back to ensure string is valid return datetime.datetime.fromisoformat(iso_string).isoformat() @@ -33,4 +34,4 @@ class FormResponse(BaseModel): class ResponseList(BaseModel): - __root__: t.List[FormResponse] + __root__: list[FormResponse] diff --git a/backend/models/question.py b/backend/models/question.py index 201aa51..a13ce93 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -4,11 +4,12 @@ from pydantic import BaseModel, Field, root_validator, validator from backend.constants import QUESTION_TYPES, REQUIRED_QUESTION_TYPE_DATA -_TESTS_TYPE = t.Union[t.Dict[str, str], int] +_TESTS_TYPE = dict[str, str] | int class Unittests(BaseModel): """Schema model for unittest suites in code questions.""" + allow_failure: bool = False tests: _TESTS_TYPE @@ -16,17 +17,19 @@ class Unittests(BaseModel): def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE: """Confirm that at least one test exists in a test suite.""" if isinstance(value, dict): - keys = len(value.keys()) - (1 if "setUp" in value.keys() else 0) + keys = len(value.keys()) - (1 if "setUp" in value else 0) if keys == 0: - raise ValueError("Must have at least one test in a test suite.") + msg = "Must have at least one test in a test suite." + raise ValueError(msg) return value class CodeQuestion(BaseModel): """Schema model for questions of type `code`.""" + language: str - unittests: t.Optional[Unittests] + unittests: Unittests | None class Question(BaseModel): @@ -42,22 +45,20 @@ class Question(BaseModel): allow_population_by_field_name = True @validator("type", pre=True) - def validate_question_type(cls, value: str) -> t.Optional[str]: + def validate_question_type(cls, value: str) -> str: """Checks if question type in currently allowed types list.""" value = value.lower() if value not in QUESTION_TYPES: - raise ValueError( - f"{value} is not valid question type. " - f"Allowed question types: {QUESTION_TYPES}." - ) + msg = f"{value} is not valid question type. Allowed question types: {QUESTION_TYPES}." + raise ValueError(msg) return value @root_validator def validate_question_data( - cls, - value: dict[str, t.Any] - ) -> t.Optional[dict[str, t.Any]]: + cls, + value: dict[str, t.Any], + ) -> dict[str, t.Any]: """Check does required data exists for question type and remove other data.""" # When question type don't need data, don't add anything to keep DB clean. if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA: @@ -65,13 +66,15 @@ class Question(BaseModel): for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items(): if key not in value.get("data", {}): - raise ValueError(f"Required question data key '{key}' not provided.") + msg = f"Required question data key '{key}' not provided." + raise ValueError(msg) if not isinstance(value["data"][key], data_type): - raise ValueError( + msg = ( f"Question data key '{key}' expects {data_type.__name__}, " - f"got {type(value['data'][key]).__name__} instead." + f"got {type(value["data"][key]).__name__} instead." ) + raise TypeError(msg) # Validate unittest options if value.get("type").lower() == "code": diff --git a/backend/route.py b/backend/route.py index d778bf0..a9ea7ad 100644 --- a/backend/route.py +++ b/backend/route.py @@ -1,6 +1,5 @@ -""" -Base class for implementing dynamic routing. -""" +"""Base class for implementing dynamic routing.""" + from starlette.endpoints import HTTPEndpoint @@ -11,7 +10,9 @@ class Route(HTTPEndpoint): @classmethod def check_parameters(cls) -> None: if not hasattr(cls, "name"): - raise ValueError(f"Route {cls.__name__} has not defined a name") + msg = f"Route {cls.__name__} has not defined a name" + raise ValueError(msg) if not hasattr(cls, "path"): - raise ValueError(f"Route {cls.__name__} has not defined a path") + msg = f"Route {cls.__name__} has not defined a path" + raise ValueError(msg) diff --git a/backend/route_manager.py b/backend/route_manager.py index 2d95bb2..b35ca0b 100644 --- a/backend/route_manager.py +++ b/backend/route_manager.py @@ -1,6 +1,4 @@ -""" -Module to dynamically generate a Starlette routing map based on a directory tree. -""" +"""Module to dynamically generate a Starlette routing map based on a directory tree.""" import importlib import inspect @@ -27,7 +25,7 @@ def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]: return route_map -def is_route_class(member: t.Any) -> bool: # noqa: ANN401 +def is_route_class(member: t.Any) -> bool: return inspect.isclass(member) and issubclass(member, Route) and member != Route @@ -35,7 +33,7 @@ def route_classes() -> t.Iterator[tuple[Path, type[Route]]]: routes_directory = Path("backend") / "routes" for module_path in routes_directory.rglob("*.py"): - import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}" + import_name = f"{".".join(module_path.parent.parts)}.{module_path.stem}" route_module = importlib.import_module(import_name) for _member_name, member in inspect.getmembers(route_module): if is_route_class(member): @@ -47,7 +45,7 @@ def create_route_map() -> list[BaseRoute]: route_dict = nested_dict() for module_path, member in route_classes(): - # module_path == Path("backend/routes/foo/bar/baz/bin.py") + # For Path: "backend/routes/foo/bar/baz/bin.py" # => levels == ["foo", "bar", "baz"] levels = module_path.parent.parts[2:] current_level = None diff --git a/backend/routes/admin.py b/backend/routes/admin.py index 0fd0700..848abce 100644 --- a/backend/routes/admin.py +++ b/backend/routes/admin.py @@ -1,6 +1,5 @@ -""" -Adds new admin user. -""" +"""Adds new admin user.""" + from pydantic import BaseModel, Field from spectree import Response from starlette.authentication import requires @@ -22,7 +21,7 @@ async def grant(request: Request) -> JSONResponse: admin = AdminModel(**data) if await request.state.db.admins.find_one( - {"_id": admin.id} + {"_id": admin.id}, ): return JSONResponse({"error": "already_exists"}, status_code=400) @@ -40,7 +39,7 @@ class AdminRoute(Route): @api.validate( json=AdminModel, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["admin"] + tags=["admin"], ) async def post(self, request: Request) -> JSONResponse: """Grant a user administrator privileges.""" @@ -48,6 +47,7 @@ class AdminRoute(Route): if not constants.PRODUCTION: + class AdminDev(Route): """Adds new admin user with no authentication.""" @@ -57,7 +57,7 @@ if not constants.PRODUCTION: @api.validate( json=AdminModel, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["admin"] + tags=["admin"], ) async def post(self, request: Request) -> JSONResponse: """ diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 42fb3ec..bc80a7d 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -1,9 +1,6 @@ -""" -Use a token received from the Discord OAuth2 system to fetch user information. -""" +"""Use a token received from the Discord OAuth2 system to fetch user information.""" import datetime -from typing import Union import httpx import jwt @@ -35,8 +32,8 @@ class AuthorizeResponse(BaseModel): async def process_token( bearer_token: dict, - request: Request -) -> Union[AuthorizeResponse, AUTH_FAILURE]: + request: Request, +) -> AuthorizeResponse | responses.JSONResponse: """Post a bearer token to Discord, and return a JWT and username.""" interaction_start = datetime.datetime.now() @@ -57,7 +54,7 @@ async def process_token( "refresh": bearer_token["refresh_token"], "user_details": user_details, "in_guild": bool(member), - "expiry": token_expiry.isoformat() + "expiry": token_expiry.isoformat(), } token = jwt.encode(data, SECRET_KEY, algorithm="HS256") @@ -65,18 +62,18 @@ async def process_token( response = responses.JSONResponse({ "username": user.display_name, - "expiry": token_expiry.isoformat() + "expiry": token_expiry.isoformat(), }) - await set_response_token(response, request, token, bearer_token["expires_in"]) + set_response_token(response, request, token, bearer_token["expires_in"]) return response -async def set_response_token( - response: responses.Response, - request: Request, - new_token: str, - expiry: int +def set_response_token( + response: responses.Response, + request: Request, + new_token: str, + expiry: int, ) -> None: """Helper that handles logic for updating a token in a set-cookie response.""" origin_url = request.headers.get("origin") @@ -94,19 +91,18 @@ async def set_response_token( samesite = "None" response.set_cookie( - "token", f"JWT {new_token}", + "token", + f"JWT {new_token}", secure=constants.PRODUCTION, httponly=True, samesite=samesite, domain=domain, - max_age=expiry + max_age=expiry, ) class AuthorizeRoute(Route): - """ - Use the authorization code from Discord to generate a JWT token. - """ + """Use the authorization code from Discord to generate a JWT token.""" name = "authorize" path = "/authorize" @@ -114,7 +110,7 @@ class AuthorizeRoute(Route): @api.validate( json=AuthorizeRequest, resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), - tags=["auth"] + tags=["auth"], ) async def post(self, request: Request) -> responses.JSONResponse: """Generate an authorization token.""" @@ -129,9 +125,7 @@ class AuthorizeRoute(Route): class TokenRefreshRoute(Route): - """ - Use the refresh code from a JWT to get a new token and generate a new JWT token. - """ + """Use the refresh code from a JWT to get a new token and generate a new JWT token.""" name = "refresh" path = "/refresh" @@ -139,7 +133,7 @@ class TokenRefreshRoute(Route): @requires(["authenticated"]) @api.validate( resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), - tags=["auth"] + tags=["auth"], ) async def post(self, request: Request) -> responses.JSONResponse: """Refresh an authorization token.""" diff --git a/backend/routes/discord.py b/backend/routes/discord.py index bca1edb..53b8af3 100644 --- a/backend/routes/discord.py +++ b/backend/routes/discord.py @@ -10,7 +10,8 @@ from backend import discord, models, route from backend.validation import ErrorMessage, api NOT_FOUND_EXCEPTION = JSONResponse( - {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 + {"error": "Could not find the requested resource in the guild or cache."}, + status_code=404, ) @@ -28,7 +29,7 @@ class RolesRoute(route.Route): @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=RolesResponse), - tags=["roles"] + tags=["roles"], ) async def patch(self, request: Request) -> JSONResponse: """Refresh the roles database.""" @@ -54,7 +55,7 @@ class MemberRoute(route.Route): @api.validate( resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), json=MemberRequest, - tags=["auth"] + tags=["auth"], ) async def delete(self, request: Request) -> JSONResponse: """Force a resync of the cache for the given user.""" @@ -63,21 +64,20 @@ class MemberRoute(route.Route): if member: return JSONResponse(member.dict()) - else: - return NOT_FOUND_EXCEPTION + return NOT_FOUND_EXCEPTION @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), json=MemberRequest, - tags=["auth"] + tags=["auth"], ) async def get(self, request: Request) -> JSONResponse: """Get a user's roles on the configured server.""" body = await request.json() member = await discord.get_member(request.state.db, body["user_id"]) - if member: - return JSONResponse(member.dict()) - else: + if not member: return NOT_FOUND_EXCEPTION + + return JSONResponse(member.dict()) diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py index 75ff495..0fe10b5 100644 --- a/backend/routes/forms/discover.py +++ b/backend/routes/forms/discover.py @@ -1,6 +1,5 @@ -""" -Return a list of all publicly discoverable forms to unauthenticated users. -""" +"""Return a list of all publicly discoverable forms to unauthenticated users.""" + from spectree.response import Response from starlette.requests import Request from starlette.responses import JSONResponse @@ -12,7 +11,7 @@ from backend.validation import api __FEATURES = [ constants.FormFeatures.OPEN.value, - constants.FormFeatures.REQUIRES_LOGIN.value + constants.FormFeatures.REQUIRES_LOGIN.value, ] if not constants.PRODUCTION: __FEATURES.append(constants.FormFeatures.DISCOVERABLE.value) @@ -22,7 +21,7 @@ __QUESTION = Question( name="Click the button below to log into the forms application.", type="section", data={"text": ""}, - required=False + required=False, ) AUTH_FORM = Form( @@ -31,14 +30,12 @@ AUTH_FORM = Form( questions=[__QUESTION], name="Login", description="Log into Python Discord Forms.", - submitted_text="This page can't be submitted." + submitted_text="This page can't be submitted.", ) class DiscoverableFormsList(Route): - """ - List all discoverable forms that should be shown on the homepage. - """ + """List all discoverable forms that should be shown on the homepage.""" name = "discoverable_forms_list" path = "/discoverable" @@ -46,15 +43,11 @@ class DiscoverableFormsList(Route): @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """List all discoverable forms that should be shown on the homepage.""" - forms = [] cursor = request.state.db.forms.find({"features": "DISCOVERABLE"}).sort("name") # Parse it to Form and then back to dictionary # to replace _id with id - for form in await cursor.to_list(None): - forms.append(Form(**form)) - - forms = [form.dict(admin=False) for form in forms] + forms = [Form(**form).dict(admin=False) for form in await cursor.to_list(None)] # Return an empty form in development environments to help with authentication. if not constants.PRODUCTION: diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index 020193c..410102a 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -1,6 +1,5 @@ -""" -Returns, updates or deletes a single form given an ID. -""" +"""Returns, updates or deletes a single form given an ID.""" + import json.decoder import deepmerge @@ -48,7 +47,7 @@ class SingleForm(Route): admin = False filters = { - "_id": form_id + "_id": form_id, } if not admin: @@ -71,7 +70,7 @@ class SingleForm(Route): HTTP_400=ErrorMessage, HTTP_404=ErrorMessage, ), - tags=["forms"] + tags=["forms"], ) async def patch(self, request: Request) -> JSONResponse: """Updates form by ID.""" @@ -90,7 +89,7 @@ class SingleForm(Route): # Build Data Merger merge_strategy = [ - (dict, ["merge"]) + (dict, ["merge"]), ] merger = deepmerge.Merger(merge_strategy, ["override"], ["override"]) @@ -105,13 +104,12 @@ class SingleForm(Route): await request.state.db.forms.replace_one({"_id": form_id}, form.dict()) return JSONResponse(form.dict()) - else: - return JSONResponse({"error": "not_found"}, status_code=404) + return JSONResponse({"error": "not_found"}, status_code=404) @requires(["authenticated", "admin"]) @api.validate( resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage), - tags=["forms"] + tags=["forms"], ) async def delete(self, request: Request) -> JSONResponse: """Deletes form by ID.""" diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py index 38be693..1fdfc48 100644 --- a/backend/routes/forms/index.py +++ b/backend/routes/forms/index.py @@ -1,6 +1,5 @@ -""" -Return a list of all forms to authenticated users. -""" +"""Return a list of all forms to authenticated users.""" + from spectree.response import Response from starlette.authentication import requires from starlette.requests import Request @@ -14,9 +13,7 @@ from backend.validation import ErrorMessage, OkayResponse, api class FormsList(Route): - """ - List all available forms for authorized viewers. - """ + """List all available forms for authorized viewers.""" name = "forms_list_create" path = "/" @@ -25,24 +22,17 @@ class FormsList(Route): @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """Return a list of all forms to authenticated users.""" - forms = [] cursor = request.state.db.forms.find() - for form in await cursor.to_list(None): - forms.append(Form(**form)) # For converting _id to id - - # Covert them back to dictionaries - forms = [form.dict() for form in forms] + forms = [Form(**form).dict() for form in await cursor.to_list(None)] - return JSONResponse( - forms - ) + return JSONResponse(forms) @requires(["authenticated", "Helpers"]) @api.validate( json=Form, resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), - tags=["forms"] + tags=["forms"], ) async def post(self, request: Request) -> JSONResponse: """Create a new form.""" @@ -66,9 +56,7 @@ class FormsList(Route): form = Form(**form_data) if await request.state.db.forms.find_one({"_id": form.id}): - return JSONResponse({ - "error": "id_taken" - }, status_code=400) + return JSONResponse({"error": "id_taken"}, status_code=400) await request.state.db.forms.insert_one(form.dict(by_alias=True)) return JSONResponse(form.dict()) diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py index 565701f..b4f7f04 100644 --- a/backend/routes/forms/response.py +++ b/backend/routes/forms/response.py @@ -1,6 +1,4 @@ -""" -Returns or deletes form response by ID. -""" +"""Returns or deletes form response by ID.""" from spectree import Response as RouteResponse from starlette.authentication import requires @@ -22,7 +20,7 @@ class Response(Route): @requires(["authenticated"]) @api.validate( resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def get(self, request: Request) -> JSONResponse: """Return a single form response by ID.""" @@ -32,30 +30,29 @@ class Response(Route): if raw_response := await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], - "form_id": form_id - } + "form_id": form_id, + }, ): response = FormResponse(**raw_response) return JSONResponse(response.dict()) - else: - return JSONResponse({"error": "response_not_found"}, status_code=404) + return JSONResponse({"error": "response_not_found"}, status_code=404) @requires(["authenticated", "admin"]) @api.validate( resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def delete(self, request: Request) -> JSONResponse: """Delete a form response by ID.""" if not await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], - "form_id": request.path_params["form_id"] - } + "form_id": request.path_params["form_id"], + }, ): return JSONResponse({"error": "not_found"}, status_code=404) await request.state.db.responses.delete_one( - {"_id": request.path_params["response_id"]} + {"_id": request.path_params["response_id"]}, ) return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py index 818ebce..85e5af2 100644 --- a/backend/routes/forms/responses.py +++ b/backend/routes/forms/responses.py @@ -1,6 +1,5 @@ -""" -Returns all form responses by form ID. -""" +"""Returns all form responses by form ID.""" + from pydantic import BaseModel from spectree import Response from starlette.authentication import requires @@ -18,9 +17,7 @@ class ResponseIdList(BaseModel): class Responses(Route): - """ - Returns all form responses by form ID. - """ + """Returns all form responses by form ID.""" name = "form_responses" path = "/{form_id:str}/responses" @@ -28,7 +25,7 @@ class Responses(Route): @requires(["authenticated"]) @api.validate( resp=Response(HTTP_200=ResponseList), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def get(self, request: Request) -> JSONResponse: """Returns all form responses by form ID.""" @@ -36,11 +33,9 @@ class Responses(Route): await discord.verify_response_access(form_id, request) cursor = request.state.db.responses.find( - {"form_id": form_id} + {"form_id": form_id}, ) - responses = [ - FormResponse(**response) for response in await cursor.to_list(None) - ] + responses = [FormResponse(**response) for response in await cursor.to_list(None)] return JSONResponse([response.dict() for response in responses]) @requires(["authenticated", "admin"]) @@ -49,14 +44,14 @@ class Responses(Route): resp=Response( HTTP_200=OkayResponse, HTTP_404=ErrorMessage, - HTTP_400=ErrorMessage + HTTP_400=ErrorMessage, ), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def delete(self, request: Request) -> JSONResponse: """Bulk deletes form responses by IDs.""" if not await request.state.db.forms.find_one( - {"_id": request.path_params["form_id"]} + {"_id": request.path_params["form_id"]}, ): return JSONResponse({"error": "not_found"}, status_code=404) @@ -67,37 +62,34 @@ class Responses(Route): ids = set(response_ids.ids) cursor = request.state.db.responses.find( - {"_id": {"$in": list(ids)}} # Convert here back to list, may throw error. + {"_id": {"$in": list(ids)}}, # Convert here back to list, may throw error. ) - entries = [ - FormResponse(**submission) for submission in await cursor.to_list(None) - ] + entries = [FormResponse(**submission) for submission in await cursor.to_list(None)] actual_ids = {entry.id for entry in entries} if len(ids) != len(actual_ids): return JSONResponse( { "error": "responses_not_found", - "ids": list(ids - actual_ids) + "ids": list(ids - actual_ids), }, - status_code=404 + status_code=404, ) if any(entry.form_id != request.path_params["form_id"] for entry in entries): return JSONResponse( { "error": "wrong_form", - "ids": list( - entry.id for entry in entries - if entry.id != request.path_params["form_id"] - ) + "ids": [ + entry.id for entry in entries if entry.id != request.path_params["form_id"] + ], }, - status_code=400 + status_code=400, ) await request.state.db.responses.delete_many( { - "_id": {"$in": list(actual_ids)} - } + "_id": {"$in": list(actual_ids)}, + }, ) return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 765856e..8f01e2b 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -1,6 +1,4 @@ -""" -Submit a form. -""" +"""Submit a form.""" import asyncio import binascii @@ -8,10 +6,9 @@ import datetime import hashlib import typing import uuid -from typing import Any, Optional +from typing import Any import httpx -import pymongo.database import sentry_sdk from pydantic import ValidationError from pydantic.main import BaseModel @@ -29,13 +26,16 @@ from backend.routes.forms.discover import AUTH_FORM from backend.routes.forms.unittesting import BypassDetectedError, execute_unittest from backend.validation import ErrorMessage, api +if typing.TYPE_CHECKING: + import pymongo.database + HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" HCAPTCHA_HEADERS = { - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", } DISCORD_HEADERS = { - "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}" + "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}", } @@ -46,7 +46,7 @@ class SubmissionResponse(BaseModel): class PartialSubmission(BaseModel): response: dict[str, Any] - captcha: Optional[str] + captcha: str | None class UnittestError(BaseModel): @@ -62,9 +62,7 @@ class UnittestErrorMessage(ErrorMessage): class SubmitForm(Route): - """ - Submit a form with the provided form ID. - """ + """Submit a form with the provided form ID.""" name = "submit_form" path = "/submit/{form_id:str}" @@ -75,9 +73,9 @@ class SubmitForm(Route): HTTP_200=SubmissionResponse, HTTP_404=ErrorMessage, HTTP_400=ErrorMessage, - HTTP_422=UnittestErrorMessage + HTTP_422=UnittestErrorMessage, ), - tags=["forms", "responses"] + tags=["forms", "responses"], ) async def post(self, request: Request) -> JSONResponse: """Submit a response to the form.""" @@ -92,7 +90,7 @@ class SubmitForm(Route): if old != request.user.token: try: expiry = datetime.datetime.fromisoformat( - request.user.decoded_token.get("expiry") + request.user.decoded_token.get("expiry"), ) except ValueError: expiry = None @@ -117,7 +115,7 @@ class SubmitForm(Route): id="not-submitted", form_id=AUTH_FORM.id, response={question.id: None for question in AUTH_FORM.questions}, - timestamp=datetime.datetime.now().isoformat() + timestamp=datetime.datetime.now().isoformat(), ).dict() return JSONResponse({"form": AUTH_FORM.dict(admin=False), "response": response}) @@ -131,8 +129,9 @@ class SubmitForm(Route): ip_hash_ctx = hashlib.md5() ip_hash_ctx.update( request.headers.get( - "Cf-Connecting-IP", request.client.host - ).encode() + "Cf-Connecting-IP", + request.client.host, + ).encode(), ) ip_hash = binascii.hexlify(ip_hash_ctx.digest()) user_agent_hash_ctx = hashlib.md5() @@ -142,12 +141,12 @@ class SubmitForm(Route): async with httpx.AsyncClient() as client: query_params = { "secret": constants.HCAPTCHA_API_SECRET, - "response": data.get("captcha") + "response": data.get("captcha"), } r = await client.post( HCAPTCHA_VERIFY_URL, params=query_params, - headers=HCAPTCHA_HEADERS + headers=HCAPTCHA_HEADERS, ) r.raise_for_status() captcha_data = r.json() @@ -155,7 +154,7 @@ class SubmitForm(Route): response["antispam"] = { "ip_hash": ip_hash.decode(), "user_agent_hash": user_agent_hash.decode(), - "captcha_pass": captcha_data["success"] + "captcha_pass": captcha_data["success"], } if constants.FormFeatures.REQUIRES_LOGIN.value in form.features: @@ -164,16 +163,12 @@ class SubmitForm(Route): response["user"]["admin"] = request.user.admin if ( - constants.FormFeatures.COLLECT_EMAIL.value in form.features - and "email" not in response["user"] + constants.FormFeatures.COLLECT_EMAIL.value in form.features + and "email" not in response["user"] ): - return JSONResponse({ - "error": "email_required" - }, status_code=400) + return JSONResponse({"error": "email_required"}, status_code=400) else: - return JSONResponse({ - "error": "missing_discord_data" - }, status_code=400) + return JSONResponse({"error": "missing_discord_data"}, status_code=400) missing_fields = [] for question in form.questions: @@ -184,10 +179,13 @@ class SubmitForm(Route): missing_fields.append(question.id) if missing_fields: - return JSONResponse({ - "error": "missing_fields", - "fields": missing_fields - }, status_code=400) + return JSONResponse( + { + "error": "missing_fields", + "fields": missing_fields, + }, + status_code=400, + ) try: response_obj = FormResponse(**response) @@ -200,10 +198,12 @@ class SubmitForm(Route): if len(errors): username = getattr(request.user, "user_id", "Unknown") - sentry_sdk.capture_exception(BypassDetectedError( - f"Detected unittest bypass attempt on form {form.id} by {username}. " - f"Submission has been written to reporting database ({response_obj.id})." - )) + sentry_sdk.capture_exception( + BypassDetectedError( + f"Detected unittest bypass attempt on form {form.id} by {username}. " + f"Submission has been written to reporting database ({response_obj.id}).", + ) + ) database: pymongo.database.Database = request.state.db await database.get_collection("violations").insert_one({ "user": username, @@ -219,7 +219,7 @@ class SubmitForm(Route): for test in unittest_results: response_obj.response[test.question_id] = { "value": response_obj.response[test.question_id], - "passed": test.passed + "passed": test.passed, } if test.return_code == 0: @@ -238,9 +238,8 @@ class SubmitForm(Route): # Report a failure on internal errors, # or if the test suite doesn't allow failures if not test.passed: - allow_failure = ( - form.questions[test.question_index].data["unittests"]["allow_failure"] - ) + question = form.questions[test.question_index] + allow_failure = question.data["unittests"]["allow_failure"] # An error while communicating with the test runner if test.return_code == 99: @@ -251,15 +250,16 @@ class SubmitForm(Route): failures.append(test) if len(failures): - return JSONResponse({ - "error": "failed_tests", - "test_results": [ - test._asdict() for test in failures - ] - }, status_code=status_code) + return JSONResponse( + { + "error": "failed_tests", + "test_results": [test._asdict() for test in failures], + }, + status_code=status_code, + ) await request.state.db.responses.insert_one( - response_obj.dict(by_alias=True) + response_obj.dict(by_alias=True), ) tasks = BackgroundTasks() @@ -272,36 +272,37 @@ class SubmitForm(Route): self.send_submission_webhook, form=form, response=response_obj, - request_user=request_user + request_user=request_user, ) if constants.FormFeatures.ASSIGN_ROLE.value in form.features: tasks.add_task( self.assign_role, form=form, - request_user=request.user + request_user=request.user, ) - return JSONResponse({ - "form": form.dict(admin=False), - "response": response_obj.dict() - }, background=tasks) + return JSONResponse( + { + "form": form.dict(admin=False), + "response": response_obj.dict(), + }, + background=tasks, + ) - else: - return JSONResponse({ - "error": "Open form not found" - }, status_code=404) + return JSONResponse({"error": "Open form not found"}, status_code=404) @staticmethod async def send_submission_webhook( - form: Form, - response: FormResponse, - request_user: typing.Optional[User] + form: Form, + response: FormResponse, + request_user: User | None, ) -> None: """Helper to send a submission message to a discord webhook.""" # Stop if webhook is not available if form.webhook is None: - raise ValueError("Got empty webhook.") + msg = "Got empty webhook." + raise ValueError(msg) try: mention = request_user.discord_mention @@ -330,7 +331,7 @@ class SubmitForm(Route): hook = { "embeds": [embed], "allowed_mentions": {"parse": ["users", "roles"]}, - "username": form.name or "Python Discord Forms" + "username": form.name or "Python Discord Forms", } # Set hook message @@ -345,8 +346,8 @@ class SubmitForm(Route): "time": response.timestamp, } - for key in ctx: - message = message.replace(f"{{{key}}}", str(ctx[key])) + for key, val in ctx.items(): + message = message.replace(f"{{{key}}}", str(val)) hook["content"] = message.replace("_USER_MENTION_", mention) @@ -359,11 +360,12 @@ class SubmitForm(Route): async def assign_role(form: Form, request_user: User) -> None: """Assigns Discord role to user when user submitted response.""" if not form.discord_role: - raise ValueError("Got empty Discord role ID.") + msg = "Got empty Discord role ID." + raise ValueError(msg) url = ( f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" - f"/members/{request_user.payload['id']}/roles/{form.discord_role}" + f"/members/{request_user.payload["id"]}/roles/{form.discord_role}" ) async with httpx.AsyncClient() as client: diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a02afea..3239d35 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -1,7 +1,8 @@ import base64 -from collections import namedtuple from itertools import count +from pathlib import Path from textwrap import indent +from typing import NamedTuple import httpx from httpx import HTTPStatusError @@ -9,7 +10,7 @@ from httpx import HTTPStatusError from backend.constants import SNEKBOX_URL from backend.models import Form, FormResponse -with open("resources/unittest_template.py") as file: +with Path("resources/unittest_template.py").open(encoding="utf8") as file: TEST_TEMPLATE = file.read() @@ -17,9 +18,12 @@ class BypassDetectedError(Exception): """Detected an attempt at bypassing the unittests.""" -UnittestResult = namedtuple( - "UnittestResult", "question_id question_index return_code passed result" -) +class UnittestResult(NamedTuple): + question_id: str + question_index: int + return_code: int + passed: bool + result: str def filter_unittests(form: Form) -> Form: @@ -46,11 +50,11 @@ def _make_unit_code(units: dict[str, str]) -> str: elif unit_name == "tearDown": result += "\ndef tearDown(self):" else: - name = f"test_{unit_name.removeprefix('#').removeprefix('test_')}" + name = f"test_{unit_name.removeprefix("#").removeprefix("test_")}" result += f"\nasync def {name}(self):" # Unite code - result += f"\n{indent(unit_code, ' ')}" + result += f"\n{indent(unit_code, " ")}" return indent(result, " ") @@ -72,7 +76,8 @@ async def _post_eval(code: str) -> dict[str, str]: async def execute_unittest( - form_response: FormResponse, form: Form + form_response: FormResponse, + form: Form, ) -> tuple[list[UnittestResult], list[BypassDetectedError]]: """Execute all the unittests in this form and return the results.""" unittest_results = [] @@ -80,16 +85,17 @@ async def execute_unittest( for index, question in enumerate(form.questions): if question.type == "code": - # Exit early if the suite doesn't have any tests if question.data["unittests"] is None: - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=0, - passed=True, - result="" - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=0, + passed=True, + result="", + ) + ) continue passed = False @@ -98,7 +104,7 @@ async def execute_unittest( hidden_test_counter = count(1) hidden_tests = { test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter) - for test in question.data["unittests"]["tests"].keys() + for test in question.data["unittests"]["tests"] if test.startswith("#") } @@ -124,18 +130,18 @@ async def execute_unittest( try: passed = bool(int(stdout[0])) except ValueError: - raise BypassDetectedError("Detected a bypass when reading result code.") + msg = "Detected a bypass when reading result code." + raise BypassDetectedError(msg) if passed and stdout.strip() != "1": # Most likely a bypass attempt # A 1 was written to stdout to indicate success, # followed by the actual output - raise BypassDetectedError( - "Detected improper value for stdout in unittest." - ) + msg = "Detected improper value for stdout in unittest." + raise BypassDetectedError(msg) # If the test failed, we have to populate the result string. - elif not passed: + if not passed: failed_tests = stdout[1:].strip().split(";") # Redact failed hidden tests @@ -146,7 +152,7 @@ async def execute_unittest( result = ";".join(failed_tests) else: result = "" - elif return_code in (5, 6, 99): + elif return_code in {5, 6, 99}: result = response["stdout"] # Killed by NsJail elif return_code == 137: @@ -162,12 +168,14 @@ async def execute_unittest( errors.append(error) passed = False - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=return_code, - passed=passed, - result=result - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=return_code, + passed=passed, + result=result, + ) + ) return unittest_results, errors diff --git a/backend/routes/index.py b/backend/routes/index.py index 207c36a..c6e38ea 100644 --- a/backend/routes/index.py +++ b/backend/routes/index.py @@ -1,6 +1,5 @@ -""" -Index route for the forms API. -""" +"""Index route for the forms API.""" + import platform from pydantic import BaseModel @@ -20,13 +19,13 @@ class IndexResponse(BaseModel): description=( "The connecting client, in production this will" " be an IP of our internal load balancer" - ) + ), ) sha: str = Field( - description="Current release Git SHA in production." + description="Current release Git SHA in production.", ) node: str = Field( - description="The node that processed the request." + description="The node that processed the request.", ) @@ -42,24 +41,22 @@ class IndexRoute(Route): @api.validate(resp=Response(HTTP_200=IndexResponse)) def get(self, request: Request) -> JSONResponse: - """ - Return a hello from Python Discord forms! - """ + """Return a hello from Python Discord forms!.""" response_data = { "message": "Hello, world!", "client": request.client.host, "user": { - "authenticated": False + "authenticated": False, }, "sha": GIT_SHA, - "node": platform.uname().node + "node": platform.uname().node, } if request.user.is_authenticated: response_data["user"] = { "authenticated": True, "user": request.user.payload, - "scopes": request.auth.scopes + "scopes": request.auth.scopes, } return JSONResponse(response_data) diff --git a/backend/validation.py b/backend/validation.py index 8771924..0560701 100644 --- a/backend/validation.py +++ b/backend/validation.py @@ -7,7 +7,7 @@ from spectree import SpecTree api = SpecTree( "starlette", TITLE="Python Discord Forms", - PATH="docs" + PATH="docs", ) |