aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2024-07-07 02:29:26 +0100
committerGravatar Joe Banks <[email protected]>2024-07-08 15:00:10 +0100
commitd0e09d2ba567f23d91ac76d1844966bafb9b063a (patch)
tree9e825e3f09df02ab32e401c7e9555df26356dd4c /backend
parentChange linting config to Ruff (diff)
Apply fixable lint settings with Ruff
Diffstat (limited to 'backend')
-rw-r--r--backend/__init__.py6
-rw-r--r--backend/authentication/backend.py41
-rw-r--r--backend/authentication/user.py14
-rw-r--r--backend/constants.py5
-rw-r--r--backend/discord.py60
-rw-r--r--backend/middleware.py4
-rw-r--r--backend/models/__init__.py8
-rw-r--r--backend/models/discord_role.py14
-rw-r--r--backend/models/discord_user.py30
-rw-r--r--backend/models/form.py98
-rw-r--r--backend/models/form_response.py15
-rw-r--r--backend/models/question.py33
-rw-r--r--backend/route.py11
-rw-r--r--backend/route_manager.py10
-rw-r--r--backend/routes/admin.py12
-rw-r--r--backend/routes/auth/authorize.py42
-rw-r--r--backend/routes/discord.py18
-rw-r--r--backend/routes/forms/discover.py21
-rw-r--r--backend/routes/forms/form.py16
-rw-r--r--backend/routes/forms/index.py26
-rw-r--r--backend/routes/forms/response.py21
-rw-r--r--backend/routes/forms/responses.py46
-rw-r--r--backend/routes/forms/submit.py134
-rw-r--r--backend/routes/forms/unittesting.py68
-rw-r--r--backend/routes/index.py21
-rw-r--r--backend/validation.py2
26 files changed, 380 insertions, 396 deletions
diff --git a/backend/__init__.py b/backend/__init__.py
index dcbdcdf..67015d7 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -27,7 +27,7 @@ sentry_sdk.init(
dsn=constants.FORMS_BACKEND_DSN,
send_default_pii=True,
release=SENTRY_RELEASE,
- environment=SENTRY_RELEASE
+ environment=SENTRY_RELEASE,
)
middleware = [
@@ -36,10 +36,10 @@ middleware = [
allow_origins=["https://forms.pythondiscord.com"],
allow_origin_regex=ALLOW_ORIGIN_REGEX,
allow_headers=[
- "Content-Type"
+ "Content-Type",
],
allow_methods=["*"],
- allow_credentials=True
+ allow_credentials=True,
),
Middleware(DatabaseMiddleware),
Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()),
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index 54385e2..2512761 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -1,11 +1,9 @@
-import typing as t
-
import jwt
from starlette import authentication
from starlette.requests import Request
-from backend import constants
-from backend import discord
+from backend import constants, discord
+
# We must import user such way here to avoid circular imports
from .user import User
@@ -19,20 +17,19 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
try:
prefix, token = cookie.split()
except ValueError:
- raise authentication.AuthenticationError(
- "Unable to split prefix and token from authorization cookie."
- )
+ msg = "Unable to split prefix and token from authorization cookie."
+ raise authentication.AuthenticationError(msg)
if prefix.upper() != "JWT":
- raise authentication.AuthenticationError(
- f"Invalid authorization cookie prefix '{prefix}'."
- )
+ msg = f"Invalid authorization cookie prefix '{prefix}'."
+ raise authentication.AuthenticationError(msg)
return token
async def authenticate(
- self, request: Request
- ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
+ self,
+ request: Request,
+ ) -> tuple[authentication.AuthCredentials, authentication.BaseUser] | None:
"""Handles JWT authentication process."""
cookie = request.cookies.get("token")
if not cookie:
@@ -48,21 +45,25 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
scopes = ["authenticated"]
if not payload.get("token"):
- raise authentication.AuthenticationError("Token is missing from JWT.")
+ msg = "Token is missing from JWT."
+ raise authentication.AuthenticationError(msg)
if not payload.get("refresh"):
- raise authentication.AuthenticationError(
- "Refresh token is missing from JWT."
- )
+ msg = "Refresh token is missing from JWT."
+ raise authentication.AuthenticationError(msg)
try:
user_details = payload.get("user_details")
if not user_details or not user_details.get("id"):
- raise authentication.AuthenticationError("Improper user details.")
- except Exception:
- raise authentication.AuthenticationError("Could not parse user details.")
+ msg = "Improper user details."
+ raise authentication.AuthenticationError(msg) # noqa: TRY301
+ except Exception: # noqa: BLE001
+ msg = "Could not parse user details."
+ raise authentication.AuthenticationError(msg)
user = User(
- token, user_details, await discord.get_member(request.state.db, user_details["id"])
+ token,
+ user_details,
+ await discord.get_member(request.state.db, user_details["id"]),
)
if await user.fetch_admin_status(request.state.db):
scopes.append("admin")
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index cd5a249..c81b7a9 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,4 +1,3 @@
-import typing
import typing as t
import jwt
@@ -16,7 +15,7 @@ class User(BaseUser):
self,
token: str,
payload: dict[str, t.Any],
- member: typing.Optional[models.DiscordMember],
+ member: models.DiscordMember | None,
) -> None:
self.token = token
self.payload = payload
@@ -31,11 +30,11 @@ class User(BaseUser):
@property
def display_name(self) -> str:
"""Return username and discriminator as display name."""
- return f"{self.payload['username']}#{self.payload['discriminator']}"
+ return f"{self.payload["username"]}#{self.payload["discriminator"]}"
@property
def discord_mention(self) -> str:
- return f"<@{self.payload['id']}>"
+ return f"<@{self.payload["id"]}>"
@property
def user_id(self) -> str:
@@ -61,9 +60,10 @@ class User(BaseUser):
return roles
async def fetch_admin_status(self, database: Database) -> bool:
- self.admin = await database.admins.find_one(
- {"_id": self.payload["id"]}
- ) is not None
+ query = {"_id": self.payload["id"]}
+ found_admin = await database.admins.find_one(query)
+
+ self.admin = found_admin is not None
return self.admin
diff --git a/backend/constants.py b/backend/constants.py
index e1c38d3..8089077 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -18,7 +18,8 @@ PRODUCTION_URL = "https://forms.pythondiscord.com"
OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET")
OAUTH2_REDIRECT_URI = os.getenv(
- "OAUTH2_REDIRECT_URI", "https://forms.pythondiscord.com/callback"
+ "OAUTH2_REDIRECT_URI",
+ "https://forms.pythondiscord.com/callback",
)
GIT_SHA = os.getenv("GIT_SHA", "dev")
@@ -28,7 +29,7 @@ DOCS_PASSWORD = os.getenv("DOCS_PASSWORD")
SECRET_KEY = os.getenv("SECRET_KEY", binascii.hexlify(os.urandom(30)).decode())
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
-DISCORD_GUILD = os.getenv("DISCORD_GUILD", 267624335836053506)
+DISCORD_GUILD = os.getenv("DISCORD_GUILD", "267624335836053506")
HCAPTCHA_API_SECRET = os.getenv("HCAPTCHA_API_SECRET")
diff --git a/backend/discord.py b/backend/discord.py
index ff6c1bb..dc5989a 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -2,7 +2,6 @@
import datetime
import json
-import typing
import httpx
import starlette.requests
@@ -17,7 +16,7 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict
data = {
"client_id": constants.OAUTH2_CLIENT_ID,
"client_secret": constants.OAUTH2_CLIENT_SECRET,
- "redirect_uri": f"{redirect}/callback"
+ "redirect_uri": f"{redirect}/callback",
}
if refresh:
@@ -27,9 +26,13 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict
data["grant_type"] = "authorization_code"
data["code"] = code
- r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={
- "Content-Type": "application/x-www-form-urlencoded"
- }, data=data)
+ r = await client.post(
+ f"{constants.DISCORD_API_BASE_URL}/oauth2/token",
+ headers={
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ data=data,
+ )
r.raise_for_status()
@@ -38,9 +41,12 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict
async def fetch_user_details(bearer_token: str) -> dict:
async with httpx.AsyncClient() as client:
- r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={
- "Authorization": f"Bearer {bearer_token}"
- })
+ r = await client.get(
+ f"{constants.DISCORD_API_BASE_URL}/users/@me",
+ headers={
+ "Authorization": f"Bearer {bearer_token}",
+ },
+ )
r.raise_for_status()
@@ -52,7 +58,7 @@ async def _get_role_info() -> list[models.DiscordRole]:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles",
- headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
+ headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"},
)
r.raise_for_status()
@@ -60,7 +66,9 @@ async def _get_role_info() -> list[models.DiscordRole]:
async def get_roles(
- database: Database, *, force_refresh: bool = False
+ database: Database,
+ *,
+ force_refresh: bool = False,
) -> list[models.DiscordRole]:
"""
Get a list of all roles from the cache, or discord API if not available.
@@ -86,23 +94,26 @@ async def get_roles(
if len(roles) == 0:
# Fetch roles from the API and insert into the database
roles = await _get_role_info()
- await collection.insert_many({
- "name": role.name,
- "id": role.id,
- "data": role.json(),
- "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
- } for role in roles)
+ await collection.insert_many(
+ {
+ "name": role.name,
+ "id": role.id,
+ "data": role.json(),
+ "inserted_at": datetime.datetime.now(tz=datetime.UTC),
+ }
+ for role in roles
+ )
return roles
-async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]:
+async def _fetch_member_api(member_id: str) -> models.DiscordMember | None:
"""Get a member by ID from the configured guild using the discord API."""
async with httpx.AsyncClient() as client:
r = await client.get(
f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}"
f"/members/{member_id}",
- headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
+ headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"},
)
if r.status_code == 404:
@@ -113,8 +124,11 @@ async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMem
async def get_member(
- database: Database, user_id: str, *, force_refresh: bool = False
-) -> typing.Optional[models.DiscordMember]:
+ database: Database,
+ user_id: str,
+ *,
+ force_refresh: bool = False,
+) -> models.DiscordMember | None:
"""
Get a member from the cache, or from the discord API.
@@ -147,7 +161,7 @@ async def get_member(
await collection.insert_one({
"user": user_id,
"data": member.json(),
- "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
+ "inserted_at": datetime.datetime.now(tz=datetime.UTC),
})
return member
@@ -161,7 +175,9 @@ class UnauthorizedError(exceptions.HTTPException):
async def _verify_access_helper(
- form_id: str, request: starlette.requests.Request, attribute: str
+ form_id: str,
+ request: starlette.requests.Request,
+ attribute: str,
) -> None:
"""A low level helper to validate access to a form resource based on the user's scopes."""
form = await request.state.db.forms.find_one({"_id": form_id})
diff --git a/backend/middleware.py b/backend/middleware.py
index 7a3bdc8..0b08859 100644
--- a/backend/middleware.py
+++ b/backend/middleware.py
@@ -7,14 +7,13 @@ from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE
class DatabaseMiddleware:
-
def __init__(self, app: ASGIApp) -> None:
self._app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
client: AsyncIOMotorClient = AsyncIOMotorClient(
DATABASE_URL,
- tlsAllowInvalidCertificates=True
+ tlsAllowInvalidCertificates=True,
)
db = client[MONGO_DATABASE]
Request(scope).state.db = db
@@ -22,7 +21,6 @@ class DatabaseMiddleware:
class ProtectedDocsMiddleware:
-
def __init__(self, app: ASGIApp) -> None:
self._app = app
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index a9f76e0..336e28b 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -7,13 +7,13 @@ from .question import CodeQuestion, Question
__all__ = [
"AntiSpam",
+ "CodeQuestion",
+ "DiscordMember",
"DiscordRole",
"DiscordUser",
- "DiscordMember",
"Form",
+ "FormList",
"FormResponse",
- "CodeQuestion",
"Question",
- "FormList",
- "ResponseList"
+ "ResponseList",
]
diff --git a/backend/models/discord_role.py b/backend/models/discord_role.py
index ada35ef..195f557 100644
--- a/backend/models/discord_role.py
+++ b/backend/models/discord_role.py
@@ -1,13 +1,11 @@
-import typing
-
from pydantic import BaseModel
class RoleTags(BaseModel):
"""Meta information about a discord role."""
- bot_id: typing.Optional[str]
- integration_id: typing.Optional[str]
+ bot_id: str | None
+ integration_id: str | None
premium_subscriber: bool
def __init__(self, **data) -> None:
@@ -20,7 +18,7 @@ class RoleTags(BaseModel):
We manually parse the raw data to determine if the field exists, and give it a useful
bool value.
"""
- data["premium_subscriber"] = "premium_subscriber" in data.keys()
+ data["premium_subscriber"] = "premium_subscriber" in data
super().__init__(**data)
@@ -31,10 +29,10 @@ class DiscordRole(BaseModel):
name: str
color: int
hoist: bool
- icon: typing.Optional[str]
- unicode_emoji: typing.Optional[str]
+ icon: str | None
+ unicode_emoji: str | None
position: int
permissions: str
managed: bool
mentionable: bool
- tags: typing.Optional[RoleTags]
+ tags: RoleTags | None
diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py
index 0eca15b..be10672 100644
--- a/backend/models/discord_user.py
+++ b/backend/models/discord_user.py
@@ -11,15 +11,15 @@ class _User(BaseModel):
username: str
id: str
discriminator: str
- avatar: t.Optional[str]
- bot: t.Optional[bool]
- system: t.Optional[bool]
- locale: t.Optional[str]
- verified: t.Optional[bool]
- email: t.Optional[str]
- flags: t.Optional[int]
- premium_type: t.Optional[int]
- public_flags: t.Optional[int]
+ avatar: str | None
+ bot: bool | None
+ system: bool | None
+ locale: str | None
+ verified: bool | None
+ email: str | None
+ flags: int | None
+ premium_type: int | None
+ public_flags: int | None
class DiscordUser(_User):
@@ -33,16 +33,16 @@ class DiscordMember(BaseModel):
"""A discord guild member."""
user: _User
- nick: t.Optional[str]
- avatar: t.Optional[str]
+ nick: str | None
+ avatar: str | None
roles: list[str]
joined_at: datetime.datetime
- premium_since: t.Optional[datetime.datetime]
+ premium_since: datetime.datetime | None
deaf: bool
mute: bool
- pending: t.Optional[bool]
- permissions: t.Optional[str]
- communication_disabled_until: t.Optional[datetime.datetime]
+ pending: bool | None
+ permissions: str | None
+ communication_disabled_until: datetime.datetime | None
def dict(self, *args, **kwargs) -> dict[str, t.Any]:
"""Convert the model to a python dict, and encode timestamps in a serializable format."""
diff --git a/backend/models/form.py b/backend/models/form.py
index 10c8bfd..3db267e 100644
--- a/backend/models/form.py
+++ b/backend/models/form.py
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, constr, root_validator, validator
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from backend.constants import DISCORD_GUILD, FormFeatures, WebHook
+
from .question import Question
PUBLIC_FIELDS = [
@@ -14,20 +15,22 @@ PUBLIC_FIELDS = [
"name",
"description",
"submitted_text",
- "discord_role"
+ "discord_role",
]
class _WebHook(BaseModel):
"""Schema model of discord webhooks."""
+
url: str
- message: t.Optional[str]
+ message: str | None
@validator("url")
def validate_url(cls, url: str) -> str:
"""Validates URL parameter."""
if "discord.com/api/webhooks/" not in url:
- raise ValueError("URL must be a discord webhook.")
+ msg = "URL must be a discord webhook."
+ raise ValueError(msg)
return url
@@ -40,56 +43,55 @@ class Form(BaseModel):
questions: list[Question]
name: str
description: str
- submitted_text: t.Optional[str] = None
+ submitted_text: str | None = None
webhook: _WebHook = None
- discord_role: t.Optional[str]
- response_readers: t.Optional[list[str]]
- editors: t.Optional[list[str]]
+ discord_role: str | None
+ response_readers: list[str] | None
+ editors: list[str] | None
class Config:
allow_population_by_field_name = True
@validator("features")
- def validate_features(cls, value: list[str]) -> t.Optional[list[str]]:
+ def validate_features(cls, value: list[str]) -> list[str]:
"""Validates is all features in allowed list."""
# Uppercase everything to avoid mixed case in DB
value = [v.upper() for v in value]
allowed_values = [v.value for v in FormFeatures.__members__.values()]
if any(v not in allowed_values for v in value):
- raise ValueError("Form features list contains one or more invalid values.")
+ msg = "Form features list contains one or more invalid values."
+ raise ValueError(msg)
if FormFeatures.REQUIRES_LOGIN.value not in value:
if FormFeatures.COLLECT_EMAIL.value in value:
- raise ValueError(
- "COLLECT_EMAIL feature require REQUIRES_LOGIN feature."
- )
+ msg = "COLLECT_EMAIL feature require REQUIRES_LOGIN feature."
+ raise ValueError(msg)
if FormFeatures.ASSIGN_ROLE.value in value:
- raise ValueError("ASSIGN_ROLE feature require REQUIRES_LOGIN feature.")
+ msg = "ASSIGN_ROLE feature require REQUIRES_LOGIN feature."
+ raise ValueError(msg)
return value
@validator("response_readers", "editors")
- def validate_role_scoping(cls, value: t.Optional[list[str]]) -> t.Optional[list[str]]:
+ def validate_role_scoping(cls, value: list[str] | None) -> list[str]:
"""Ensure special role based permissions aren't granted to the @everyone role."""
- if value and str(DISCORD_GUILD) in value:
- raise ValueError("You can not add the everyone role as an access scope.")
+ if value and DISCORD_GUILD in value:
+ msg = "You can not add the everyone role as an access scope."
+ raise ValueError(msg)
return value
@root_validator
- def validate_role(cls, values: dict[str, t.Any]) -> t.Optional[dict[str, t.Any]]:
+ def validate_role(cls, values: dict[str, t.Any]) -> dict[str, t.Any]:
"""Validates does Discord role provided when flag provided."""
- if (
- FormFeatures.ASSIGN_ROLE.value in values.get("features", [])
- and not values.get("discord_role")
- ):
- raise ValueError(
- "discord_role field is required when ASSIGN_ROLE flag is provided."
- )
+ is_role_assigner = FormFeatures.ASSIGN_ROLE.value in values.get("features", [])
+ if is_role_assigner and not values.get("discord_role"):
+ msg = "discord_role field is required when ASSIGN_ROLE flag is provided."
+ raise ValueError(msg)
return values
- def dict(self, admin: bool = True, **kwargs) -> dict[str, t.Any]:
+ def dict(self, admin: bool = True, **kwargs) -> dict[str, t.Any]: # noqa: FBT001, FBT002
"""Wrapper for original function to exclude private data for public access."""
data = super().dict(**kwargs)
@@ -97,10 +99,7 @@ class Form(BaseModel):
if not admin:
for field in PUBLIC_FIELDS:
- if field == "id" and kwargs.get("by_alias"):
- fetch_field = "_id"
- else:
- fetch_field = field
+ fetch_field = "_id" if field == "id" and kwargs.get("by_alias") else field
returned_data[field] = data[fetch_field]
else:
@@ -110,17 +109,20 @@ class Form(BaseModel):
class FormList(BaseModel):
- __root__: t.List[Form]
+ __root__: list[Form]
-async def validate_hook_url(url: str) -> t.Optional[ValidationError]:
+async def validate_hook_url(url: str) -> ValidationError | None:
"""Validator for discord webhook urls."""
- async def validate() -> t.Optional[str]:
+
+ async def validate() -> str | None:
if not isinstance(url, str):
- raise ValueError("Webhook URL must be a string.")
+ msg = "Webhook URL must be a string."
+ raise TypeError(msg)
if "discord.com/api/webhooks/" not in url:
- raise ValueError("URL must be a discord webhook.")
+ msg = "URL must be a discord webhook."
+ raise ValueError(msg)
try:
async with httpx.AsyncClient() as client:
@@ -129,36 +131,32 @@ async def validate_hook_url(url: str) -> t.Optional[ValidationError]:
except httpx.RequestError as error:
# Catch exceptions in request format
- raise ValueError(
- f"Encountered error while trying to connect to url: `{error}`"
- )
+ msg = f"Encountered error while trying to connect to url: `{error}`"
+ raise ValueError(msg)
except httpx.HTTPStatusError as error:
# Catch exceptions in response
status = error.response.status_code
if status == 401:
- raise ValueError(
- "Could not authenticate with target. Please check the webhook url."
- )
- elif status == 404:
- raise ValueError(
- "Target could not find webhook url. Please check the webhook url."
- )
- else:
- raise ValueError(
- f"Unknown error ({status}) while connecting to target: {error}"
- )
+ msg = "Could not authenticate with target. Please check the webhook url."
+ raise ValueError(msg)
+ if status == 404:
+ msg = "Target could not find webhook url. Please check the webhook url."
+ raise ValueError(msg)
+
+ msg = f"Unknown error ({status}) while connecting to target: {error}"
+ raise ValueError(msg)
return url
# Validate, and return errors, if any
try:
await validate()
- except Exception as e:
+ except Exception as e: # noqa: BLE001
loc = (
WebHook.__name__.lower(),
- WebHook.URL.value
+ WebHook.URL.value,
)
return ValidationError([ErrorWrapper(e, loc=loc)], _WebHook)
diff --git a/backend/models/form_response.py b/backend/models/form_response.py
index 933f5e4..3c8297b 100644
--- a/backend/models/form_response.py
+++ b/backend/models/form_response.py
@@ -11,19 +11,20 @@ class FormResponse(BaseModel):
"""Schema model for form response."""
id: str = Field(alias="_id")
- user: t.Optional[DiscordUser]
- antispam: t.Optional[AntiSpam]
+ user: DiscordUser | None
+ antispam: AntiSpam | None
response: dict[str, t.Any]
form_id: str
timestamp: str
@validator("timestamp", pre=True)
- def set_timestamp(cls, iso_string: t.Optional[str]) -> t.Optional[str]:
+ def set_timestamp(cls, iso_string: str | None) -> str:
if iso_string is None:
- return datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
+ return datetime.datetime.now(tz=datetime.UTC).isoformat()
- elif not isinstance(iso_string, str):
- raise ValueError("Submission timestamp must be a string.")
+ if not isinstance(iso_string, str):
+ msg = "Submission timestamp must be a string."
+ raise TypeError(msg)
# Convert to datetime and back to ensure string is valid
return datetime.datetime.fromisoformat(iso_string).isoformat()
@@ -33,4 +34,4 @@ class FormResponse(BaseModel):
class ResponseList(BaseModel):
- __root__: t.List[FormResponse]
+ __root__: list[FormResponse]
diff --git a/backend/models/question.py b/backend/models/question.py
index 201aa51..a13ce93 100644
--- a/backend/models/question.py
+++ b/backend/models/question.py
@@ -4,11 +4,12 @@ from pydantic import BaseModel, Field, root_validator, validator
from backend.constants import QUESTION_TYPES, REQUIRED_QUESTION_TYPE_DATA
-_TESTS_TYPE = t.Union[t.Dict[str, str], int]
+_TESTS_TYPE = dict[str, str] | int
class Unittests(BaseModel):
"""Schema model for unittest suites in code questions."""
+
allow_failure: bool = False
tests: _TESTS_TYPE
@@ -16,17 +17,19 @@ class Unittests(BaseModel):
def validate_tests(cls, value: _TESTS_TYPE) -> _TESTS_TYPE:
"""Confirm that at least one test exists in a test suite."""
if isinstance(value, dict):
- keys = len(value.keys()) - (1 if "setUp" in value.keys() else 0)
+ keys = len(value.keys()) - (1 if "setUp" in value else 0)
if keys == 0:
- raise ValueError("Must have at least one test in a test suite.")
+ msg = "Must have at least one test in a test suite."
+ raise ValueError(msg)
return value
class CodeQuestion(BaseModel):
"""Schema model for questions of type `code`."""
+
language: str
- unittests: t.Optional[Unittests]
+ unittests: Unittests | None
class Question(BaseModel):
@@ -42,22 +45,20 @@ class Question(BaseModel):
allow_population_by_field_name = True
@validator("type", pre=True)
- def validate_question_type(cls, value: str) -> t.Optional[str]:
+ def validate_question_type(cls, value: str) -> str:
"""Checks if question type in currently allowed types list."""
value = value.lower()
if value not in QUESTION_TYPES:
- raise ValueError(
- f"{value} is not valid question type. "
- f"Allowed question types: {QUESTION_TYPES}."
- )
+ msg = f"{value} is not valid question type. Allowed question types: {QUESTION_TYPES}."
+ raise ValueError(msg)
return value
@root_validator
def validate_question_data(
- cls,
- value: dict[str, t.Any]
- ) -> t.Optional[dict[str, t.Any]]:
+ cls,
+ value: dict[str, t.Any],
+ ) -> dict[str, t.Any]:
"""Check does required data exists for question type and remove other data."""
# When question type don't need data, don't add anything to keep DB clean.
if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA:
@@ -65,13 +66,15 @@ class Question(BaseModel):
for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items():
if key not in value.get("data", {}):
- raise ValueError(f"Required question data key '{key}' not provided.")
+ msg = f"Required question data key '{key}' not provided."
+ raise ValueError(msg)
if not isinstance(value["data"][key], data_type):
- raise ValueError(
+ msg = (
f"Question data key '{key}' expects {data_type.__name__}, "
- f"got {type(value['data'][key]).__name__} instead."
+ f"got {type(value["data"][key]).__name__} instead."
)
+ raise TypeError(msg)
# Validate unittest options
if value.get("type").lower() == "code":
diff --git a/backend/route.py b/backend/route.py
index d778bf0..a9ea7ad 100644
--- a/backend/route.py
+++ b/backend/route.py
@@ -1,6 +1,5 @@
-"""
-Base class for implementing dynamic routing.
-"""
+"""Base class for implementing dynamic routing."""
+
from starlette.endpoints import HTTPEndpoint
@@ -11,7 +10,9 @@ class Route(HTTPEndpoint):
@classmethod
def check_parameters(cls) -> None:
if not hasattr(cls, "name"):
- raise ValueError(f"Route {cls.__name__} has not defined a name")
+ msg = f"Route {cls.__name__} has not defined a name"
+ raise ValueError(msg)
if not hasattr(cls, "path"):
- raise ValueError(f"Route {cls.__name__} has not defined a path")
+ msg = f"Route {cls.__name__} has not defined a path"
+ raise ValueError(msg)
diff --git a/backend/route_manager.py b/backend/route_manager.py
index 2d95bb2..b35ca0b 100644
--- a/backend/route_manager.py
+++ b/backend/route_manager.py
@@ -1,6 +1,4 @@
-"""
-Module to dynamically generate a Starlette routing map based on a directory tree.
-"""
+"""Module to dynamically generate a Starlette routing map based on a directory tree."""
import importlib
import inspect
@@ -27,7 +25,7 @@ def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]:
return route_map
-def is_route_class(member: t.Any) -> bool: # noqa: ANN401
+def is_route_class(member: t.Any) -> bool:
return inspect.isclass(member) and issubclass(member, Route) and member != Route
@@ -35,7 +33,7 @@ def route_classes() -> t.Iterator[tuple[Path, type[Route]]]:
routes_directory = Path("backend") / "routes"
for module_path in routes_directory.rglob("*.py"):
- import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}"
+ import_name = f"{".".join(module_path.parent.parts)}.{module_path.stem}"
route_module = importlib.import_module(import_name)
for _member_name, member in inspect.getmembers(route_module):
if is_route_class(member):
@@ -47,7 +45,7 @@ def create_route_map() -> list[BaseRoute]:
route_dict = nested_dict()
for module_path, member in route_classes():
- # module_path == Path("backend/routes/foo/bar/baz/bin.py")
+ # For Path: "backend/routes/foo/bar/baz/bin.py"
# => levels == ["foo", "bar", "baz"]
levels = module_path.parent.parts[2:]
current_level = None
diff --git a/backend/routes/admin.py b/backend/routes/admin.py
index 0fd0700..848abce 100644
--- a/backend/routes/admin.py
+++ b/backend/routes/admin.py
@@ -1,6 +1,5 @@
-"""
-Adds new admin user.
-"""
+"""Adds new admin user."""
+
from pydantic import BaseModel, Field
from spectree import Response
from starlette.authentication import requires
@@ -22,7 +21,7 @@ async def grant(request: Request) -> JSONResponse:
admin = AdminModel(**data)
if await request.state.db.admins.find_one(
- {"_id": admin.id}
+ {"_id": admin.id},
):
return JSONResponse({"error": "already_exists"}, status_code=400)
@@ -40,7 +39,7 @@ class AdminRoute(Route):
@api.validate(
json=AdminModel,
resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage),
- tags=["admin"]
+ tags=["admin"],
)
async def post(self, request: Request) -> JSONResponse:
"""Grant a user administrator privileges."""
@@ -48,6 +47,7 @@ class AdminRoute(Route):
if not constants.PRODUCTION:
+
class AdminDev(Route):
"""Adds new admin user with no authentication."""
@@ -57,7 +57,7 @@ if not constants.PRODUCTION:
@api.validate(
json=AdminModel,
resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage),
- tags=["admin"]
+ tags=["admin"],
)
async def post(self, request: Request) -> JSONResponse:
"""
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 42fb3ec..bc80a7d 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -1,9 +1,6 @@
-"""
-Use a token received from the Discord OAuth2 system to fetch user information.
-"""
+"""Use a token received from the Discord OAuth2 system to fetch user information."""
import datetime
-from typing import Union
import httpx
import jwt
@@ -35,8 +32,8 @@ class AuthorizeResponse(BaseModel):
async def process_token(
bearer_token: dict,
- request: Request
-) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ request: Request,
+) -> AuthorizeResponse | responses.JSONResponse:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()
@@ -57,7 +54,7 @@ async def process_token(
"refresh": bearer_token["refresh_token"],
"user_details": user_details,
"in_guild": bool(member),
- "expiry": token_expiry.isoformat()
+ "expiry": token_expiry.isoformat(),
}
token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
@@ -65,18 +62,18 @@ async def process_token(
response = responses.JSONResponse({
"username": user.display_name,
- "expiry": token_expiry.isoformat()
+ "expiry": token_expiry.isoformat(),
})
- await set_response_token(response, request, token, bearer_token["expires_in"])
+ set_response_token(response, request, token, bearer_token["expires_in"])
return response
-async def set_response_token(
- response: responses.Response,
- request: Request,
- new_token: str,
- expiry: int
+def set_response_token(
+ response: responses.Response,
+ request: Request,
+ new_token: str,
+ expiry: int,
) -> None:
"""Helper that handles logic for updating a token in a set-cookie response."""
origin_url = request.headers.get("origin")
@@ -94,19 +91,18 @@ async def set_response_token(
samesite = "None"
response.set_cookie(
- "token", f"JWT {new_token}",
+ "token",
+ f"JWT {new_token}",
secure=constants.PRODUCTION,
httponly=True,
samesite=samesite,
domain=domain,
- max_age=expiry
+ max_age=expiry,
)
class AuthorizeRoute(Route):
- """
- Use the authorization code from Discord to generate a JWT token.
- """
+ """Use the authorization code from Discord to generate a JWT token."""
name = "authorize"
path = "/authorize"
@@ -114,7 +110,7 @@ class AuthorizeRoute(Route):
@api.validate(
json=AuthorizeRequest,
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
- tags=["auth"]
+ tags=["auth"],
)
async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
@@ -129,9 +125,7 @@ class AuthorizeRoute(Route):
class TokenRefreshRoute(Route):
- """
- Use the refresh code from a JWT to get a new token and generate a new JWT token.
- """
+ """Use the refresh code from a JWT to get a new token and generate a new JWT token."""
name = "refresh"
path = "/refresh"
@@ -139,7 +133,7 @@ class TokenRefreshRoute(Route):
@requires(["authenticated"])
@api.validate(
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
- tags=["auth"]
+ tags=["auth"],
)
async def post(self, request: Request) -> responses.JSONResponse:
"""Refresh an authorization token."""
diff --git a/backend/routes/discord.py b/backend/routes/discord.py
index bca1edb..53b8af3 100644
--- a/backend/routes/discord.py
+++ b/backend/routes/discord.py
@@ -10,7 +10,8 @@ from backend import discord, models, route
from backend.validation import ErrorMessage, api
NOT_FOUND_EXCEPTION = JSONResponse(
- {"error": "Could not find the requested resource in the guild or cache."}, status_code=404
+ {"error": "Could not find the requested resource in the guild or cache."},
+ status_code=404,
)
@@ -28,7 +29,7 @@ class RolesRoute(route.Route):
@requires(["authenticated", "admin"])
@api.validate(
resp=Response(HTTP_200=RolesResponse),
- tags=["roles"]
+ tags=["roles"],
)
async def patch(self, request: Request) -> JSONResponse:
"""Refresh the roles database."""
@@ -54,7 +55,7 @@ class MemberRoute(route.Route):
@api.validate(
resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
json=MemberRequest,
- tags=["auth"]
+ tags=["auth"],
)
async def delete(self, request: Request) -> JSONResponse:
"""Force a resync of the cache for the given user."""
@@ -63,21 +64,20 @@ class MemberRoute(route.Route):
if member:
return JSONResponse(member.dict())
- else:
- return NOT_FOUND_EXCEPTION
+ return NOT_FOUND_EXCEPTION
@requires(["authenticated", "admin"])
@api.validate(
resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
json=MemberRequest,
- tags=["auth"]
+ tags=["auth"],
)
async def get(self, request: Request) -> JSONResponse:
"""Get a user's roles on the configured server."""
body = await request.json()
member = await discord.get_member(request.state.db, body["user_id"])
- if member:
- return JSONResponse(member.dict())
- else:
+ if not member:
return NOT_FOUND_EXCEPTION
+
+ return JSONResponse(member.dict())
diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py
index 75ff495..0fe10b5 100644
--- a/backend/routes/forms/discover.py
+++ b/backend/routes/forms/discover.py
@@ -1,6 +1,5 @@
-"""
-Return a list of all publicly discoverable forms to unauthenticated users.
-"""
+"""Return a list of all publicly discoverable forms to unauthenticated users."""
+
from spectree.response import Response
from starlette.requests import Request
from starlette.responses import JSONResponse
@@ -12,7 +11,7 @@ from backend.validation import api
__FEATURES = [
constants.FormFeatures.OPEN.value,
- constants.FormFeatures.REQUIRES_LOGIN.value
+ constants.FormFeatures.REQUIRES_LOGIN.value,
]
if not constants.PRODUCTION:
__FEATURES.append(constants.FormFeatures.DISCOVERABLE.value)
@@ -22,7 +21,7 @@ __QUESTION = Question(
name="Click the button below to log into the forms application.",
type="section",
data={"text": ""},
- required=False
+ required=False,
)
AUTH_FORM = Form(
@@ -31,14 +30,12 @@ AUTH_FORM = Form(
questions=[__QUESTION],
name="Login",
description="Log into Python Discord Forms.",
- submitted_text="This page can't be submitted."
+ submitted_text="This page can't be submitted.",
)
class DiscoverableFormsList(Route):
- """
- List all discoverable forms that should be shown on the homepage.
- """
+ """List all discoverable forms that should be shown on the homepage."""
name = "discoverable_forms_list"
path = "/discoverable"
@@ -46,15 +43,11 @@ class DiscoverableFormsList(Route):
@api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""List all discoverable forms that should be shown on the homepage."""
- forms = []
cursor = request.state.db.forms.find({"features": "DISCOVERABLE"}).sort("name")
# Parse it to Form and then back to dictionary
# to replace _id with id
- for form in await cursor.to_list(None):
- forms.append(Form(**form))
-
- forms = [form.dict(admin=False) for form in forms]
+ forms = [Form(**form).dict(admin=False) for form in await cursor.to_list(None)]
# Return an empty form in development environments to help with authentication.
if not constants.PRODUCTION:
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index 020193c..410102a 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -1,6 +1,5 @@
-"""
-Returns, updates or deletes a single form given an ID.
-"""
+"""Returns, updates or deletes a single form given an ID."""
+
import json.decoder
import deepmerge
@@ -48,7 +47,7 @@ class SingleForm(Route):
admin = False
filters = {
- "_id": form_id
+ "_id": form_id,
}
if not admin:
@@ -71,7 +70,7 @@ class SingleForm(Route):
HTTP_400=ErrorMessage,
HTTP_404=ErrorMessage,
),
- tags=["forms"]
+ tags=["forms"],
)
async def patch(self, request: Request) -> JSONResponse:
"""Updates form by ID."""
@@ -90,7 +89,7 @@ class SingleForm(Route):
# Build Data Merger
merge_strategy = [
- (dict, ["merge"])
+ (dict, ["merge"]),
]
merger = deepmerge.Merger(merge_strategy, ["override"], ["override"])
@@ -105,13 +104,12 @@ class SingleForm(Route):
await request.state.db.forms.replace_one({"_id": form_id}, form.dict())
return JSONResponse(form.dict())
- else:
- return JSONResponse({"error": "not_found"}, status_code=404)
+ return JSONResponse({"error": "not_found"}, status_code=404)
@requires(["authenticated", "admin"])
@api.validate(
resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage),
- tags=["forms"]
+ tags=["forms"],
)
async def delete(self, request: Request) -> JSONResponse:
"""Deletes form by ID."""
diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py
index 38be693..1fdfc48 100644
--- a/backend/routes/forms/index.py
+++ b/backend/routes/forms/index.py
@@ -1,6 +1,5 @@
-"""
-Return a list of all forms to authenticated users.
-"""
+"""Return a list of all forms to authenticated users."""
+
from spectree.response import Response
from starlette.authentication import requires
from starlette.requests import Request
@@ -14,9 +13,7 @@ from backend.validation import ErrorMessage, OkayResponse, api
class FormsList(Route):
- """
- List all available forms for authorized viewers.
- """
+ """List all available forms for authorized viewers."""
name = "forms_list_create"
path = "/"
@@ -25,24 +22,17 @@ class FormsList(Route):
@api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Return a list of all forms to authenticated users."""
- forms = []
cursor = request.state.db.forms.find()
- for form in await cursor.to_list(None):
- forms.append(Form(**form)) # For converting _id to id
-
- # Covert them back to dictionaries
- forms = [form.dict() for form in forms]
+ forms = [Form(**form).dict() for form in await cursor.to_list(None)]
- return JSONResponse(
- forms
- )
+ return JSONResponse(forms)
@requires(["authenticated", "Helpers"])
@api.validate(
json=Form,
resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage),
- tags=["forms"]
+ tags=["forms"],
)
async def post(self, request: Request) -> JSONResponse:
"""Create a new form."""
@@ -66,9 +56,7 @@ class FormsList(Route):
form = Form(**form_data)
if await request.state.db.forms.find_one({"_id": form.id}):
- return JSONResponse({
- "error": "id_taken"
- }, status_code=400)
+ return JSONResponse({"error": "id_taken"}, status_code=400)
await request.state.db.forms.insert_one(form.dict(by_alias=True))
return JSONResponse(form.dict())
diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py
index 565701f..b4f7f04 100644
--- a/backend/routes/forms/response.py
+++ b/backend/routes/forms/response.py
@@ -1,6 +1,4 @@
-"""
-Returns or deletes form response by ID.
-"""
+"""Returns or deletes form response by ID."""
from spectree import Response as RouteResponse
from starlette.authentication import requires
@@ -22,7 +20,7 @@ class Response(Route):
@requires(["authenticated"])
@api.validate(
resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage),
- tags=["forms", "responses"]
+ tags=["forms", "responses"],
)
async def get(self, request: Request) -> JSONResponse:
"""Return a single form response by ID."""
@@ -32,30 +30,29 @@ class Response(Route):
if raw_response := await request.state.db.responses.find_one(
{
"_id": request.path_params["response_id"],
- "form_id": form_id
- }
+ "form_id": form_id,
+ },
):
response = FormResponse(**raw_response)
return JSONResponse(response.dict())
- else:
- return JSONResponse({"error": "response_not_found"}, status_code=404)
+ return JSONResponse({"error": "response_not_found"}, status_code=404)
@requires(["authenticated", "admin"])
@api.validate(
resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage),
- tags=["forms", "responses"]
+ tags=["forms", "responses"],
)
async def delete(self, request: Request) -> JSONResponse:
"""Delete a form response by ID."""
if not await request.state.db.responses.find_one(
{
"_id": request.path_params["response_id"],
- "form_id": request.path_params["form_id"]
- }
+ "form_id": request.path_params["form_id"],
+ },
):
return JSONResponse({"error": "not_found"}, status_code=404)
await request.state.db.responses.delete_one(
- {"_id": request.path_params["response_id"]}
+ {"_id": request.path_params["response_id"]},
)
return JSONResponse({"status": "ok"})
diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py
index 818ebce..85e5af2 100644
--- a/backend/routes/forms/responses.py
+++ b/backend/routes/forms/responses.py
@@ -1,6 +1,5 @@
-"""
-Returns all form responses by form ID.
-"""
+"""Returns all form responses by form ID."""
+
from pydantic import BaseModel
from spectree import Response
from starlette.authentication import requires
@@ -18,9 +17,7 @@ class ResponseIdList(BaseModel):
class Responses(Route):
- """
- Returns all form responses by form ID.
- """
+ """Returns all form responses by form ID."""
name = "form_responses"
path = "/{form_id:str}/responses"
@@ -28,7 +25,7 @@ class Responses(Route):
@requires(["authenticated"])
@api.validate(
resp=Response(HTTP_200=ResponseList),
- tags=["forms", "responses"]
+ tags=["forms", "responses"],
)
async def get(self, request: Request) -> JSONResponse:
"""Returns all form responses by form ID."""
@@ -36,11 +33,9 @@ class Responses(Route):
await discord.verify_response_access(form_id, request)
cursor = request.state.db.responses.find(
- {"form_id": form_id}
+ {"form_id": form_id},
)
- responses = [
- FormResponse(**response) for response in await cursor.to_list(None)
- ]
+ responses = [FormResponse(**response) for response in await cursor.to_list(None)]
return JSONResponse([response.dict() for response in responses])
@requires(["authenticated", "admin"])
@@ -49,14 +44,14 @@ class Responses(Route):
resp=Response(
HTTP_200=OkayResponse,
HTTP_404=ErrorMessage,
- HTTP_400=ErrorMessage
+ HTTP_400=ErrorMessage,
),
- tags=["forms", "responses"]
+ tags=["forms", "responses"],
)
async def delete(self, request: Request) -> JSONResponse:
"""Bulk deletes form responses by IDs."""
if not await request.state.db.forms.find_one(
- {"_id": request.path_params["form_id"]}
+ {"_id": request.path_params["form_id"]},
):
return JSONResponse({"error": "not_found"}, status_code=404)
@@ -67,37 +62,34 @@ class Responses(Route):
ids = set(response_ids.ids)
cursor = request.state.db.responses.find(
- {"_id": {"$in": list(ids)}} # Convert here back to list, may throw error.
+ {"_id": {"$in": list(ids)}}, # Convert here back to list, may throw error.
)
- entries = [
- FormResponse(**submission) for submission in await cursor.to_list(None)
- ]
+ entries = [FormResponse(**submission) for submission in await cursor.to_list(None)]
actual_ids = {entry.id for entry in entries}
if len(ids) != len(actual_ids):
return JSONResponse(
{
"error": "responses_not_found",
- "ids": list(ids - actual_ids)
+ "ids": list(ids - actual_ids),
},
- status_code=404
+ status_code=404,
)
if any(entry.form_id != request.path_params["form_id"] for entry in entries):
return JSONResponse(
{
"error": "wrong_form",
- "ids": list(
- entry.id for entry in entries
- if entry.id != request.path_params["form_id"]
- )
+ "ids": [
+ entry.id for entry in entries if entry.id != request.path_params["form_id"]
+ ],
},
- status_code=400
+ status_code=400,
)
await request.state.db.responses.delete_many(
{
- "_id": {"$in": list(actual_ids)}
- }
+ "_id": {"$in": list(actual_ids)},
+ },
)
return JSONResponse({"status": "ok"})
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 765856e..8f01e2b 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -1,6 +1,4 @@
-"""
-Submit a form.
-"""
+"""Submit a form."""
import asyncio
import binascii
@@ -8,10 +6,9 @@ import datetime
import hashlib
import typing
import uuid
-from typing import Any, Optional
+from typing import Any
import httpx
-import pymongo.database
import sentry_sdk
from pydantic import ValidationError
from pydantic.main import BaseModel
@@ -29,13 +26,16 @@ from backend.routes.forms.discover import AUTH_FORM
from backend.routes.forms.unittesting import BypassDetectedError, execute_unittest
from backend.validation import ErrorMessage, api
+if typing.TYPE_CHECKING:
+ import pymongo.database
+
HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"
HCAPTCHA_HEADERS = {
- "Content-Type": "application/x-www-form-urlencoded"
+ "Content-Type": "application/x-www-form-urlencoded",
}
DISCORD_HEADERS = {
- "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"
+ "Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}",
}
@@ -46,7 +46,7 @@ class SubmissionResponse(BaseModel):
class PartialSubmission(BaseModel):
response: dict[str, Any]
- captcha: Optional[str]
+ captcha: str | None
class UnittestError(BaseModel):
@@ -62,9 +62,7 @@ class UnittestErrorMessage(ErrorMessage):
class SubmitForm(Route):
- """
- Submit a form with the provided form ID.
- """
+ """Submit a form with the provided form ID."""
name = "submit_form"
path = "/submit/{form_id:str}"
@@ -75,9 +73,9 @@ class SubmitForm(Route):
HTTP_200=SubmissionResponse,
HTTP_404=ErrorMessage,
HTTP_400=ErrorMessage,
- HTTP_422=UnittestErrorMessage
+ HTTP_422=UnittestErrorMessage,
),
- tags=["forms", "responses"]
+ tags=["forms", "responses"],
)
async def post(self, request: Request) -> JSONResponse:
"""Submit a response to the form."""
@@ -92,7 +90,7 @@ class SubmitForm(Route):
if old != request.user.token:
try:
expiry = datetime.datetime.fromisoformat(
- request.user.decoded_token.get("expiry")
+ request.user.decoded_token.get("expiry"),
)
except ValueError:
expiry = None
@@ -117,7 +115,7 @@ class SubmitForm(Route):
id="not-submitted",
form_id=AUTH_FORM.id,
response={question.id: None for question in AUTH_FORM.questions},
- timestamp=datetime.datetime.now().isoformat()
+ timestamp=datetime.datetime.now().isoformat(),
).dict()
return JSONResponse({"form": AUTH_FORM.dict(admin=False), "response": response})
@@ -131,8 +129,9 @@ class SubmitForm(Route):
ip_hash_ctx = hashlib.md5()
ip_hash_ctx.update(
request.headers.get(
- "Cf-Connecting-IP", request.client.host
- ).encode()
+ "Cf-Connecting-IP",
+ request.client.host,
+ ).encode(),
)
ip_hash = binascii.hexlify(ip_hash_ctx.digest())
user_agent_hash_ctx = hashlib.md5()
@@ -142,12 +141,12 @@ class SubmitForm(Route):
async with httpx.AsyncClient() as client:
query_params = {
"secret": constants.HCAPTCHA_API_SECRET,
- "response": data.get("captcha")
+ "response": data.get("captcha"),
}
r = await client.post(
HCAPTCHA_VERIFY_URL,
params=query_params,
- headers=HCAPTCHA_HEADERS
+ headers=HCAPTCHA_HEADERS,
)
r.raise_for_status()
captcha_data = r.json()
@@ -155,7 +154,7 @@ class SubmitForm(Route):
response["antispam"] = {
"ip_hash": ip_hash.decode(),
"user_agent_hash": user_agent_hash.decode(),
- "captcha_pass": captcha_data["success"]
+ "captcha_pass": captcha_data["success"],
}
if constants.FormFeatures.REQUIRES_LOGIN.value in form.features:
@@ -164,16 +163,12 @@ class SubmitForm(Route):
response["user"]["admin"] = request.user.admin
if (
- constants.FormFeatures.COLLECT_EMAIL.value in form.features
- and "email" not in response["user"]
+ constants.FormFeatures.COLLECT_EMAIL.value in form.features
+ and "email" not in response["user"]
):
- return JSONResponse({
- "error": "email_required"
- }, status_code=400)
+ return JSONResponse({"error": "email_required"}, status_code=400)
else:
- return JSONResponse({
- "error": "missing_discord_data"
- }, status_code=400)
+ return JSONResponse({"error": "missing_discord_data"}, status_code=400)
missing_fields = []
for question in form.questions:
@@ -184,10 +179,13 @@ class SubmitForm(Route):
missing_fields.append(question.id)
if missing_fields:
- return JSONResponse({
- "error": "missing_fields",
- "fields": missing_fields
- }, status_code=400)
+ return JSONResponse(
+ {
+ "error": "missing_fields",
+ "fields": missing_fields,
+ },
+ status_code=400,
+ )
try:
response_obj = FormResponse(**response)
@@ -200,10 +198,12 @@ class SubmitForm(Route):
if len(errors):
username = getattr(request.user, "user_id", "Unknown")
- sentry_sdk.capture_exception(BypassDetectedError(
- f"Detected unittest bypass attempt on form {form.id} by {username}. "
- f"Submission has been written to reporting database ({response_obj.id})."
- ))
+ sentry_sdk.capture_exception(
+ BypassDetectedError(
+ f"Detected unittest bypass attempt on form {form.id} by {username}. "
+ f"Submission has been written to reporting database ({response_obj.id}).",
+ )
+ )
database: pymongo.database.Database = request.state.db
await database.get_collection("violations").insert_one({
"user": username,
@@ -219,7 +219,7 @@ class SubmitForm(Route):
for test in unittest_results:
response_obj.response[test.question_id] = {
"value": response_obj.response[test.question_id],
- "passed": test.passed
+ "passed": test.passed,
}
if test.return_code == 0:
@@ -238,9 +238,8 @@ class SubmitForm(Route):
# Report a failure on internal errors,
# or if the test suite doesn't allow failures
if not test.passed:
- allow_failure = (
- form.questions[test.question_index].data["unittests"]["allow_failure"]
- )
+ question = form.questions[test.question_index]
+ allow_failure = question.data["unittests"]["allow_failure"]
# An error while communicating with the test runner
if test.return_code == 99:
@@ -251,15 +250,16 @@ class SubmitForm(Route):
failures.append(test)
if len(failures):
- return JSONResponse({
- "error": "failed_tests",
- "test_results": [
- test._asdict() for test in failures
- ]
- }, status_code=status_code)
+ return JSONResponse(
+ {
+ "error": "failed_tests",
+ "test_results": [test._asdict() for test in failures],
+ },
+ status_code=status_code,
+ )
await request.state.db.responses.insert_one(
- response_obj.dict(by_alias=True)
+ response_obj.dict(by_alias=True),
)
tasks = BackgroundTasks()
@@ -272,36 +272,37 @@ class SubmitForm(Route):
self.send_submission_webhook,
form=form,
response=response_obj,
- request_user=request_user
+ request_user=request_user,
)
if constants.FormFeatures.ASSIGN_ROLE.value in form.features:
tasks.add_task(
self.assign_role,
form=form,
- request_user=request.user
+ request_user=request.user,
)
- return JSONResponse({
- "form": form.dict(admin=False),
- "response": response_obj.dict()
- }, background=tasks)
+ return JSONResponse(
+ {
+ "form": form.dict(admin=False),
+ "response": response_obj.dict(),
+ },
+ background=tasks,
+ )
- else:
- return JSONResponse({
- "error": "Open form not found"
- }, status_code=404)
+ return JSONResponse({"error": "Open form not found"}, status_code=404)
@staticmethod
async def send_submission_webhook(
- form: Form,
- response: FormResponse,
- request_user: typing.Optional[User]
+ form: Form,
+ response: FormResponse,
+ request_user: User | None,
) -> None:
"""Helper to send a submission message to a discord webhook."""
# Stop if webhook is not available
if form.webhook is None:
- raise ValueError("Got empty webhook.")
+ msg = "Got empty webhook."
+ raise ValueError(msg)
try:
mention = request_user.discord_mention
@@ -330,7 +331,7 @@ class SubmitForm(Route):
hook = {
"embeds": [embed],
"allowed_mentions": {"parse": ["users", "roles"]},
- "username": form.name or "Python Discord Forms"
+ "username": form.name or "Python Discord Forms",
}
# Set hook message
@@ -345,8 +346,8 @@ class SubmitForm(Route):
"time": response.timestamp,
}
- for key in ctx:
- message = message.replace(f"{{{key}}}", str(ctx[key]))
+ for key, val in ctx.items():
+ message = message.replace(f"{{{key}}}", str(val))
hook["content"] = message.replace("_USER_MENTION_", mention)
@@ -359,11 +360,12 @@ class SubmitForm(Route):
async def assign_role(form: Form, request_user: User) -> None:
"""Assigns Discord role to user when user submitted response."""
if not form.discord_role:
- raise ValueError("Got empty Discord role ID.")
+ msg = "Got empty Discord role ID."
+ raise ValueError(msg)
url = (
f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}"
- f"/members/{request_user.payload['id']}/roles/{form.discord_role}"
+ f"/members/{request_user.payload["id"]}/roles/{form.discord_role}"
)
async with httpx.AsyncClient() as client:
diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py
index a02afea..3239d35 100644
--- a/backend/routes/forms/unittesting.py
+++ b/backend/routes/forms/unittesting.py
@@ -1,7 +1,8 @@
import base64
-from collections import namedtuple
from itertools import count
+from pathlib import Path
from textwrap import indent
+from typing import NamedTuple
import httpx
from httpx import HTTPStatusError
@@ -9,7 +10,7 @@ from httpx import HTTPStatusError
from backend.constants import SNEKBOX_URL
from backend.models import Form, FormResponse
-with open("resources/unittest_template.py") as file:
+with Path("resources/unittest_template.py").open(encoding="utf8") as file:
TEST_TEMPLATE = file.read()
@@ -17,9 +18,12 @@ class BypassDetectedError(Exception):
"""Detected an attempt at bypassing the unittests."""
-UnittestResult = namedtuple(
- "UnittestResult", "question_id question_index return_code passed result"
-)
+class UnittestResult(NamedTuple):
+ question_id: str
+ question_index: int
+ return_code: int
+ passed: bool
+ result: str
def filter_unittests(form: Form) -> Form:
@@ -46,11 +50,11 @@ def _make_unit_code(units: dict[str, str]) -> str:
elif unit_name == "tearDown":
result += "\ndef tearDown(self):"
else:
- name = f"test_{unit_name.removeprefix('#').removeprefix('test_')}"
+ name = f"test_{unit_name.removeprefix("#").removeprefix("test_")}"
result += f"\nasync def {name}(self):"
# Unite code
- result += f"\n{indent(unit_code, ' ')}"
+ result += f"\n{indent(unit_code, " ")}"
return indent(result, " ")
@@ -72,7 +76,8 @@ async def _post_eval(code: str) -> dict[str, str]:
async def execute_unittest(
- form_response: FormResponse, form: Form
+ form_response: FormResponse,
+ form: Form,
) -> tuple[list[UnittestResult], list[BypassDetectedError]]:
"""Execute all the unittests in this form and return the results."""
unittest_results = []
@@ -80,16 +85,17 @@ async def execute_unittest(
for index, question in enumerate(form.questions):
if question.type == "code":
-
# Exit early if the suite doesn't have any tests
if question.data["unittests"] is None:
- unittest_results.append(UnittestResult(
- question_id=question.id,
- question_index=index,
- return_code=0,
- passed=True,
- result=""
- ))
+ unittest_results.append(
+ UnittestResult(
+ question_id=question.id,
+ question_index=index,
+ return_code=0,
+ passed=True,
+ result="",
+ )
+ )
continue
passed = False
@@ -98,7 +104,7 @@ async def execute_unittest(
hidden_test_counter = count(1)
hidden_tests = {
test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter)
- for test in question.data["unittests"]["tests"].keys()
+ for test in question.data["unittests"]["tests"]
if test.startswith("#")
}
@@ -124,18 +130,18 @@ async def execute_unittest(
try:
passed = bool(int(stdout[0]))
except ValueError:
- raise BypassDetectedError("Detected a bypass when reading result code.")
+ msg = "Detected a bypass when reading result code."
+ raise BypassDetectedError(msg)
if passed and stdout.strip() != "1":
# Most likely a bypass attempt
# A 1 was written to stdout to indicate success,
# followed by the actual output
- raise BypassDetectedError(
- "Detected improper value for stdout in unittest."
- )
+ msg = "Detected improper value for stdout in unittest."
+ raise BypassDetectedError(msg)
# If the test failed, we have to populate the result string.
- elif not passed:
+ if not passed:
failed_tests = stdout[1:].strip().split(";")
# Redact failed hidden tests
@@ -146,7 +152,7 @@ async def execute_unittest(
result = ";".join(failed_tests)
else:
result = ""
- elif return_code in (5, 6, 99):
+ elif return_code in {5, 6, 99}:
result = response["stdout"]
# Killed by NsJail
elif return_code == 137:
@@ -162,12 +168,14 @@ async def execute_unittest(
errors.append(error)
passed = False
- unittest_results.append(UnittestResult(
- question_id=question.id,
- question_index=index,
- return_code=return_code,
- passed=passed,
- result=result
- ))
+ unittest_results.append(
+ UnittestResult(
+ question_id=question.id,
+ question_index=index,
+ return_code=return_code,
+ passed=passed,
+ result=result,
+ )
+ )
return unittest_results, errors
diff --git a/backend/routes/index.py b/backend/routes/index.py
index 207c36a..c6e38ea 100644
--- a/backend/routes/index.py
+++ b/backend/routes/index.py
@@ -1,6 +1,5 @@
-"""
-Index route for the forms API.
-"""
+"""Index route for the forms API."""
+
import platform
from pydantic import BaseModel
@@ -20,13 +19,13 @@ class IndexResponse(BaseModel):
description=(
"The connecting client, in production this will"
" be an IP of our internal load balancer"
- )
+ ),
)
sha: str = Field(
- description="Current release Git SHA in production."
+ description="Current release Git SHA in production.",
)
node: str = Field(
- description="The node that processed the request."
+ description="The node that processed the request.",
)
@@ -42,24 +41,22 @@ class IndexRoute(Route):
@api.validate(resp=Response(HTTP_200=IndexResponse))
def get(self, request: Request) -> JSONResponse:
- """
- Return a hello from Python Discord forms!
- """
+ """Return a hello from Python Discord forms!."""
response_data = {
"message": "Hello, world!",
"client": request.client.host,
"user": {
- "authenticated": False
+ "authenticated": False,
},
"sha": GIT_SHA,
- "node": platform.uname().node
+ "node": platform.uname().node,
}
if request.user.is_authenticated:
response_data["user"] = {
"authenticated": True,
"user": request.user.payload,
- "scopes": request.auth.scopes
+ "scopes": request.auth.scopes,
}
return JSONResponse(response_data)
diff --git a/backend/validation.py b/backend/validation.py
index 8771924..0560701 100644
--- a/backend/validation.py
+++ b/backend/validation.py
@@ -7,7 +7,7 @@ from spectree import SpecTree
api = SpecTree(
"starlette",
TITLE="Python Discord Forms",
- PATH="docs"
+ PATH="docs",
)