diff options
Diffstat (limited to 'backend/authentication')
| -rw-r--r-- | backend/authentication/backend.py | 37 | ||||
| -rw-r--r-- | backend/authentication/user.py | 26 | 
2 files changed, 52 insertions, 11 deletions
| diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index f1d2ece..c7590e9 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,6 @@ -import jwt  import typing as t +import jwt  from starlette import authentication  from starlette.requests import Request @@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):      """Custom Starlette authentication backend for JWT."""      @staticmethod -    def get_token_from_header(header: str) -> str: -        """Parse JWT token from header value.""" +    def get_token_from_cookie(cookie: str) -> str: +        """Parse JWT token from cookie."""          try: -            prefix, token = header.split() +            prefix, token = cookie.split()          except ValueError:              raise authentication.AuthenticationError( -                "Unable to split prefix and token from Authorization header." +                "Unable to split prefix and token from authorization cookie."              )          if prefix.upper() != "JWT":              raise authentication.AuthenticationError( -                f"Invalid Authorization header prefix '{prefix}'." +                f"Invalid authorization cookie prefix '{prefix}'."              )          return token @@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          self, request: Request      ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:          """Handles JWT authentication process.""" -        if "Authorization" not in request.headers: +        cookie = request.cookies.get("token") +        if not cookie:              return None -        auth = request.headers["Authorization"] -        token = self.get_token_from_header(auth) +        token = self.get_token_from_cookie(cookie)          try:              payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) @@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          scopes = ["authenticated"] -        if payload.get("admin") is True: +        if not payload.get("token"): +            raise authentication.AuthenticationError("Token is missing from JWT.") +        if not payload.get("refresh"): +            raise authentication.AuthenticationError( +                "Refresh token is missing from JWT." +            ) + +        try: +            user_details = payload.get("user_details") +            if not user_details or not user_details.get("id"): +                raise authentication.AuthenticationError("Improper user details.") +        except Exception: +            raise authentication.AuthenticationError("Could not parse user details.") + +        user = User(token, user_details) +        if await user.fetch_admin_status(request):              scopes.append("admin") -        return authentication.AuthCredentials(scopes), User(token, payload) +        return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index f40c68c..857c2ed 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,6 +1,11 @@  import typing as t +import jwt  from starlette.authentication import BaseUser +from starlette.requests import Request + +from backend.constants import SECRET_KEY +from backend.discord import fetch_user_details  class User(BaseUser): @@ -9,6 +14,7 @@ class User(BaseUser):      def __init__(self, token: str, payload: dict[str, t.Any]) -> None:          self.token = token          self.payload = payload +        self.admin = False      @property      def is_authenticated(self) -> bool: @@ -23,3 +29,23 @@ class User(BaseUser):      @property      def discord_mention(self) -> str:          return f"<@{self.payload['id']}>" + +    @property +    def decoded_token(self) -> dict[str, any]: +        return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) + +    async def fetch_admin_status(self, request: Request) -> bool: +        self.admin = await request.state.db.admins.find_one( +            {"_id": self.payload["id"]} +        ) is not None + +        return self.admin + +    async def refresh_data(self) -> None: +        """Fetches user data from discord, and updates the instance.""" +        self.payload = await fetch_user_details(self.decoded_token.get("token")) + +        updated_info = self.decoded_token +        updated_info["user_details"] = self.payload + +        self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256") | 
