aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Hassan Abouelela <[email protected]>2021-02-19 09:01:38 +0300
committerGravatar Hassan Abouelela <[email protected]>2021-02-19 09:09:24 +0300
commit7a16a6b129f754a5486c441f2602a8d593edb85f (patch)
tree4153446b809706370824682ca0925e0ad2378ba4 /backend
parentAdds Production Constant (diff)
Adds Token Refresh Route
Signed-off-by: Hassan Abouelela <[email protected]>
Diffstat (limited to 'backend')
-rw-r--r--backend/authentication/user.py17
-rw-r--r--backend/discord.py11
-rw-r--r--backend/routes/auth/authorize.py81
3 files changed, 93 insertions, 16 deletions
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index f40c68c..a1d78e5 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,7 +1,11 @@
import typing as t
+import jwt
from starlette.authentication import BaseUser
+from backend.constants import SECRET_KEY
+from backend.discord import fetch_user_details
+
class User(BaseUser):
"""Starlette BaseUser implementation for JWT authentication."""
@@ -23,3 +27,16 @@ class User(BaseUser):
@property
def discord_mention(self) -> str:
return f"<@{self.payload['id']}>"
+
+ @property
+ def decoded_token(self) -> dict[str, any]:
+ return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
+
+ async def refresh_data(self) -> None:
+ """Fetches user data from discord, and updates the instance."""
+ self.payload = await fetch_user_details(self.decoded_token.get("token"))
+
+ updated_info = self.decoded_token
+ updated_info["user_details"] = self.payload
+
+ self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256")
diff --git a/backend/discord.py b/backend/discord.py
index d6310b7..9cdd2c4 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -8,16 +8,21 @@ from backend.constants import (
API_BASE_URL = "https://discord.com/api/v8"
-async def fetch_bearer_token(access_code: str) -> dict:
+async def fetch_bearer_token(code: str, *, refresh: bool) -> dict:
async with httpx.AsyncClient() as client:
data = {
"client_id": OAUTH2_CLIENT_ID,
"client_secret": OAUTH2_CLIENT_SECRET,
- "grant_type": "authorization_code",
- "code": access_code,
"redirect_uri": OAUTH2_REDIRECT_URI
}
+ if refresh:
+ data["grant_type"] = "refresh_token"
+ data["refresh_token"] = code
+ else:
+ data["grant_type"] = "authorization_code"
+ data["code"] = code
+
r = await client.post(f"{API_BASE_URL}/oauth2/token", headers={
"Content-Type": "application/x-www-form-urlencoded"
}, data=data)
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 975936a..2244152 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,17 +2,23 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import datetime
+from typing import Union
+
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
+from backend import constants
+from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.route import Route
from backend.validation import ErrorMessage, api
@@ -21,7 +27,42 @@ class AuthorizeRequest(BaseModel):
class AuthorizeResponse(BaseModel):
- token: str = Field(description="A JWT token containing the user information")
+ username: str = Field("Discord display name.")
+
+
+AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400)
+
+
+async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ """Post a bearer token to Discord, and return a JWT and username."""
+ interaction_start = datetime.datetime.now()
+
+ try:
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ AUTH_FAILURE.delete_cookie("BackendToken")
+ return AUTH_FAILURE
+
+ max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
+ token_expiry = interaction_start + max_age
+
+ data = {
+ "token": bearer_token["access_token"],
+ "refresh": bearer_token["refresh_token"],
+ "user_details": user_details,
+ "expiry": token_expiry.isoformat()
+ }
+
+ token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
+ user = User(token, user_details)
+
+ response = JSONResponse({"username": user.display_name})
+ response.set_cookie(
+ "BackendToken", f"JWT {token}",
+ secure=constants.PRODUCTION, httponly=True, samesite="strict",
+ max_age=bearer_token["expires_in"]
+ )
+ return response
class AuthorizeRoute(Route):
@@ -40,19 +81,33 @@ class AuthorizeRoute(Route):
async def post(self, request: Request) -> JSONResponse:
"""Generate an authorization token."""
data = await request.json()
-
try:
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ bearer_token = await fetch_bearer_token(data["token"], refresh=False)
except httpx.HTTPStatusError:
- return JSONResponse({
- "error": "auth_failure"
- }, status_code=400)
+ return AUTH_FAILURE
+
+ return await process_token(bearer_token)
- user_details["admin"] = await request.state.db.admins.find_one(
- {"_id": user_details["id"]}
- ) is not None
- token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
+class TokenRefreshRoute(Route):
+ """
+ Use the refresh code from a JWT to get a new token and generate a new JWT token.
+ """
+
+ name = "refresh"
+ path = "/refresh"
+
+ @requires(["authenticated"])
+ @api.validate(
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
+ async def post(self, request: Request) -> JSONResponse:
+ """Refresh an authorization token."""
+ try:
+ token = request.user.decoded_token.get("refresh")
+ bearer_token = await fetch_bearer_token(token, refresh=True)
+ except httpx.HTTPStatusError:
+ return AUTH_FAILURE
- return JSONResponse({"token": token})
+ return await process_token(bearer_token)