aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes/auth/authorize.py
blob: 25091099dae9d4cd641819bfedb74618c3c1aa42 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Use a token received from the Discord OAuth2 system to fetch user information.
"""

import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
from starlette.requests import Request
from starlette.responses import JSONResponse

from backend.constants import SECRET_KEY
from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
from backend.validation import ErrorMessage, api


class AuthorizeRequest(BaseModel):
    token: str = Field(description="The access token received from Discord.")


class AuthorizeResponse(BaseModel):
    token: str = Field(description="A JWT token containing the user information")


class AuthorizeRoute(Route):
    """
    Use the authorization code from Discord to generate a JWT token.
    """

    name = "authorize"
    path = "/authorize"

    @api.validate(
        json=AuthorizeRequest,
        resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
        tags=["auth"]
    )
    async def post(self, request: Request) -> JSONResponse:
        """Generate an authorization token."""
        data = await request.json()

        try:
            bearer_token = await fetch_bearer_token(data["token"])
            user_details = await fetch_user_details(bearer_token["access_token"])
        except httpx.HTTPStatusError:
            return JSONResponse({
                "error": "auth_failure"
            }, status_code=400)

        user_details["admin"] = await request.state.db.admins.find_one(
            {"_id": user_details["id"]}
        ) is not None

        token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")

        return JSONResponse({
            "token": token.decode()
        })