aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Kieran Siek <[email protected]>2022-03-20 17:25:06 -0400
committerGravatar GitHub <[email protected]>2022-03-20 17:25:06 -0400
commit25fce5e0161c2d84d4a6b710aa5c83a863766f98 (patch)
treee3c15dad453f8d518bbf5335a14eddedf2c2d054 /backend
parentMerge pull request #151 from python-discord/dependabot/pip/sentry-sdk-1.5.7 (diff)
parentMerge branch 'main' into roles (diff)
Merge pull request #135 from python-discord/roles
Overhaul Access System
Diffstat (limited to 'backend')
-rw-r--r--backend/authentication/backend.py9
-rw-r--r--backend/authentication/user.py41
-rw-r--r--backend/discord.py171
-rw-r--r--backend/models/__init__.py5
-rw-r--r--backend/models/discord_role.py40
-rw-r--r--backend/models/discord_user.py34
-rw-r--r--backend/models/form.py13
-rw-r--r--backend/routes/auth/authorize.py12
-rw-r--r--backend/routes/discord.py83
-rw-r--r--backend/routes/forms/discover.py2
-rw-r--r--backend/routes/forms/form.py61
-rw-r--r--backend/routes/forms/index.py6
-rw-r--r--backend/routes/forms/response.py11
-rw-r--r--backend/routes/forms/responses.py15
-rw-r--r--backend/routes/forms/submit.py2
15 files changed, 434 insertions, 71 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index c7590e9..54385e2 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -5,6 +5,7 @@ from starlette import authentication
from starlette.requests import Request
from backend import constants
+from backend import discord
# We must import user such way here to avoid circular imports
from .user import User
@@ -60,8 +61,12 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
except Exception:
raise authentication.AuthenticationError("Could not parse user details.")
- user = User(token, user_details)
- if await user.fetch_admin_status(request):
+ user = User(
+ token, user_details, await discord.get_member(request.state.db, user_details["id"])
+ )
+ if await user.fetch_admin_status(request.state.db):
scopes.append("admin")
+ scopes.extend(await user.get_user_roles(request.state.db))
+
return authentication.AuthCredentials(scopes), user
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index 857c2ed..6256cae 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,20 +1,27 @@
+import typing
import typing as t
import jwt
+from pymongo.database import Database
from starlette.authentication import BaseUser
-from starlette.requests import Request
+from backend import discord, models
from backend.constants import SECRET_KEY
-from backend.discord import fetch_user_details
class User(BaseUser):
"""Starlette BaseUser implementation for JWT authentication."""
- def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
+ def __init__(
+ self,
+ token: str,
+ payload: dict[str, t.Any],
+ member: typing.Optional[models.DiscordMember],
+ ) -> None:
self.token = token
self.payload = payload
self.admin = False
+ self.member = member
@property
def is_authenticated(self) -> bool:
@@ -34,16 +41,36 @@ class User(BaseUser):
def decoded_token(self) -> dict[str, any]:
return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
- async def fetch_admin_status(self, request: Request) -> bool:
- self.admin = await request.state.db.admins.find_one(
+ async def get_user_roles(self, database: Database) -> list[str]:
+ """Get a list of the user's discord roles."""
+ if not self.member:
+ return []
+
+ server_roles = await discord.get_roles(database)
+ roles = [role.name for role in server_roles if role.id in self.member.roles]
+
+ if "admin" in roles:
+ # Protect against collision with the forms admin role
+ roles.remove("admin")
+ roles.append("discord admin")
+
+ return roles
+
+ async def fetch_admin_status(self, database: Database) -> bool:
+ self.admin = await database.admins.find_one(
{"_id": self.payload["id"]}
) is not None
return self.admin
- async def refresh_data(self) -> None:
+ async def refresh_data(self, database: Database) -> None:
"""Fetches user data from discord, and updates the instance."""
- self.payload = await fetch_user_details(self.decoded_token.get("token"))
+ self.member = await discord.get_member(database, self.payload["id"])
+
+ if self.member:
+ self.payload = self.member.user.dict()
+ else:
+ self.payload = await discord.fetch_user_details(self.decoded_token.get("token"))
updated_info = self.decoded_token
updated_info["user_details"] = self.payload
diff --git a/backend/discord.py b/backend/discord.py
index e5c7f8f..be12109 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -1,16 +1,22 @@
"""Various utilities for working with the Discord API."""
+
+import datetime
+import json
+import typing
+
import httpx
+import starlette.requests
+from pymongo.database import Database
+from starlette import exceptions
-from backend.constants import (
- DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET
-)
+from backend import constants, models
async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
async with httpx.AsyncClient() as client:
data = {
- "client_id": OAUTH2_CLIENT_ID,
- "client_secret": OAUTH2_CLIENT_SECRET,
+ "client_id": constants.OAUTH2_CLIENT_ID,
+ "client_secret": constants.OAUTH2_CLIENT_SECRET,
"redirect_uri": f"{redirect}/callback"
}
@@ -21,7 +27,7 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict
data["grant_type"] = "authorization_code"
data["code"] = code
- r = await client.post(f"{DISCORD_API_BASE_URL}/oauth2/token", headers={
+ r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={
"Content-Type": "application/x-www-form-urlencoded"
}, data=data)
@@ -32,10 +38,161 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict
async def fetch_user_details(bearer_token: str) -> dict:
async with httpx.AsyncClient() as client:
- r = await client.get(f"{DISCORD_API_BASE_URL}/users/@me", headers={
+ r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={
"Authorization": f"Bearer {bearer_token}"
})
r.raise_for_status()
return r.json()
+
+
+async def _get_role_info() -> list[models.DiscordRole]:
+ """Get information about the roles in the configured guild."""
+ async with httpx.AsyncClient() as client:
+ r = await client.get(
+ f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles",
+ headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
+ )
+
+ r.raise_for_status()
+ return [models.DiscordRole(**role) for role in r.json()]
+
+
+async def get_roles(
+ database: Database, *, force_refresh: bool = False
+) -> list[models.DiscordRole]:
+ """
+ Get a list of all roles from the cache, or discord API if not available.
+
+ If `force_refresh` is True, the cache is skipped and the roles are updated.
+ """
+ collection = database.get_collection("roles")
+
+ if force_refresh:
+ # Drop all values in the collection
+ await collection.delete_many({})
+
+ # `create_index` creates the index if it does not exist, or passes
+ # This handles TTL on role objects
+ await collection.create_index(
+ "inserted_at",
+ expireAfterSeconds=60 * 60 * 24, # 1 day
+ name="inserted_at",
+ )
+
+ roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()]
+
+ if len(roles) == 0:
+ # Fetch roles from the API and insert into the database
+ roles = await _get_role_info()
+ await collection.insert_many({
+ "name": role.name,
+ "id": role.id,
+ "data": role.json(),
+ "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
+ } for role in roles)
+
+ return roles
+
+
+async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]:
+ """Get a member by ID from the configured guild using the discord API."""
+ async with httpx.AsyncClient() as client:
+ r = await client.get(
+ f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}"
+ f"/members/{member_id}",
+ headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
+ )
+
+ if r.status_code == 404:
+ return None
+
+ r.raise_for_status()
+ return models.DiscordMember(**r.json())
+
+
+async def get_member(
+ database: Database, user_id: str, *, force_refresh: bool = False
+) -> typing.Optional[models.DiscordMember]:
+ """
+ Get a member from the cache, or from the discord API.
+
+ If `force_refresh` is True, the cache is skipped and the entry is updated.
+ None may be returned if the member object does not exist.
+ """
+ collection = database.get_collection("discord_members")
+
+ if force_refresh:
+ await collection.delete_one({"user": user_id})
+
+ # `create_index` creates the index if it does not exist, or passes
+ # This handles TTL on member objects
+ await collection.create_index(
+ "inserted_at",
+ expireAfterSeconds=60 * 60, # 1 hour
+ name="inserted_at",
+ )
+
+ result = await collection.find_one({"user": user_id})
+
+ if result is not None:
+ return models.DiscordMember(**json.loads(result["data"]))
+
+ member = await _fetch_member_api(user_id)
+
+ if not member:
+ return None
+
+ await collection.insert_one({
+ "user": user_id,
+ "data": member.json(),
+ "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
+ })
+ return member
+
+
+class FormNotFoundError(exceptions.HTTPException):
+ """The requested form was not found."""
+
+
+class UnauthorizedError(exceptions.HTTPException):
+ """You are not authorized to use this resource."""
+
+
+async def _verify_access_helper(
+ form_id: str, request: starlette.requests.Request, attribute: str
+) -> None:
+ """A low level helper to validate access to a form resource based on the user's scopes."""
+ form = await request.state.db.forms.find_one({"_id": form_id})
+
+ if not form:
+ raise FormNotFoundError(status_code=404)
+
+ # Short circuit all resources for forms admins
+ if "admin" in request.auth.scopes:
+ return
+
+ form = models.Form(**form)
+
+ for role_id in getattr(form, attribute, []):
+ role = await request.state.db.roles.find_one({"id": role_id})
+ if not role:
+ continue
+
+ role = models.DiscordRole(**json.loads(role["data"]))
+
+ if role.name in request.auth.scopes:
+ return
+
+ raise UnauthorizedError(status_code=401)
+
+
+async def verify_response_access(form_id: str, request: starlette.requests.Request) -> None:
+ """Ensure the user can access responses on the requested resource."""
+ await _verify_access_helper(form_id, request, "response_readers")
+
+
+async def verify_edit_access(form_id: str, request: starlette.requests.Request) -> None:
+ """Ensure the user can view and modify the requested resource."""
+ await _verify_access_helper(form_id, request, "editors")
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index 8ad7f7f..a9f76e0 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -1,12 +1,15 @@
from .antispam import AntiSpam
-from .discord_user import DiscordUser
+from .discord_role import DiscordRole
+from .discord_user import DiscordMember, DiscordUser
from .form import Form, FormList
from .form_response import FormResponse, ResponseList
from .question import CodeQuestion, Question
__all__ = [
"AntiSpam",
+ "DiscordRole",
"DiscordUser",
+ "DiscordMember",
"Form",
"FormResponse",
"CodeQuestion",
diff --git a/backend/models/discord_role.py b/backend/models/discord_role.py
new file mode 100644
index 0000000..c05c9de
--- /dev/null
+++ b/backend/models/discord_role.py
@@ -0,0 +1,40 @@
+import typing
+
+from pydantic import BaseModel
+
+
+class RoleTags(BaseModel):
+ """Meta information about a discord role."""
+
+ bot_id: typing.Optional[str]
+ integration_id: typing.Optional[str]
+ premium_subscriber: bool
+
+ def __init__(self, **data: typing.Any) -> None:
+ """
+ Handle the terrible discord API.
+
+ Discord only returns the premium_subscriber field if it's true,
+ meaning the typical validation process wouldn't work.
+
+ We manually parse the raw data to determine if the field exists, and give it a useful
+ bool value.
+ """
+ data["premium_subscriber"] = "premium_subscriber" in data.keys()
+ super().__init__(**data)
+
+
+class DiscordRole(BaseModel):
+ """Schema model of Discord guild roles."""
+
+ id: str
+ name: str
+ color: int
+ hoist: bool
+ icon: typing.Optional[str]
+ unicode_emoji: typing.Optional[str]
+ position: int
+ permissions: str
+ managed: bool
+ mentionable: bool
+ tags: typing.Optional[RoleTags]
diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py
index 9f246ba..0eca15b 100644
--- a/backend/models/discord_user.py
+++ b/backend/models/discord_user.py
@@ -1,10 +1,11 @@
+import datetime
import typing as t
from pydantic import BaseModel
-class DiscordUser(BaseModel):
- """Schema model of Discord user for form response."""
+class _User(BaseModel):
+ """Base for discord users and members."""
# Discord default fields.
username: str
@@ -20,5 +21,34 @@ class DiscordUser(BaseModel):
premium_type: t.Optional[int]
public_flags: t.Optional[int]
+
+class DiscordUser(_User):
+ """Schema model of Discord user for form response."""
+
# Custom fields
admin: bool
+
+
+class DiscordMember(BaseModel):
+ """A discord guild member."""
+
+ user: _User
+ nick: t.Optional[str]
+ avatar: t.Optional[str]
+ roles: list[str]
+ joined_at: datetime.datetime
+ premium_since: t.Optional[datetime.datetime]
+ deaf: bool
+ mute: bool
+ pending: t.Optional[bool]
+ permissions: t.Optional[str]
+ communication_disabled_until: t.Optional[datetime.datetime]
+
+ def dict(self, *args, **kwargs) -> dict[str, t.Any]:
+ """Convert the model to a python dict, and encode timestamps in a serializable format."""
+ data = super().dict(*args, **kwargs)
+ for field, value in data.items():
+ if isinstance(value, datetime.datetime):
+ data[field] = value.isoformat()
+
+ return data
diff --git a/backend/models/form.py b/backend/models/form.py
index f19ed85..f888d6e 100644
--- a/backend/models/form.py
+++ b/backend/models/form.py
@@ -1,10 +1,10 @@
import typing as t
import httpx
-from pydantic import constr, BaseModel, Field, root_validator, validator
+from pydantic import BaseModel, Field, constr, root_validator, validator
from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from backend.constants import FormFeatures, WebHook
+from backend.constants import DISCORD_GUILD, FormFeatures, WebHook
from .question import Question
PUBLIC_FIELDS = [
@@ -43,6 +43,8 @@ class Form(BaseModel):
submitted_text: t.Optional[str] = None
webhook: _WebHook = None
discord_role: t.Optional[str]
+ response_readers: t.Optional[list[str]]
+ editors: t.Optional[list[str]]
class Config:
allow_population_by_field_name = True
@@ -67,6 +69,13 @@ class Form(BaseModel):
return value
+ @validator("response_readers", "editors")
+ def validate_role_scoping(cls, value: t.Optional[list[str]]) -> t.Optional[list[str]]:
+ """Ensure special role based permissions aren't granted to the @everyone role."""
+ if value and str(DISCORD_GUILD) in value:
+ raise ValueError("You can not add the everyone role as an access scope.")
+ return value
+
@root_validator
def validate_role(cls, values: dict[str, t.Any]) -> t.Optional[dict[str, t.Any]]:
"""Validates does Discord role provided when flag provided."""
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index d4587f0..42fb3ec 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -17,7 +17,7 @@ from starlette.requests import Request
from backend import constants
from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.discord import fetch_bearer_token, fetch_user_details, get_member
from backend.route import Route
from backend.validation import ErrorMessage, api
@@ -34,8 +34,8 @@ class AuthorizeResponse(BaseModel):
async def process_token(
- bearer_token: dict,
- request: Request
+ bearer_token: dict,
+ request: Request
) -> Union[AuthorizeResponse, AUTH_FAILURE]:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()
@@ -46,6 +46,9 @@ async def process_token(
AUTH_FAILURE.delete_cookie("token")
return AUTH_FAILURE
+ user_id = user_details["id"]
+ member = await get_member(request.state.db, user_id, force_refresh=True)
+
max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
token_expiry = interaction_start + max_age
@@ -53,11 +56,12 @@ async def process_token(
"token": bearer_token["access_token"],
"refresh": bearer_token["refresh_token"],
"user_details": user_details,
+ "in_guild": bool(member),
"expiry": token_expiry.isoformat()
}
token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
- user = User(token, user_details)
+ user = User(token, user_details, member)
response = responses.JSONResponse({
"username": user.display_name,
diff --git a/backend/routes/discord.py b/backend/routes/discord.py
new file mode 100644
index 0000000..bca1edb
--- /dev/null
+++ b/backend/routes/discord.py
@@ -0,0 +1,83 @@
+"""Routes which directly interact with discord related data."""
+
+import pydantic
+from spectree import Response
+from starlette.authentication import requires
+from starlette.responses import JSONResponse
+from starlette.routing import Request
+
+from backend import discord, models, route
+from backend.validation import ErrorMessage, api
+
+NOT_FOUND_EXCEPTION = JSONResponse(
+ {"error": "Could not find the requested resource in the guild or cache."}, status_code=404
+)
+
+
+class RolesRoute(route.Route):
+ """Refreshes the roles database."""
+
+ name = "roles"
+ path = "/roles"
+
+ class RolesResponse(pydantic.BaseModel):
+ """A list of all roles on the configured server."""
+
+ roles: list[models.DiscordRole]
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=RolesResponse),
+ tags=["roles"]
+ )
+ async def patch(self, request: Request) -> JSONResponse:
+ """Refresh the roles database."""
+ roles = await discord.get_roles(request.state.db, force_refresh=True)
+
+ return JSONResponse(
+ {"roles": [role.dict() for role in roles]},
+ )
+
+
+class MemberRoute(route.Route):
+ """Retrieve information about a server member."""
+
+ name = "member"
+ path = "/member"
+
+ class MemberRequest(pydantic.BaseModel):
+ """An ID of the member to update."""
+
+ user_id: str
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
+ json=MemberRequest,
+ tags=["auth"]
+ )
+ async def delete(self, request: Request) -> JSONResponse:
+ """Force a resync of the cache for the given user."""
+ body = await request.json()
+ member = await discord.get_member(request.state.db, body["user_id"], force_refresh=True)
+
+ if member:
+ return JSONResponse(member.dict())
+ else:
+ return NOT_FOUND_EXCEPTION
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
+ json=MemberRequest,
+ tags=["auth"]
+ )
+ async def get(self, request: Request) -> JSONResponse:
+ """Get a user's roles on the configured server."""
+ body = await request.json()
+ member = await discord.get_member(request.state.db, body["user_id"])
+
+ if member:
+ return JSONResponse(member.dict())
+ else:
+ return NOT_FOUND_EXCEPTION
diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py
index d7351d5..b993075 100644
--- a/backend/routes/forms/discover.py
+++ b/backend/routes/forms/discover.py
@@ -29,7 +29,7 @@ EMPTY_FORM = Form(
features=__FEATURES,
questions=[__QUESTION],
name="Auth form",
- description="An empty form to help you get a token."
+ description="An empty form to help you get a token.",
)
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index 0f96b85..567c197 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -10,13 +10,15 @@ from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
-from backend import constants
+from backend import constants, discord
from backend.models import Form
from backend.route import Route
from backend.routes.forms.discover import EMPTY_FORM
from backend.routes.forms.unittesting import filter_unittests
from backend.validation import ErrorMessage, OkayResponse, api
+PUBLIC_FORM_FEATURES = (constants.FormFeatures.OPEN, constants.FormFeatures.DISCOVERABLE)
+
class SingleForm(Route):
"""
@@ -31,9 +33,19 @@ class SingleForm(Route):
@api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
- admin = request.user.admin if request.user.is_authenticated else False
form_id = request.path_params["form_id"].lower()
+ try:
+ await discord.verify_edit_access(form_id, request)
+ admin = True
+ except discord.FormNotFoundError:
+ if not constants.PRODUCTION and form_id == EMPTY_FORM.id:
+ # Empty form to help with authentication in development.
+ return JSONResponse(EMPTY_FORM.dict(admin=False))
+ raise
+ except discord.UnauthorizedError:
+ admin = False
+
filters = {
"_id": form_id
}
@@ -41,25 +53,18 @@ class SingleForm(Route):
if not admin:
filters["features"] = {"$in": ["OPEN", "DISCOVERABLE"]}
- if raw_form := await request.state.db.forms.find_one(filters):
- form = Form(**raw_form)
- if not admin:
- form = filter_unittests(form)
-
- return JSONResponse(form.dict(admin=admin))
-
- elif not constants.PRODUCTION and form_id == EMPTY_FORM.id:
- # Empty form to help with authentication in development.
- return JSONResponse(EMPTY_FORM.dict(admin=admin))
+ form = Form(**await request.state.db.forms.find_one(filters))
+ if not admin:
+ form = filter_unittests(form)
- return JSONResponse({"error": "not_found"}, status_code=404)
+ return JSONResponse(form.dict(admin=admin))
- @requires(["authenticated", "admin"])
+ @requires(["authenticated"])
@api.validate(
resp=Response(
HTTP_200=OkayResponse,
HTTP_400=ErrorMessage,
- HTTP_404=ErrorMessage
+ HTTP_404=ErrorMessage,
),
tags=["forms"]
)
@@ -70,10 +75,12 @@ class SingleForm(Route):
except json.decoder.JSONDecodeError:
return JSONResponse("Expected a JSON body.", 400)
- form_id = {"_id": request.path_params["form_id"].lower()}
- if raw_form := await request.state.db.forms.find_one(form_id):
+ form_id = request.path_params["form_id"].lower()
+ await discord.verify_edit_access(form_id, request)
+
+ if raw_form := await request.state.db.forms.find_one({"_id": form_id}):
if "_id" in data or "id" in data:
- if (data.get("id") or data.get("_id")) != form_id["_id"]:
+ if (data.get("id") or data.get("_id")) != form_id:
return JSONResponse({"error": "locked_field"}, status_code=400)
# Build Data Merger
@@ -90,7 +97,7 @@ class SingleForm(Route):
except ValidationError as e:
return JSONResponse(e.errors(), status_code=422)
- await request.state.db.forms.replace_one(form_id, form.dict())
+ await request.state.db.forms.replace_one({"_id": form_id}, form.dict())
return JSONResponse(form.dict())
else:
@@ -98,21 +105,15 @@ class SingleForm(Route):
@requires(["authenticated", "admin"])
@api.validate(
- resp=Response(HTTP_200=OkayResponse, HTTP_404=ErrorMessage),
+ resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage),
tags=["forms"]
)
async def delete(self, request: Request) -> JSONResponse:
"""Deletes form by ID."""
- if not await request.state.db.forms.find_one(
- {"_id": request.path_params["form_id"].lower()}
- ):
- return JSONResponse({"error": "not_found"}, status_code=404)
+ form_id = request.path_params["form_id"].lower()
+ await discord.verify_edit_access(form_id, request)
- await request.state.db.forms.delete_one(
- {"_id": request.path_params["form_id"].lower()}
- )
- await request.state.db.responses.delete_many(
- {"form_id": request.path_params["form_id"].lower()}
- )
+ await request.state.db.forms.delete_one({"_id": form_id})
+ await request.state.db.responses.delete_many({"form_id": form_id})
return JSONResponse({"status": "ok"})
diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py
index 22171fa..38be693 100644
--- a/backend/routes/forms/index.py
+++ b/backend/routes/forms/index.py
@@ -15,13 +15,13 @@ from backend.validation import ErrorMessage, OkayResponse, api
class FormsList(Route):
"""
- List all available forms for administrator viewing.
+ List all available forms for authorized viewers.
"""
name = "forms_list_create"
path = "/"
- @requires(["authenticated", "admin"])
+ @requires(["authenticated", "Admins"])
@api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Return a list of all forms to authenticated users."""
@@ -38,7 +38,7 @@ class FormsList(Route):
forms
)
- @requires(["authenticated", "admin"])
+ @requires(["authenticated", "Helpers"])
@api.validate(
json=Form,
resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage),
diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py
index d8d8d17..565701f 100644
--- a/backend/routes/forms/response.py
+++ b/backend/routes/forms/response.py
@@ -1,11 +1,13 @@
"""
Returns or deletes form response by ID.
"""
+
from spectree import Response as RouteResponse
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
+from backend import discord
from backend.models import FormResponse
from backend.route import Route
from backend.validation import ErrorMessage, OkayResponse, api
@@ -17,23 +19,26 @@ class Response(Route):
name = "response"
path = "/{form_id:str}/responses/{response_id:str}"
- @requires(["authenticated", "admin"])
+ @requires(["authenticated"])
@api.validate(
resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage),
tags=["forms", "responses"]
)
async def get(self, request: Request) -> JSONResponse:
"""Return a single form response by ID."""
+ form_id = request.path_params["form_id"]
+ await discord.verify_response_access(form_id, request)
+
if raw_response := await request.state.db.responses.find_one(
{
"_id": request.path_params["response_id"],
- "form_id": request.path_params["form_id"]
+ "form_id": form_id
}
):
response = FormResponse(**raw_response)
return JSONResponse(response.dict())
else:
- return JSONResponse({"error": "not_found"}, status_code=404)
+ return JSONResponse({"error": "response_not_found"}, status_code=404)
@requires(["authenticated", "admin"])
@api.validate(
diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py
index f3c4cd7..818ebce 100644
--- a/backend/routes/forms/responses.py
+++ b/backend/routes/forms/responses.py
@@ -7,9 +7,10 @@ from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
+from backend import discord
from backend.models import FormResponse, ResponseList
from backend.route import Route
-from backend.validation import api, ErrorMessage, OkayResponse
+from backend.validation import ErrorMessage, OkayResponse, api
class ResponseIdList(BaseModel):
@@ -24,20 +25,18 @@ class Responses(Route):
name = "form_responses"
path = "/{form_id:str}/responses"
- @requires(["authenticated", "admin"])
+ @requires(["authenticated"])
@api.validate(
- resp=Response(HTTP_200=ResponseList, HTTP_404=ErrorMessage),
+ resp=Response(HTTP_200=ResponseList),
tags=["forms", "responses"]
)
async def get(self, request: Request) -> JSONResponse:
"""Returns all form responses by form ID."""
- if not await request.state.db.forms.find_one(
- {"_id": request.path_params["form_id"]}
- ):
- return JSONResponse({"error": "not_found"}, 404)
+ form_id = request.path_params["form_id"]
+ await discord.verify_response_access(form_id, request)
cursor = request.state.db.responses.find(
- {"form_id": request.path_params["form_id"]}
+ {"form_id": form_id}
)
responses = [
FormResponse(**response) for response in await cursor.to_list(None)
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 95e30b0..baf403d 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -83,7 +83,7 @@ class SubmitForm(Route):
try:
if hasattr(request.user, User.refresh_data.__name__):
old = request.user.token
- await request.user.refresh_data()
+ await request.user.refresh_data(request.state.db)
if old != request.user.token:
try: