1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
"""Use a token received from the Discord OAuth2 system to fetch user information."""
import datetime
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
from starlette import responses
from starlette.authentication import requires
from starlette.requests import Request
from backend import constants
from backend.authentication.user import User
from backend.constants import SECRET_KEY
from backend.discord import fetch_bearer_token, fetch_user_details, get_member
from backend.route import Route
from backend.validation import ErrorMessage, api
AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)
class AuthorizeRequest(BaseModel):
token: str = Field(description="The access token received from Discord.")
class AuthorizeResponse(BaseModel):
username: str = Field("Discord display name.")
expiry: str = Field("ISO formatted timestamp of expiry.")
async def process_token(
bearer_token: dict,
request: Request,
) -> AuthorizeResponse | responses.JSONResponse:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()
try:
user_details = await fetch_user_details(bearer_token["access_token"])
except httpx.HTTPStatusError:
AUTH_FAILURE.delete_cookie("token")
return AUTH_FAILURE
user_id = user_details["id"]
member = await get_member(user_id, force_refresh=True)
max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
token_expiry = interaction_start + max_age
data = {
"token": bearer_token["access_token"],
"refresh": bearer_token["refresh_token"],
"user_details": user_details,
"in_guild": bool(member),
# Legacy key, we should use exp and use JWT expiry as below it.
"expiry": token_expiry.isoformat(),
# Correct JWT expiry key:
"exp": token_expiry
}
token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
user = User(token, user_details, member)
response = responses.JSONResponse({
"username": user.display_name,
"expiry": token_expiry.isoformat(),
})
set_response_token(response, request, token, bearer_token["expires_in"])
return response
def set_response_token(
response: responses.Response,
request: Request,
new_token: str,
expiry: int,
) -> None:
"""Helper that handles logic for updating a token in a set-cookie response."""
origin_url = request.headers.get("origin")
if origin_url == constants.PRODUCTION_URL:
domain = request.url.netloc
samesite = "strict"
elif not constants.PRODUCTION:
domain = None
samesite = "strict"
else:
domain = request.url.netloc
samesite = "None"
response.set_cookie(
"token",
f"JWT {new_token}",
secure=constants.PRODUCTION,
httponly=True,
samesite=samesite,
domain=domain,
max_age=expiry,
)
class AuthorizeRoute(Route):
"""Use the authorization code from Discord to generate a JWT token."""
name = "authorize"
path = "/authorize"
@api.validate(
json=AuthorizeRequest,
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"],
)
async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
data = await request.json()
try:
url = request.headers.get("origin")
bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
return AUTH_FAILURE
return await process_token(bearer_token, request)
class TokenRefreshRoute(Route):
"""Use the refresh code from a JWT to get a new token and generate a new JWT token."""
name = "refresh"
path = "/refresh"
@requires(["authenticated"])
@api.validate(
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"],
)
async def post(self, request: Request) -> responses.JSONResponse:
"""Refresh an authorization token."""
try:
token = request.user.decoded_token.get("refresh")
url = request.headers.get("origin")
bearer_token = await fetch_bearer_token(token, url, refresh=True)
except httpx.HTTPStatusError:
return AUTH_FAILURE
return await process_token(bearer_token, request)
|