aboutsummaryrefslogtreecommitdiffstats
path: root/backend/authentication/backend.py
blob: e4699bda35ba055cf551bfcd349a9f789901c0b0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import jwt
import typing as t
from abc import ABC

from starlette import authentication
from starlette.requests import Request

from backend import constants
# We must import user such way here to avoid circular imports
from .user import User


class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
    """Custom Starlette authentication backend for JWT."""

    @staticmethod
    def get_token_from_header(header: str) -> str:
        """Parse JWT token from header value."""
        try:
            prefix, token = header.split()
        except ValueError:
            raise authentication.AuthenticationError(
                "Unable to split prefix and token from Authorization header."
            )

        if prefix.upper() != "JWT":
            raise authentication.AuthenticationError(
                f"Invalid Authorization header prefix '{prefix}'."
            )

        return token

    async def authenticate(
        self, request: Request
    ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
        """Handles JWT authentication process."""
        if "Authorization" not in request.headers:
            return

        auth = request.headers["Authorization"]
        token = self.get_token_from_header(auth)

        try:
            payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"])
        except jwt.InvalidTokenError as e:
            raise authentication.AuthenticationError(str(e))

        scopes = ["authenticated"]

        if payload.get("admin", False) is True:
            scopes.append("admin")

        return authentication.AuthCredentials(scopes), User(token, payload)