diff options
Diffstat (limited to 'backend/routes/auth/authorize.py')
-rw-r--r-- | backend/routes/auth/authorize.py | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 26d8622..1e773d6 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -10,9 +10,9 @@ import jwt from pydantic.fields import Field from pydantic.main import BaseModel from spectree.response import Response +from starlette import responses from starlette.authentication import requires from starlette.requests import Request -from starlette.responses import JSONResponse from backend import constants from backend.authentication.user import User @@ -21,7 +21,7 @@ from backend.discord import fetch_bearer_token, fetch_user_details from backend.route import Route from backend.validation import ErrorMessage, api -AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400) +AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400) class AuthorizeRequest(BaseModel): @@ -33,7 +33,7 @@ class AuthorizeResponse(BaseModel): expiry: str = Field("ISO formatted timestamp of expiry.") -async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]: +async def process_token(bearer_token: dict, origin: str) -> Union[AuthorizeResponse, AUTH_FAILURE]: """Post a bearer token to Discord, and return a JWT and username.""" interaction_start = datetime.datetime.now() @@ -56,17 +56,42 @@ async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAI token = jwt.encode(data, SECRET_KEY, algorithm="HS256") user = User(token, user_details) - response = JSONResponse({ + response = responses.JSONResponse({ "username": user.display_name, "expiry": token_expiry.isoformat() }) + await set_response_token(response, origin, token, bearer_token["expires_in"]) + return response + + +async def set_response_token( + response: responses.Response, + origin_url: str, + new_token: str, + expiry: int +) -> None: + """Helper that handles logic for updating a token in a set-cookie response.""" + if origin_url == constants.PRODUCTION_URL: + domain = constants.PRODUCTION_URL + samesite = "strict" + + elif not constants.PRODUCTION: + domain = None + samesite = "strict" + + else: + domain = origin_url + samesite = "None" + response.set_cookie( - "token", f"JWT {token}", - secure=constants.PRODUCTION, httponly=True, samesite="strict", - max_age=bearer_token["expires_in"] + "token", f"JWT {new_token}", + secure=constants.PRODUCTION, + httponly=True, + samesite=samesite, + domain=domain, + max_age=expiry ) - return response class AuthorizeRoute(Route): @@ -82,7 +107,7 @@ class AuthorizeRoute(Route): resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), tags=["auth"] ) - async def post(self, request: Request) -> JSONResponse: + async def post(self, request: Request) -> responses.JSONResponse: """Generate an authorization token.""" data = await request.json() try: @@ -91,7 +116,7 @@ class AuthorizeRoute(Route): except httpx.HTTPStatusError: return AUTH_FAILURE - return await process_token(bearer_token) + return await process_token(bearer_token, url) class TokenRefreshRoute(Route): @@ -107,7 +132,7 @@ class TokenRefreshRoute(Route): resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), tags=["auth"] ) - async def post(self, request: Request) -> JSONResponse: + async def post(self, request: Request) -> responses.JSONResponse: """Refresh an authorization token.""" try: token = request.user.decoded_token.get("refresh") @@ -116,4 +141,4 @@ class TokenRefreshRoute(Route): except httpx.HTTPStatusError: return AUTH_FAILURE - return await process_token(bearer_token) + return await process_token(bearer_token, url) |