aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes
diff options
context:
space:
mode:
Diffstat (limited to 'backend/routes')
-rw-r--r--backend/routes/auth/authorize.py49
-rw-r--r--backend/routes/forms/submit.py9
2 files changed, 41 insertions, 17 deletions
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 26d8622..1e773d6 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -10,9 +10,9 @@ import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette import responses
from starlette.authentication import requires
from starlette.requests import Request
-from starlette.responses import JSONResponse
from backend import constants
from backend.authentication.user import User
@@ -21,7 +21,7 @@ from backend.discord import fetch_bearer_token, fetch_user_details
from backend.route import Route
from backend.validation import ErrorMessage, api
-AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400)
+AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)
class AuthorizeRequest(BaseModel):
@@ -33,7 +33,7 @@ class AuthorizeResponse(BaseModel):
expiry: str = Field("ISO formatted timestamp of expiry.")
-async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+async def process_token(bearer_token: dict, origin: str) -> Union[AuthorizeResponse, AUTH_FAILURE]:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()
@@ -56,17 +56,42 @@ async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAI
token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
user = User(token, user_details)
- response = JSONResponse({
+ response = responses.JSONResponse({
"username": user.display_name,
"expiry": token_expiry.isoformat()
})
+ await set_response_token(response, origin, token, bearer_token["expires_in"])
+ return response
+
+
+async def set_response_token(
+ response: responses.Response,
+ origin_url: str,
+ new_token: str,
+ expiry: int
+) -> None:
+ """Helper that handles logic for updating a token in a set-cookie response."""
+ if origin_url == constants.PRODUCTION_URL:
+ domain = constants.PRODUCTION_URL
+ samesite = "strict"
+
+ elif not constants.PRODUCTION:
+ domain = None
+ samesite = "strict"
+
+ else:
+ domain = origin_url
+ samesite = "None"
+
response.set_cookie(
- "token", f"JWT {token}",
- secure=constants.PRODUCTION, httponly=True, samesite="strict",
- max_age=bearer_token["expires_in"]
+ "token", f"JWT {new_token}",
+ secure=constants.PRODUCTION,
+ httponly=True,
+ samesite=samesite,
+ domain=domain,
+ max_age=expiry
)
- return response
class AuthorizeRoute(Route):
@@ -82,7 +107,7 @@ class AuthorizeRoute(Route):
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
- async def post(self, request: Request) -> JSONResponse:
+ async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
data = await request.json()
try:
@@ -91,7 +116,7 @@ class AuthorizeRoute(Route):
except httpx.HTTPStatusError:
return AUTH_FAILURE
- return await process_token(bearer_token)
+ return await process_token(bearer_token, url)
class TokenRefreshRoute(Route):
@@ -107,7 +132,7 @@ class TokenRefreshRoute(Route):
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
- async def post(self, request: Request) -> JSONResponse:
+ async def post(self, request: Request) -> responses.JSONResponse:
"""Refresh an authorization token."""
try:
token = request.user.decoded_token.get("refresh")
@@ -116,4 +141,4 @@ class TokenRefreshRoute(Route):
except httpx.HTTPStatusError:
return AUTH_FAILURE
- return await process_token(bearer_token)
+ return await process_token(bearer_token, url)
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 8680b2d..975307b 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -20,6 +20,7 @@ from backend import constants
from backend.authentication.user import User
from backend.models import Form, FormResponse
from backend.route import Route
+from backend.routes.auth.authorize import set_response_token
from backend.routes.forms.unittesting import execute_unittest
from backend.validation import ErrorMessage, api
@@ -74,11 +75,9 @@ class SubmitForm(Route):
except ValueError:
expiry = None
- response.set_cookie(
- "token", f"JWT {request.user.token}",
- secure=constants.PRODUCTION, httponly=True, samesite="strict",
- max_age=(expiry - datetime.datetime.now()).seconds
- )
+ origin = request.headers.get("origin")
+ expiry_seconds = (expiry - datetime.datetime.now()).seconds
+ await set_response_token(response, origin, request.user.token, expiry_seconds)
except httpx.HTTPStatusError:
pass