diff options
Diffstat (limited to 'backend/authentication')
-rw-r--r-- | backend/authentication/__init__.py | 3 | ||||
-rw-r--r-- | backend/authentication/backend.py | 52 |
2 files changed, 54 insertions, 1 deletions
diff --git a/backend/authentication/__init__.py b/backend/authentication/__init__.py index 35b01f3..43601a7 100644 --- a/backend/authentication/__init__.py +++ b/backend/authentication/__init__.py @@ -1,3 +1,4 @@ +from .backend import JWTAuthenticationBackend from .user import User -__all__ = ["User"] +__all__ = ["JWTAuthenticationBackend", "User"] diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py new file mode 100644 index 0000000..c38bfaf --- /dev/null +++ b/backend/authentication/backend.py @@ -0,0 +1,52 @@ +import jwt +import typing as t +from abc import ABC + +from starlette import authentication +from starlette.requests import Request + +from backend import constants +from backend.authentication import User + + +class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): + """Custom Starlette authentication backend for JWT.""" + + @staticmethod + def get_token_from_header(header: str) -> t.Optional[str]: + """Parse JWT token from header value.""" + try: + prefix, token = header.split() + except ValueError: + raise authentication.AuthenticationError( + "Unable to split prefix and token from Authorization header." + ) + + if prefix.upper() != "JWT": + raise authentication.AuthenticationError( + f"Invalid Authorization header prefix '{prefix}'." + ) + + return token + + async def authenticate( + self, request: Request + ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]: + """Handles JWT authentication process.""" + if "Authorization" not in request.headers: + return + + auth = request.headers["Authorization"] + token = self.get_token_from_header(auth) + + try: + payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) + except jwt.InvalidTokenError as e: + raise authentication.AuthenticationError(str(e)) + + scopes = ["authenticated"] + + if payload.get("admin", False) is True: + scopes.append("admin") + + return authentication.AuthCredentials(scopes), User(token, payload) |