aboutsummaryrefslogtreecommitdiffstats
path: root/backend/routes/auth
diff options
context:
space:
mode:
Diffstat (limited to 'backend/routes/auth')
-rw-r--r--backend/routes/auth/authorize.py121
1 files changed, 106 insertions, 15 deletions
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 975936a..d4587f0 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,26 +2,101 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import datetime
+from typing import Union
+
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette import responses
+from starlette.authentication import requires
from starlette.requests import Request
-from starlette.responses import JSONResponse
+from backend import constants
+from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.route import Route
from backend.validation import ErrorMessage, api
+AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)
+
class AuthorizeRequest(BaseModel):
token: str = Field(description="The access token received from Discord.")
class AuthorizeResponse(BaseModel):
- token: str = Field(description="A JWT token containing the user information")
+ username: str = Field("Discord display name.")
+ expiry: str = Field("ISO formatted timestamp of expiry.")
+
+
+async def process_token(
+ bearer_token: dict,
+ request: Request
+) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ """Post a bearer token to Discord, and return a JWT and username."""
+ interaction_start = datetime.datetime.now()
+
+ try:
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ AUTH_FAILURE.delete_cookie("token")
+ return AUTH_FAILURE
+
+ max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
+ token_expiry = interaction_start + max_age
+
+ data = {
+ "token": bearer_token["access_token"],
+ "refresh": bearer_token["refresh_token"],
+ "user_details": user_details,
+ "expiry": token_expiry.isoformat()
+ }
+
+ token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
+ user = User(token, user_details)
+
+ response = responses.JSONResponse({
+ "username": user.display_name,
+ "expiry": token_expiry.isoformat()
+ })
+
+ await set_response_token(response, request, token, bearer_token["expires_in"])
+ return response
+
+
+async def set_response_token(
+ response: responses.Response,
+ request: Request,
+ new_token: str,
+ expiry: int
+) -> None:
+ """Helper that handles logic for updating a token in a set-cookie response."""
+ origin_url = request.headers.get("origin")
+
+ if origin_url == constants.PRODUCTION_URL:
+ domain = request.url.netloc
+ samesite = "strict"
+
+ elif not constants.PRODUCTION:
+ domain = None
+ samesite = "strict"
+
+ else:
+ domain = request.url.netloc
+ samesite = "None"
+
+ response.set_cookie(
+ "token", f"JWT {new_token}",
+ secure=constants.PRODUCTION,
+ httponly=True,
+ samesite=samesite,
+ domain=domain,
+ max_age=expiry
+ )
class AuthorizeRoute(Route):
@@ -37,22 +112,38 @@ class AuthorizeRoute(Route):
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
- async def post(self, request: Request) -> JSONResponse:
+ async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
data = await request.json()
-
try:
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
- return JSONResponse({
- "error": "auth_failure"
- }, status_code=400)
+ return AUTH_FAILURE
- user_details["admin"] = await request.state.db.admins.find_one(
- {"_id": user_details["id"]}
- ) is not None
+ return await process_token(bearer_token, request)
+
+
+class TokenRefreshRoute(Route):
+ """
+ Use the refresh code from a JWT to get a new token and generate a new JWT token.
+ """
- token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
+ name = "refresh"
+ path = "/refresh"
+
+ @requires(["authenticated"])
+ @api.validate(
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
+ async def post(self, request: Request) -> responses.JSONResponse:
+ """Refresh an authorization token."""
+ try:
+ token = request.user.decoded_token.get("refresh")
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(token, url, refresh=True)
+ except httpx.HTTPStatusError:
+ return AUTH_FAILURE
- return JSONResponse({"token": token})
+ return await process_token(bearer_token, request)