diff options
| author | 2022-03-20 17:25:06 -0400 | |
|---|---|---|
| committer | 2022-03-20 17:25:06 -0400 | |
| commit | 25fce5e0161c2d84d4a6b710aa5c83a863766f98 (patch) | |
| tree | e3c15dad453f8d518bbf5335a14eddedf2c2d054 /backend | |
| parent | Merge pull request #151 from python-discord/dependabot/pip/sentry-sdk-1.5.7 (diff) | |
| parent | Merge branch 'main' into roles (diff) | |
Merge pull request #135 from python-discord/roles
Overhaul Access System
Diffstat (limited to '')
| -rw-r--r-- | backend/authentication/backend.py | 9 | ||||
| -rw-r--r-- | backend/authentication/user.py | 41 | ||||
| -rw-r--r-- | backend/discord.py | 171 | ||||
| -rw-r--r-- | backend/models/__init__.py | 5 | ||||
| -rw-r--r-- | backend/models/discord_role.py | 40 | ||||
| -rw-r--r-- | backend/models/discord_user.py | 34 | ||||
| -rw-r--r-- | backend/models/form.py | 13 | ||||
| -rw-r--r-- | backend/routes/auth/authorize.py | 12 | ||||
| -rw-r--r-- | backend/routes/discord.py | 83 | ||||
| -rw-r--r-- | backend/routes/forms/discover.py | 2 | ||||
| -rw-r--r-- | backend/routes/forms/form.py | 61 | ||||
| -rw-r--r-- | backend/routes/forms/index.py | 6 | ||||
| -rw-r--r-- | backend/routes/forms/response.py | 11 | ||||
| -rw-r--r-- | backend/routes/forms/responses.py | 15 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 2 | 
15 files changed, 434 insertions, 71 deletions
| diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index c7590e9..54385e2 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -5,6 +5,7 @@ from starlette import authentication  from starlette.requests import Request  from backend import constants +from backend import discord  # We must import user such way here to avoid circular imports  from .user import User @@ -60,8 +61,12 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          except Exception:              raise authentication.AuthenticationError("Could not parse user details.") -        user = User(token, user_details) -        if await user.fetch_admin_status(request): +        user = User( +            token, user_details, await discord.get_member(request.state.db, user_details["id"]) +        ) +        if await user.fetch_admin_status(request.state.db):              scopes.append("admin") +        scopes.extend(await user.get_user_roles(request.state.db)) +          return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index 857c2ed..6256cae 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,20 +1,27 @@ +import typing  import typing as t  import jwt +from pymongo.database import Database  from starlette.authentication import BaseUser -from starlette.requests import Request +from backend import discord, models  from backend.constants import SECRET_KEY -from backend.discord import fetch_user_details  class User(BaseUser):      """Starlette BaseUser implementation for JWT authentication.""" -    def __init__(self, token: str, payload: dict[str, t.Any]) -> None: +    def __init__( +        self, +        token: str, +        payload: dict[str, t.Any], +        member: typing.Optional[models.DiscordMember], +    ) -> None:          self.token = token          self.payload = payload          self.admin = False +        self.member = member      @property      def is_authenticated(self) -> bool: @@ -34,16 +41,36 @@ class User(BaseUser):      def decoded_token(self) -> dict[str, any]:          return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) -    async def fetch_admin_status(self, request: Request) -> bool: -        self.admin = await request.state.db.admins.find_one( +    async def get_user_roles(self, database: Database) -> list[str]: +        """Get a list of the user's discord roles.""" +        if not self.member: +            return [] + +        server_roles = await discord.get_roles(database) +        roles = [role.name for role in server_roles if role.id in self.member.roles] + +        if "admin" in roles: +            # Protect against collision with the forms admin role +            roles.remove("admin") +            roles.append("discord admin") + +        return roles + +    async def fetch_admin_status(self, database: Database) -> bool: +        self.admin = await database.admins.find_one(              {"_id": self.payload["id"]}          ) is not None          return self.admin -    async def refresh_data(self) -> None: +    async def refresh_data(self, database: Database) -> None:          """Fetches user data from discord, and updates the instance.""" -        self.payload = await fetch_user_details(self.decoded_token.get("token")) +        self.member = await discord.get_member(database, self.payload["id"]) + +        if self.member: +            self.payload = self.member.user.dict() +        else: +            self.payload = await discord.fetch_user_details(self.decoded_token.get("token"))          updated_info = self.decoded_token          updated_info["user_details"] = self.payload diff --git a/backend/discord.py b/backend/discord.py index e5c7f8f..be12109 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -1,16 +1,22 @@  """Various utilities for working with the Discord API.""" + +import datetime +import json +import typing +  import httpx +import starlette.requests +from pymongo.database import Database +from starlette import exceptions -from backend.constants import ( -    DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET -) +from backend import constants, models  async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:      async with httpx.AsyncClient() as client:          data = { -            "client_id": OAUTH2_CLIENT_ID, -            "client_secret": OAUTH2_CLIENT_SECRET, +            "client_id": constants.OAUTH2_CLIENT_ID, +            "client_secret": constants.OAUTH2_CLIENT_SECRET,              "redirect_uri": f"{redirect}/callback"          } @@ -21,7 +27,7 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict              data["grant_type"] = "authorization_code"              data["code"] = code -        r = await client.post(f"{DISCORD_API_BASE_URL}/oauth2/token", headers={ +        r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={              "Content-Type": "application/x-www-form-urlencoded"          }, data=data) @@ -32,10 +38,161 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict  async def fetch_user_details(bearer_token: str) -> dict:      async with httpx.AsyncClient() as client: -        r = await client.get(f"{DISCORD_API_BASE_URL}/users/@me", headers={ +        r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={              "Authorization": f"Bearer {bearer_token}"          })          r.raise_for_status()          return r.json() + + +async def _get_role_info() -> list[models.DiscordRole]: +    """Get information about the roles in the configured guild.""" +    async with httpx.AsyncClient() as client: +        r = await client.get( +            f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles", +            headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} +        ) + +        r.raise_for_status() +        return [models.DiscordRole(**role) for role in r.json()] + + +async def get_roles( +    database: Database, *, force_refresh: bool = False +) -> list[models.DiscordRole]: +    """ +    Get a list of all roles from the cache, or discord API if not available. + +    If `force_refresh` is True, the cache is skipped and the roles are updated. +    """ +    collection = database.get_collection("roles") + +    if force_refresh: +        # Drop all values in the collection +        await collection.delete_many({}) + +    # `create_index` creates the index if it does not exist, or passes +    # This handles TTL on role objects +    await collection.create_index( +        "inserted_at", +        expireAfterSeconds=60 * 60 * 24,  # 1 day +        name="inserted_at", +    ) + +    roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()] + +    if len(roles) == 0: +        # Fetch roles from the API and insert into the database +        roles = await _get_role_info() +        await collection.insert_many({ +            "name": role.name, +            "id": role.id, +            "data": role.json(), +            "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), +        } for role in roles) + +    return roles + + +async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]: +    """Get a member by ID from the configured guild using the discord API.""" +    async with httpx.AsyncClient() as client: +        r = await client.get( +            f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" +            f"/members/{member_id}", +            headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} +        ) + +        if r.status_code == 404: +            return None + +        r.raise_for_status() +        return models.DiscordMember(**r.json()) + + +async def get_member( +    database: Database, user_id: str, *, force_refresh: bool = False +) -> typing.Optional[models.DiscordMember]: +    """ +    Get a member from the cache, or from the discord API. + +    If `force_refresh` is True, the cache is skipped and the entry is updated. +    None may be returned if the member object does not exist. +    """ +    collection = database.get_collection("discord_members") + +    if force_refresh: +        await collection.delete_one({"user": user_id}) + +    # `create_index` creates the index if it does not exist, or passes +    # This handles TTL on member objects +    await collection.create_index( +        "inserted_at", +        expireAfterSeconds=60 * 60,  # 1 hour +        name="inserted_at", +    ) + +    result = await collection.find_one({"user": user_id}) + +    if result is not None: +        return models.DiscordMember(**json.loads(result["data"])) + +    member = await _fetch_member_api(user_id) + +    if not member: +        return None + +    await collection.insert_one({ +        "user": user_id, +        "data": member.json(), +        "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), +    }) +    return member + + +class FormNotFoundError(exceptions.HTTPException): +    """The requested form was not found.""" + + +class UnauthorizedError(exceptions.HTTPException): +    """You are not authorized to use this resource.""" + + +async def _verify_access_helper( +    form_id: str, request: starlette.requests.Request, attribute: str +) -> None: +    """A low level helper to validate access to a form resource based on the user's scopes.""" +    form = await request.state.db.forms.find_one({"_id": form_id}) + +    if not form: +        raise FormNotFoundError(status_code=404) + +    # Short circuit all resources for forms admins +    if "admin" in request.auth.scopes: +        return + +    form = models.Form(**form) + +    for role_id in getattr(form, attribute, []): +        role = await request.state.db.roles.find_one({"id": role_id}) +        if not role: +            continue + +        role = models.DiscordRole(**json.loads(role["data"])) + +        if role.name in request.auth.scopes: +            return + +    raise UnauthorizedError(status_code=401) + + +async def verify_response_access(form_id: str, request: starlette.requests.Request) -> None: +    """Ensure the user can access responses on the requested resource.""" +    await _verify_access_helper(form_id, request, "response_readers") + + +async def verify_edit_access(form_id: str, request: starlette.requests.Request) -> None: +    """Ensure the user can view and modify the requested resource.""" +    await _verify_access_helper(form_id, request, "editors") diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 8ad7f7f..a9f76e0 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,12 +1,15 @@  from .antispam import AntiSpam -from .discord_user import DiscordUser +from .discord_role import DiscordRole +from .discord_user import DiscordMember, DiscordUser  from .form import Form, FormList  from .form_response import FormResponse, ResponseList  from .question import CodeQuestion, Question  __all__ = [      "AntiSpam", +    "DiscordRole",      "DiscordUser", +    "DiscordMember",      "Form",      "FormResponse",      "CodeQuestion", diff --git a/backend/models/discord_role.py b/backend/models/discord_role.py new file mode 100644 index 0000000..c05c9de --- /dev/null +++ b/backend/models/discord_role.py @@ -0,0 +1,40 @@ +import typing + +from pydantic import BaseModel + + +class RoleTags(BaseModel): +    """Meta information about a discord role.""" + +    bot_id: typing.Optional[str] +    integration_id: typing.Optional[str] +    premium_subscriber: bool + +    def __init__(self, **data: typing.Any) -> None: +        """ +        Handle the terrible discord API. + +        Discord only returns the premium_subscriber field if it's true, +        meaning the typical validation process wouldn't work. + +        We manually parse the raw data to determine if the field exists, and give it a useful +        bool value. +        """ +        data["premium_subscriber"] = "premium_subscriber" in data.keys() +        super().__init__(**data) + + +class DiscordRole(BaseModel): +    """Schema model of Discord guild roles.""" + +    id: str +    name: str +    color: int +    hoist: bool +    icon: typing.Optional[str] +    unicode_emoji: typing.Optional[str] +    position: int +    permissions: str +    managed: bool +    mentionable: bool +    tags: typing.Optional[RoleTags] diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py index 9f246ba..0eca15b 100644 --- a/backend/models/discord_user.py +++ b/backend/models/discord_user.py @@ -1,10 +1,11 @@ +import datetime  import typing as t  from pydantic import BaseModel -class DiscordUser(BaseModel): -    """Schema model of Discord user for form response.""" +class _User(BaseModel): +    """Base for discord users and members."""      # Discord default fields.      username: str @@ -20,5 +21,34 @@ class DiscordUser(BaseModel):      premium_type: t.Optional[int]      public_flags: t.Optional[int] + +class DiscordUser(_User): +    """Schema model of Discord user for form response.""" +      # Custom fields      admin: bool + + +class DiscordMember(BaseModel): +    """A discord guild member.""" + +    user: _User +    nick: t.Optional[str] +    avatar: t.Optional[str] +    roles: list[str] +    joined_at: datetime.datetime +    premium_since: t.Optional[datetime.datetime] +    deaf: bool +    mute: bool +    pending: t.Optional[bool] +    permissions: t.Optional[str] +    communication_disabled_until: t.Optional[datetime.datetime] + +    def dict(self, *args, **kwargs) -> dict[str, t.Any]: +        """Convert the model to a python dict, and encode timestamps in a serializable format.""" +        data = super().dict(*args, **kwargs) +        for field, value in data.items(): +            if isinstance(value, datetime.datetime): +                data[field] = value.isoformat() + +        return data diff --git a/backend/models/form.py b/backend/models/form.py index f19ed85..f888d6e 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -1,10 +1,10 @@  import typing as t  import httpx -from pydantic import constr, BaseModel, Field, root_validator, validator +from pydantic import BaseModel, Field, constr, root_validator, validator  from pydantic.error_wrappers import ErrorWrapper, ValidationError -from backend.constants import FormFeatures, WebHook +from backend.constants import DISCORD_GUILD, FormFeatures, WebHook  from .question import Question  PUBLIC_FIELDS = [ @@ -43,6 +43,8 @@ class Form(BaseModel):      submitted_text: t.Optional[str] = None      webhook: _WebHook = None      discord_role: t.Optional[str] +    response_readers: t.Optional[list[str]] +    editors: t.Optional[list[str]]      class Config:          allow_population_by_field_name = True @@ -67,6 +69,13 @@ class Form(BaseModel):          return value +    @validator("response_readers", "editors") +    def validate_role_scoping(cls, value: t.Optional[list[str]]) -> t.Optional[list[str]]: +        """Ensure special role based permissions aren't granted to the @everyone role.""" +        if value and str(DISCORD_GUILD) in value: +            raise ValueError("You can not add the everyone role as an access scope.") +        return value +      @root_validator      def validate_role(cls, values: dict[str, t.Any]) -> t.Optional[dict[str, t.Any]]:          """Validates does Discord role provided when flag provided.""" diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index d4587f0..42fb3ec 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -17,7 +17,7 @@ from starlette.requests import Request  from backend import constants  from backend.authentication.user import User  from backend.constants import SECRET_KEY -from backend.discord import fetch_bearer_token, fetch_user_details +from backend.discord import fetch_bearer_token, fetch_user_details, get_member  from backend.route import Route  from backend.validation import ErrorMessage, api @@ -34,8 +34,8 @@ class AuthorizeResponse(BaseModel):  async def process_token( -        bearer_token: dict, -        request: Request +    bearer_token: dict, +    request: Request  ) -> Union[AuthorizeResponse, AUTH_FAILURE]:      """Post a bearer token to Discord, and return a JWT and username."""      interaction_start = datetime.datetime.now() @@ -46,6 +46,9 @@ async def process_token(          AUTH_FAILURE.delete_cookie("token")          return AUTH_FAILURE +    user_id = user_details["id"] +    member = await get_member(request.state.db, user_id, force_refresh=True) +      max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))      token_expiry = interaction_start + max_age @@ -53,11 +56,12 @@ async def process_token(          "token": bearer_token["access_token"],          "refresh": bearer_token["refresh_token"],          "user_details": user_details, +        "in_guild": bool(member),          "expiry": token_expiry.isoformat()      }      token = jwt.encode(data, SECRET_KEY, algorithm="HS256") -    user = User(token, user_details) +    user = User(token, user_details, member)      response = responses.JSONResponse({          "username": user.display_name, diff --git a/backend/routes/discord.py b/backend/routes/discord.py new file mode 100644 index 0000000..bca1edb --- /dev/null +++ b/backend/routes/discord.py @@ -0,0 +1,83 @@ +"""Routes which directly interact with discord related data.""" + +import pydantic +from spectree import Response +from starlette.authentication import requires +from starlette.responses import JSONResponse +from starlette.routing import Request + +from backend import discord, models, route +from backend.validation import ErrorMessage, api + +NOT_FOUND_EXCEPTION = JSONResponse( +    {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 +) + + +class RolesRoute(route.Route): +    """Refreshes the roles database.""" + +    name = "roles" +    path = "/roles" + +    class RolesResponse(pydantic.BaseModel): +        """A list of all roles on the configured server.""" + +        roles: list[models.DiscordRole] + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=RolesResponse), +        tags=["roles"] +    ) +    async def patch(self, request: Request) -> JSONResponse: +        """Refresh the roles database.""" +        roles = await discord.get_roles(request.state.db, force_refresh=True) + +        return JSONResponse( +            {"roles": [role.dict() for role in roles]}, +        ) + + +class MemberRoute(route.Route): +    """Retrieve information about a server member.""" + +    name = "member" +    path = "/member" + +    class MemberRequest(pydantic.BaseModel): +        """An ID of the member to update.""" + +        user_id: str + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), +        json=MemberRequest, +        tags=["auth"] +    ) +    async def delete(self, request: Request) -> JSONResponse: +        """Force a resync of the cache for the given user.""" +        body = await request.json() +        member = await discord.get_member(request.state.db, body["user_id"], force_refresh=True) + +        if member: +            return JSONResponse(member.dict()) +        else: +            return NOT_FOUND_EXCEPTION + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), +        json=MemberRequest, +        tags=["auth"] +    ) +    async def get(self, request: Request) -> JSONResponse: +        """Get a user's roles on the configured server.""" +        body = await request.json() +        member = await discord.get_member(request.state.db, body["user_id"]) + +        if member: +            return JSONResponse(member.dict()) +        else: +            return NOT_FOUND_EXCEPTION diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py index d7351d5..b993075 100644 --- a/backend/routes/forms/discover.py +++ b/backend/routes/forms/discover.py @@ -29,7 +29,7 @@ EMPTY_FORM = Form(      features=__FEATURES,      questions=[__QUESTION],      name="Auth form", -    description="An empty form to help you get a token." +    description="An empty form to help you get a token.",  ) diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index 0f96b85..567c197 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -10,13 +10,15 @@ from starlette.authentication import requires  from starlette.requests import Request  from starlette.responses import JSONResponse -from backend import constants +from backend import constants, discord  from backend.models import Form  from backend.route import Route  from backend.routes.forms.discover import EMPTY_FORM  from backend.routes.forms.unittesting import filter_unittests  from backend.validation import ErrorMessage, OkayResponse, api +PUBLIC_FORM_FEATURES = (constants.FormFeatures.OPEN, constants.FormFeatures.DISCOVERABLE) +  class SingleForm(Route):      """ @@ -31,9 +33,19 @@ class SingleForm(Route):      @api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])      async def get(self, request: Request) -> JSONResponse:          """Returns single form information by ID.""" -        admin = request.user.admin if request.user.is_authenticated else False          form_id = request.path_params["form_id"].lower() +        try: +            await discord.verify_edit_access(form_id, request) +            admin = True +        except discord.FormNotFoundError: +            if not constants.PRODUCTION and form_id == EMPTY_FORM.id: +                # Empty form to help with authentication in development. +                return JSONResponse(EMPTY_FORM.dict(admin=False)) +            raise +        except discord.UnauthorizedError: +            admin = False +          filters = {              "_id": form_id          } @@ -41,25 +53,18 @@ class SingleForm(Route):          if not admin:              filters["features"] = {"$in": ["OPEN", "DISCOVERABLE"]} -        if raw_form := await request.state.db.forms.find_one(filters): -            form = Form(**raw_form) -            if not admin: -                form = filter_unittests(form) - -            return JSONResponse(form.dict(admin=admin)) - -        elif not constants.PRODUCTION and form_id == EMPTY_FORM.id: -            # Empty form to help with authentication in development. -            return JSONResponse(EMPTY_FORM.dict(admin=admin)) +        form = Form(**await request.state.db.forms.find_one(filters)) +        if not admin: +            form = filter_unittests(form) -        return JSONResponse({"error": "not_found"}, status_code=404) +        return JSONResponse(form.dict(admin=admin)) -    @requires(["authenticated", "admin"]) +    @requires(["authenticated"])      @api.validate(          resp=Response(              HTTP_200=OkayResponse,              HTTP_400=ErrorMessage, -            HTTP_404=ErrorMessage +            HTTP_404=ErrorMessage,          ),          tags=["forms"]      ) @@ -70,10 +75,12 @@ class SingleForm(Route):          except json.decoder.JSONDecodeError:              return JSONResponse("Expected a JSON body.", 400) -        form_id = {"_id": request.path_params["form_id"].lower()} -        if raw_form := await request.state.db.forms.find_one(form_id): +        form_id = request.path_params["form_id"].lower() +        await discord.verify_edit_access(form_id, request) + +        if raw_form := await request.state.db.forms.find_one({"_id": form_id}):              if "_id" in data or "id" in data: -                if (data.get("id") or data.get("_id")) != form_id["_id"]: +                if (data.get("id") or data.get("_id")) != form_id:                      return JSONResponse({"error": "locked_field"}, status_code=400)              # Build Data Merger @@ -90,7 +97,7 @@ class SingleForm(Route):              except ValidationError as e:                  return JSONResponse(e.errors(), status_code=422) -            await request.state.db.forms.replace_one(form_id, form.dict()) +            await request.state.db.forms.replace_one({"_id": form_id}, form.dict())              return JSONResponse(form.dict())          else: @@ -98,21 +105,15 @@ class SingleForm(Route):      @requires(["authenticated", "admin"])      @api.validate( -        resp=Response(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), +        resp=Response(HTTP_200=OkayResponse, HTTP_401=ErrorMessage, HTTP_404=ErrorMessage),          tags=["forms"]      )      async def delete(self, request: Request) -> JSONResponse:          """Deletes form by ID.""" -        if not await request.state.db.forms.find_one( -            {"_id": request.path_params["form_id"].lower()} -        ): -            return JSONResponse({"error": "not_found"}, status_code=404) +        form_id = request.path_params["form_id"].lower() +        await discord.verify_edit_access(form_id, request) -        await request.state.db.forms.delete_one( -            {"_id": request.path_params["form_id"].lower()} -        ) -        await request.state.db.responses.delete_many( -            {"form_id": request.path_params["form_id"].lower()} -        ) +        await request.state.db.forms.delete_one({"_id": form_id}) +        await request.state.db.responses.delete_many({"form_id": form_id})          return JSONResponse({"status": "ok"}) diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py index 22171fa..38be693 100644 --- a/backend/routes/forms/index.py +++ b/backend/routes/forms/index.py @@ -15,13 +15,13 @@ from backend.validation import ErrorMessage, OkayResponse, api  class FormsList(Route):      """ -    List all available forms for administrator viewing. +    List all available forms for authorized viewers.      """      name = "forms_list_create"      path = "/" -    @requires(["authenticated", "admin"]) +    @requires(["authenticated", "Admins"])      @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])      async def get(self, request: Request) -> JSONResponse:          """Return a list of all forms to authenticated users.""" @@ -38,7 +38,7 @@ class FormsList(Route):              forms          ) -    @requires(["authenticated", "admin"]) +    @requires(["authenticated", "Helpers"])      @api.validate(          json=Form,          resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py index d8d8d17..565701f 100644 --- a/backend/routes/forms/response.py +++ b/backend/routes/forms/response.py @@ -1,11 +1,13 @@  """  Returns or deletes form response by ID.  """ +  from spectree import Response as RouteResponse  from starlette.authentication import requires  from starlette.requests import Request  from starlette.responses import JSONResponse +from backend import discord  from backend.models import FormResponse  from backend.route import Route  from backend.validation import ErrorMessage, OkayResponse, api @@ -17,23 +19,26 @@ class Response(Route):      name = "response"      path = "/{form_id:str}/responses/{response_id:str}" -    @requires(["authenticated", "admin"]) +    @requires(["authenticated"])      @api.validate(          resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage),          tags=["forms", "responses"]      )      async def get(self, request: Request) -> JSONResponse:          """Return a single form response by ID.""" +        form_id = request.path_params["form_id"] +        await discord.verify_response_access(form_id, request) +          if raw_response := await request.state.db.responses.find_one(              {                  "_id": request.path_params["response_id"], -                "form_id": request.path_params["form_id"] +                "form_id": form_id              }          ):              response = FormResponse(**raw_response)              return JSONResponse(response.dict())          else: -            return JSONResponse({"error": "not_found"}, status_code=404) +            return JSONResponse({"error": "response_not_found"}, status_code=404)      @requires(["authenticated", "admin"])      @api.validate( diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py index f3c4cd7..818ebce 100644 --- a/backend/routes/forms/responses.py +++ b/backend/routes/forms/responses.py @@ -7,9 +7,10 @@ from starlette.authentication import requires  from starlette.requests import Request  from starlette.responses import JSONResponse +from backend import discord  from backend.models import FormResponse, ResponseList  from backend.route import Route -from backend.validation import api, ErrorMessage, OkayResponse +from backend.validation import ErrorMessage, OkayResponse, api  class ResponseIdList(BaseModel): @@ -24,20 +25,18 @@ class Responses(Route):      name = "form_responses"      path = "/{form_id:str}/responses" -    @requires(["authenticated", "admin"]) +    @requires(["authenticated"])      @api.validate( -        resp=Response(HTTP_200=ResponseList, HTTP_404=ErrorMessage), +        resp=Response(HTTP_200=ResponseList),          tags=["forms", "responses"]      )      async def get(self, request: Request) -> JSONResponse:          """Returns all form responses by form ID.""" -        if not await request.state.db.forms.find_one( -            {"_id": request.path_params["form_id"]} -        ): -            return JSONResponse({"error": "not_found"}, 404) +        form_id = request.path_params["form_id"] +        await discord.verify_response_access(form_id, request)          cursor = request.state.db.responses.find( -            {"form_id": request.path_params["form_id"]} +            {"form_id": form_id}          )          responses = [              FormResponse(**response) for response in await cursor.to_list(None) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 95e30b0..baf403d 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -83,7 +83,7 @@ class SubmitForm(Route):          try:              if hasattr(request.user, User.refresh_data.__name__):                  old = request.user.token -                await request.user.refresh_data() +                await request.user.refresh_data(request.state.db)                  if old != request.user.token:                      try: | 
