aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes/auth/authorize.py
blob: 7f18cb48e3c2f1864f2ea5212c83fcf157dadea9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Use a token received from the Discord OAuth2 system to fetch user information."""

import datetime

import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
from starlette import responses
from starlette.authentication import requires
from starlette.requests import Request

from backend import constants
from backend.authentication.user import User
from backend.constants import SECRET_KEY
from backend.discord import fetch_bearer_token, fetch_user_details, get_member
from backend.route import Route
from backend.validation import ErrorMessage, api

AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)


class AuthorizeRequest(BaseModel):
    token: str = Field(description="The access token received from Discord.")


class AuthorizeResponse(BaseModel):
    username: str = Field("Discord display name.")
    expiry: str = Field("ISO formatted timestamp of expiry.")


async def process_token(
    bearer_token: dict,
    request: Request,
) -> AuthorizeResponse | responses.JSONResponse:
    """Post a bearer token to Discord, and return a JWT and username."""
    interaction_start = datetime.datetime.now()

    try:
        user_details = await fetch_user_details(bearer_token["access_token"])
    except httpx.HTTPStatusError:
        AUTH_FAILURE.delete_cookie("token")
        return AUTH_FAILURE

    user_id = user_details["id"]
    member = await get_member(user_id, force_refresh=True)

    max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
    token_expiry = interaction_start + max_age

    data = {
        "token": bearer_token["access_token"],
        "refresh": bearer_token["refresh_token"],
        "user_details": user_details,
        "in_guild": bool(member),
        # Legacy key, we should use exp and use JWT expiry as below it.
        "expiry": token_expiry.isoformat(),
        # Correct JWT expiry key:
        "exp": token_expiry
    }

    token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
    user = User(token, user_details, member)

    response = responses.JSONResponse({
        "username": user.display_name,
        "expiry": token_expiry.isoformat(),
    })

    set_response_token(response, request, token, bearer_token["expires_in"])
    return response


def set_response_token(
    response: responses.Response,
    request: Request,
    new_token: str,
    expiry: int,
) -> None:
    """Helper that handles logic for updating a token in a set-cookie response."""
    origin_url = request.headers.get("origin")

    if origin_url == constants.PRODUCTION_URL:
        domain = request.url.netloc
        samesite = "strict"

    elif not constants.PRODUCTION:
        domain = None
        samesite = "strict"

    else:
        domain = request.url.netloc
        samesite = "None"

    response.set_cookie(
        "token",
        f"JWT {new_token}",
        secure=constants.PRODUCTION,
        httponly=True,
        samesite=samesite,
        domain=domain,
        max_age=expiry,
    )


class AuthorizeRoute(Route):
    """Use the authorization code from Discord to generate a JWT token."""

    name = "authorize"
    path = "/authorize"

    @api.validate(
        json=AuthorizeRequest,
        resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
        tags=["auth"],
    )
    async def post(self, request: Request) -> responses.JSONResponse:
        """Generate an authorization token."""
        data = await request.json()
        try:
            url = request.headers.get("origin")
            bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
        except httpx.HTTPStatusError:
            return AUTH_FAILURE

        return await process_token(bearer_token, request)


class TokenRefreshRoute(Route):
    """Use the refresh code from a JWT to get a new token and generate a new JWT token."""

    name = "refresh"
    path = "/refresh"

    @requires(["authenticated"])
    @api.validate(
        resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
        tags=["auth"],
    )
    async def post(self, request: Request) -> responses.JSONResponse:
        """Refresh an authorization token."""
        try:
            token = request.user.decoded_token.get("refresh")
            url = request.headers.get("origin")
            bearer_token = await fetch_bearer_token(token, url, refresh=True)
        except httpx.HTTPStatusError:
            return AUTH_FAILURE

        return await process_token(bearer_token, request)