diff options
Diffstat (limited to 'backend/routes/auth')
-rw-r--r-- | backend/routes/auth/authorize.py | 121 |
1 files changed, 106 insertions, 15 deletions
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 975936a..d4587f0 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -2,26 +2,101 @@ Use a token received from the Discord OAuth2 system to fetch user information. """ +import datetime +from typing import Union + import httpx import jwt from pydantic.fields import Field from pydantic.main import BaseModel from spectree.response import Response +from starlette import responses +from starlette.authentication import requires from starlette.requests import Request -from starlette.responses import JSONResponse +from backend import constants +from backend.authentication.user import User from backend.constants import SECRET_KEY -from backend.route import Route from backend.discord import fetch_bearer_token, fetch_user_details +from backend.route import Route from backend.validation import ErrorMessage, api +AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400) + class AuthorizeRequest(BaseModel): token: str = Field(description="The access token received from Discord.") class AuthorizeResponse(BaseModel): - token: str = Field(description="A JWT token containing the user information") + username: str = Field("Discord display name.") + expiry: str = Field("ISO formatted timestamp of expiry.") + + +async def process_token( + bearer_token: dict, + request: Request +) -> Union[AuthorizeResponse, AUTH_FAILURE]: + """Post a bearer token to Discord, and return a JWT and username.""" + interaction_start = datetime.datetime.now() + + try: + user_details = await fetch_user_details(bearer_token["access_token"]) + except httpx.HTTPStatusError: + AUTH_FAILURE.delete_cookie("token") + return AUTH_FAILURE + + max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"])) + token_expiry = interaction_start + max_age + + data = { + "token": bearer_token["access_token"], + "refresh": bearer_token["refresh_token"], + "user_details": user_details, + "expiry": token_expiry.isoformat() + } + + token = jwt.encode(data, SECRET_KEY, algorithm="HS256") + user = User(token, user_details) + + response = responses.JSONResponse({ + "username": user.display_name, + "expiry": token_expiry.isoformat() + }) + + await set_response_token(response, request, token, bearer_token["expires_in"]) + return response + + +async def set_response_token( + response: responses.Response, + request: Request, + new_token: str, + expiry: int +) -> None: + """Helper that handles logic for updating a token in a set-cookie response.""" + origin_url = request.headers.get("origin") + + if origin_url == constants.PRODUCTION_URL: + domain = request.url.netloc + samesite = "strict" + + elif not constants.PRODUCTION: + domain = None + samesite = "strict" + + else: + domain = request.url.netloc + samesite = "None" + + response.set_cookie( + "token", f"JWT {new_token}", + secure=constants.PRODUCTION, + httponly=True, + samesite=samesite, + domain=domain, + max_age=expiry + ) class AuthorizeRoute(Route): @@ -37,22 +112,38 @@ class AuthorizeRoute(Route): resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), tags=["auth"] ) - async def post(self, request: Request) -> JSONResponse: + async def post(self, request: Request) -> responses.JSONResponse: """Generate an authorization token.""" data = await request.json() - try: - bearer_token = await fetch_bearer_token(data["token"]) - user_details = await fetch_user_details(bearer_token["access_token"]) + url = request.headers.get("origin") + bearer_token = await fetch_bearer_token(data["token"], url, refresh=False) except httpx.HTTPStatusError: - return JSONResponse({ - "error": "auth_failure" - }, status_code=400) + return AUTH_FAILURE - user_details["admin"] = await request.state.db.admins.find_one( - {"_id": user_details["id"]} - ) is not None + return await process_token(bearer_token, request) + + +class TokenRefreshRoute(Route): + """ + Use the refresh code from a JWT to get a new token and generate a new JWT token. + """ - token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256") + name = "refresh" + path = "/refresh" + + @requires(["authenticated"]) + @api.validate( + resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), + tags=["auth"] + ) + async def post(self, request: Request) -> responses.JSONResponse: + """Refresh an authorization token.""" + try: + token = request.user.decoded_token.get("refresh") + url = request.headers.get("origin") + bearer_token = await fetch_bearer_token(token, url, refresh=True) + except httpx.HTTPStatusError: + return AUTH_FAILURE - return JSONResponse({"token": token}) + return await process_token(bearer_token, request) |