diff options
Diffstat (limited to 'thallium-backend/src/auth.py')
| -rw-r--r-- | thallium-backend/src/auth.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/thallium-backend/src/auth.py b/thallium-backend/src/auth.py new file mode 100644 index 0000000..1dd1e23 --- /dev/null +++ b/thallium-backend/src/auth.py @@ -0,0 +1,110 @@ +import logging +import typing as t +from datetime import UTC, datetime, timedelta +from enum import IntFlag +from uuid import uuid4 + +import jwt +from fastapi import HTTPException, Request +from fastapi.security import HTTPAuthorizationCredentials +from fastapi.security.http import HTTPBase + +from src.dto.users import User, UserPermission +from src.settings import CONFIG + +log = logging.getLogger(__name__) + + +class UserTypes(IntFlag): + """All types of users.""" + + VOUCHER_USER = 2**0 + REGULAR_USER = 2**1 + + +class TokenAuth(HTTPBase): + """Ensure all requests with this auth enabled include an auth header with the expected token.""" + + def __init__( + self, + *, + auto_error: bool = True, + allow_vouchers: bool = False, + allow_regular_users: bool = False, + ) -> None: + super().__init__(scheme="token", auto_error=auto_error) + self.allow_vouchers = allow_vouchers + self.allow_regular_users = allow_regular_users + + async def __call__(self, request: Request) -> HTTPAuthorizationCredentials: + """Parse the token in the auth header, and check it matches with the expected token.""" + creds: HTTPAuthorizationCredentials = await super().__call__(request) + if creds.scheme.lower() != "token": + raise HTTPException( + status_code=401, + detail="Incorrect scheme passed", + ) + if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value(): + request.state.user = User(user_id=uuid4(), permissions=~UserPermission(0)) + return + + jwt_data = verify_jwt( + creds.credentials, + allow_vouchers=self.allow_vouchers, + allow_regular_users=self.allow_regular_users, + ) + if not jwt_data: + raise HTTPException( + status_code=403, + detail="Invalid authentication credentials", + ) + if jwt_data["iss"] == "thallium:user": + request.state.user_id = jwt_data["sub"] + else: + request.state.voucher_id = jwt_data["sub"] + + +def build_jwt( + identifier: str, + user_type: t.Literal["voucher", "user"], +) -> str: + """Build & sign a jwt.""" + return jwt.encode( + payload={ + "sub": identifier, + "iss": f"thallium:{user_type}", + "exp": datetime.now(tz=UTC) + timedelta(minutes=30), + "nbf": datetime.now(tz=UTC) - timedelta(minutes=1), + }, + key=CONFIG.signing_key.get_secret_value(), + ) + + +def verify_jwt( + jwt_data: str, + *, + allow_vouchers: bool, + allow_regular_users: bool, +) -> dict | None: + """Return and verify the given JWT.""" + issuers = [] + if allow_vouchers: + issuers.append("thallium:voucher") + if allow_regular_users: + issuers.append("thallium:user") + try: + return jwt.decode( + jwt_data, + key=CONFIG.signing_key.get_secret_value(), + issuer=issuers, + algorithms=("HS256",), + options={"require": ["exp", "iss", "sub", "nbf"]}, + ) + except jwt.InvalidIssuerError as e: + raise HTTPException(403, "Your user type does not have access to this resource") from e + except jwt.InvalidSignatureError as e: + raise HTTPException(401, "Invalid JWT signature") from e + except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e: + raise HTTPException(401, "Invalid JWT passed") from e + except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e: + raise HTTPException(401, "JWT not valid for current time") from e |