diff options
Diffstat (limited to 'backend/authentication')
-rw-r--r-- | backend/authentication/backend.py | 37 | ||||
-rw-r--r-- | backend/authentication/user.py | 26 |
2 files changed, 52 insertions, 11 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index f1d2ece..c7590e9 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,6 @@ -import jwt import typing as t +import jwt from starlette import authentication from starlette.requests import Request @@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): """Custom Starlette authentication backend for JWT.""" @staticmethod - def get_token_from_header(header: str) -> str: - """Parse JWT token from header value.""" + def get_token_from_cookie(cookie: str) -> str: + """Parse JWT token from cookie.""" try: - prefix, token = header.split() + prefix, token = cookie.split() except ValueError: raise authentication.AuthenticationError( - "Unable to split prefix and token from Authorization header." + "Unable to split prefix and token from authorization cookie." ) if prefix.upper() != "JWT": raise authentication.AuthenticationError( - f"Invalid Authorization header prefix '{prefix}'." + f"Invalid authorization cookie prefix '{prefix}'." ) return token @@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): self, request: Request ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]: """Handles JWT authentication process.""" - if "Authorization" not in request.headers: + cookie = request.cookies.get("token") + if not cookie: return None - auth = request.headers["Authorization"] - token = self.get_token_from_header(auth) + token = self.get_token_from_cookie(cookie) try: payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) @@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): scopes = ["authenticated"] - if payload.get("admin") is True: + if not payload.get("token"): + raise authentication.AuthenticationError("Token is missing from JWT.") + if not payload.get("refresh"): + raise authentication.AuthenticationError( + "Refresh token is missing from JWT." + ) + + try: + user_details = payload.get("user_details") + if not user_details or not user_details.get("id"): + raise authentication.AuthenticationError("Improper user details.") + except Exception: + raise authentication.AuthenticationError("Could not parse user details.") + + user = User(token, user_details) + if await user.fetch_admin_status(request): scopes.append("admin") - return authentication.AuthCredentials(scopes), User(token, payload) + return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index f40c68c..857c2ed 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,6 +1,11 @@ import typing as t +import jwt from starlette.authentication import BaseUser +from starlette.requests import Request + +from backend.constants import SECRET_KEY +from backend.discord import fetch_user_details class User(BaseUser): @@ -9,6 +14,7 @@ class User(BaseUser): def __init__(self, token: str, payload: dict[str, t.Any]) -> None: self.token = token self.payload = payload + self.admin = False @property def is_authenticated(self) -> bool: @@ -23,3 +29,23 @@ class User(BaseUser): @property def discord_mention(self) -> str: return f"<@{self.payload['id']}>" + + @property + def decoded_token(self) -> dict[str, any]: + return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) + + async def fetch_admin_status(self, request: Request) -> bool: + self.admin = await request.state.db.admins.find_one( + {"_id": self.payload["id"]} + ) is not None + + return self.admin + + async def refresh_data(self) -> None: + """Fetches user data from discord, and updates the instance.""" + self.payload = await fetch_user_details(self.decoded_token.get("token")) + + updated_info = self.decoded_token + updated_info["user_details"] = self.payload + + self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256") |