diff options
Diffstat (limited to 'backend')
-rw-r--r-- | backend/authentication/user.py | 17 | ||||
-rw-r--r-- | backend/discord.py | 11 | ||||
-rw-r--r-- | backend/routes/auth/authorize.py | 81 |
3 files changed, 93 insertions, 16 deletions
diff --git a/backend/authentication/user.py b/backend/authentication/user.py index f40c68c..a1d78e5 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,7 +1,11 @@ import typing as t +import jwt from starlette.authentication import BaseUser +from backend.constants import SECRET_KEY +from backend.discord import fetch_user_details + class User(BaseUser): """Starlette BaseUser implementation for JWT authentication.""" @@ -23,3 +27,16 @@ class User(BaseUser): @property def discord_mention(self) -> str: return f"<@{self.payload['id']}>" + + @property + def decoded_token(self) -> dict[str, any]: + return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) + + async def refresh_data(self) -> None: + """Fetches user data from discord, and updates the instance.""" + self.payload = await fetch_user_details(self.decoded_token.get("token")) + + updated_info = self.decoded_token + updated_info["user_details"] = self.payload + + self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256") diff --git a/backend/discord.py b/backend/discord.py index d6310b7..9cdd2c4 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -8,16 +8,21 @@ from backend.constants import ( API_BASE_URL = "https://discord.com/api/v8" -async def fetch_bearer_token(access_code: str) -> dict: +async def fetch_bearer_token(code: str, *, refresh: bool) -> dict: async with httpx.AsyncClient() as client: data = { "client_id": OAUTH2_CLIENT_ID, "client_secret": OAUTH2_CLIENT_SECRET, - "grant_type": "authorization_code", - "code": access_code, "redirect_uri": OAUTH2_REDIRECT_URI } + if refresh: + data["grant_type"] = "refresh_token" + data["refresh_token"] = code + else: + data["grant_type"] = "authorization_code" + data["code"] = code + r = await client.post(f"{API_BASE_URL}/oauth2/token", headers={ "Content-Type": "application/x-www-form-urlencoded" }, data=data) diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 975936a..2244152 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -2,17 +2,23 @@ Use a token received from the Discord OAuth2 system to fetch user information. """ +import datetime +from typing import Union + import httpx import jwt from pydantic.fields import Field from pydantic.main import BaseModel from spectree.response import Response +from starlette.authentication import requires from starlette.requests import Request from starlette.responses import JSONResponse +from backend import constants +from backend.authentication.user import User from backend.constants import SECRET_KEY -from backend.route import Route from backend.discord import fetch_bearer_token, fetch_user_details +from backend.route import Route from backend.validation import ErrorMessage, api @@ -21,7 +27,42 @@ class AuthorizeRequest(BaseModel): class AuthorizeResponse(BaseModel): - token: str = Field(description="A JWT token containing the user information") + username: str = Field("Discord display name.") + + +AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400) + + +async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]: + """Post a bearer token to Discord, and return a JWT and username.""" + interaction_start = datetime.datetime.now() + + try: + user_details = await fetch_user_details(bearer_token["access_token"]) + except httpx.HTTPStatusError: + AUTH_FAILURE.delete_cookie("BackendToken") + return AUTH_FAILURE + + max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"])) + token_expiry = interaction_start + max_age + + data = { + "token": bearer_token["access_token"], + "refresh": bearer_token["refresh_token"], + "user_details": user_details, + "expiry": token_expiry.isoformat() + } + + token = jwt.encode(data, SECRET_KEY, algorithm="HS256") + user = User(token, user_details) + + response = JSONResponse({"username": user.display_name}) + response.set_cookie( + "BackendToken", f"JWT {token}", + secure=constants.PRODUCTION, httponly=True, samesite="strict", + max_age=bearer_token["expires_in"] + ) + return response class AuthorizeRoute(Route): @@ -40,19 +81,33 @@ class AuthorizeRoute(Route): async def post(self, request: Request) -> JSONResponse: """Generate an authorization token.""" data = await request.json() - try: - bearer_token = await fetch_bearer_token(data["token"]) - user_details = await fetch_user_details(bearer_token["access_token"]) + bearer_token = await fetch_bearer_token(data["token"], refresh=False) except httpx.HTTPStatusError: - return JSONResponse({ - "error": "auth_failure" - }, status_code=400) + return AUTH_FAILURE + + return await process_token(bearer_token) - user_details["admin"] = await request.state.db.admins.find_one( - {"_id": user_details["id"]} - ) is not None - token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256") +class TokenRefreshRoute(Route): + """ + Use the refresh code from a JWT to get a new token and generate a new JWT token. + """ + + name = "refresh" + path = "/refresh" + + @requires(["authenticated"]) + @api.validate( + resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), + tags=["auth"] + ) + async def post(self, request: Request) -> JSONResponse: + """Refresh an authorization token.""" + try: + token = request.user.decoded_token.get("refresh") + bearer_token = await fetch_bearer_token(token, refresh=True) + except httpx.HTTPStatusError: + return AUTH_FAILURE - return JSONResponse({"token": token}) + return await process_token(bearer_token) |