diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/__init__.py | 5 | ||||
| -rw-r--r-- | backend/authentication/__init__.py | 4 | ||||
| -rw-r--r-- | backend/authentication/backend.py | 53 | ||||
| -rw-r--r-- | backend/authentication/user.py | 22 | ||||
| -rw-r--r-- | backend/routes/index.py | 18 | 
5 files changed, 98 insertions, 4 deletions
| diff --git a/backend/__init__.py b/backend/__init__.py index 6215961..7b4cb4d 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -2,8 +2,10 @@ import os  from starlette.applications import Starlette  from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware  from starlette.middleware.cors import CORSMiddleware +from backend.authentication import JWTAuthenticationBackend  from backend.route_manager import create_route_map  from backend.middleware import DatabaseMiddleware @@ -19,7 +21,8 @@ middleware = [          ],          allow_methods=["*"]      ), -    Middleware(DatabaseMiddleware) +    Middleware(DatabaseMiddleware), +    Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend())  ]  app = Starlette(routes=create_route_map(), middleware=middleware) diff --git a/backend/authentication/__init__.py b/backend/authentication/__init__.py new file mode 100644 index 0000000..43601a7 --- /dev/null +++ b/backend/authentication/__init__.py @@ -0,0 +1,4 @@ +from .backend import JWTAuthenticationBackend +from .user import User + +__all__ = ["JWTAuthenticationBackend", "User"] diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py new file mode 100644 index 0000000..38668eb --- /dev/null +++ b/backend/authentication/backend.py @@ -0,0 +1,53 @@ +import jwt +import typing as t +from abc import ABC + +from starlette import authentication +from starlette.requests import Request + +from backend import constants +# We must import user such way here to avoid circular imports +from .user import User + + +class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): +    """Custom Starlette authentication backend for JWT.""" + +    @staticmethod +    def get_token_from_header(header: str) -> t.Optional[str]: +        """Parse JWT token from header value.""" +        try: +            prefix, token = header.split() +        except ValueError: +            raise authentication.AuthenticationError( +                "Unable to split prefix and token from Authorization header." +            ) + +        if prefix.upper() != "JWT": +            raise authentication.AuthenticationError( +                f"Invalid Authorization header prefix '{prefix}'." +            ) + +        return token + +    async def authenticate( +        self, request: Request +    ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]: +        """Handles JWT authentication process.""" +        if "Authorization" not in request.headers: +            return + +        auth = request.headers["Authorization"] +        token = self.get_token_from_header(auth) + +        try: +            payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"]) +        except jwt.InvalidTokenError as e: +            raise authentication.AuthenticationError(str(e)) + +        scopes = ["authenticated"] + +        if payload.get("admin", False) is True: +            scopes.append("admin") + +        return authentication.AuthCredentials(scopes), User(token, payload) diff --git a/backend/authentication/user.py b/backend/authentication/user.py new file mode 100644 index 0000000..afa243f --- /dev/null +++ b/backend/authentication/user.py @@ -0,0 +1,22 @@ +import typing as t +from abc import ABC + +from starlette.authentication import BaseUser + + +class User(BaseUser, ABC): +    """Starlette BaseUser implementation for JWT authentication.""" + +    def __init__(self, token: str, payload: t.Dict) -> None: +        self.token = token +        self.payload = payload + +    @property +    def is_authenticated(self) -> bool: +        """Returns True because user is always authenticated at this stage.""" +        return True + +    @property +    def display_name(self) -> str: +        """Return username and discriminator as display name.""" +        return f"{self.payload['username']}#{self.payload['discriminator']}" diff --git a/backend/routes/index.py b/backend/routes/index.py index 8144723..b37f381 100644 --- a/backend/routes/index.py +++ b/backend/routes/index.py @@ -18,7 +18,19 @@ class IndexRoute(Route):      path = "/"      def get(self, request: Request) -> JSONResponse: -        return JSONResponse({ +        response_data = {              "message": "Hello, world!", -            "client": request.client.host -        }) +            "client": request.client.host, +            "user": { +                "authenticated": False +            } +        } + +        if request.user.is_authenticated: +            response_data["user"] = { +                "authenticated": True, +                "user": request.user.payload, +                "scopes": request.auth.scopes +            } + +        return JSONResponse(response_data) | 
