aboutsummaryrefslogtreecommitdiffstats
path: root/thallium-backend/src/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'thallium-backend/src/auth.py')
-rw-r--r--thallium-backend/src/auth.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/thallium-backend/src/auth.py b/thallium-backend/src/auth.py
new file mode 100644
index 0000000..1dd1e23
--- /dev/null
+++ b/thallium-backend/src/auth.py
@@ -0,0 +1,110 @@
+import logging
+import typing as t
+from datetime import UTC, datetime, timedelta
+from enum import IntFlag
+from uuid import uuid4
+
+import jwt
+from fastapi import HTTPException, Request
+from fastapi.security import HTTPAuthorizationCredentials
+from fastapi.security.http import HTTPBase
+
+from src.dto.users import User, UserPermission
+from src.settings import CONFIG
+
+log = logging.getLogger(__name__)
+
+
+class UserTypes(IntFlag):
+ """All types of users."""
+
+ VOUCHER_USER = 2**0
+ REGULAR_USER = 2**1
+
+
+class TokenAuth(HTTPBase):
+ """Ensure all requests with this auth enabled include an auth header with the expected token."""
+
+ def __init__(
+ self,
+ *,
+ auto_error: bool = True,
+ allow_vouchers: bool = False,
+ allow_regular_users: bool = False,
+ ) -> None:
+ super().__init__(scheme="token", auto_error=auto_error)
+ self.allow_vouchers = allow_vouchers
+ self.allow_regular_users = allow_regular_users
+
+ async def __call__(self, request: Request) -> HTTPAuthorizationCredentials:
+ """Parse the token in the auth header, and check it matches with the expected token."""
+ creds: HTTPAuthorizationCredentials = await super().__call__(request)
+ if creds.scheme.lower() != "token":
+ raise HTTPException(
+ status_code=401,
+ detail="Incorrect scheme passed",
+ )
+ if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value():
+ request.state.user = User(user_id=uuid4(), permissions=~UserPermission(0))
+ return
+
+ jwt_data = verify_jwt(
+ creds.credentials,
+ allow_vouchers=self.allow_vouchers,
+ allow_regular_users=self.allow_regular_users,
+ )
+ if not jwt_data:
+ raise HTTPException(
+ status_code=403,
+ detail="Invalid authentication credentials",
+ )
+ if jwt_data["iss"] == "thallium:user":
+ request.state.user_id = jwt_data["sub"]
+ else:
+ request.state.voucher_id = jwt_data["sub"]
+
+
+def build_jwt(
+ identifier: str,
+ user_type: t.Literal["voucher", "user"],
+) -> str:
+ """Build & sign a jwt."""
+ return jwt.encode(
+ payload={
+ "sub": identifier,
+ "iss": f"thallium:{user_type}",
+ "exp": datetime.now(tz=UTC) + timedelta(minutes=30),
+ "nbf": datetime.now(tz=UTC) - timedelta(minutes=1),
+ },
+ key=CONFIG.signing_key.get_secret_value(),
+ )
+
+
+def verify_jwt(
+ jwt_data: str,
+ *,
+ allow_vouchers: bool,
+ allow_regular_users: bool,
+) -> dict | None:
+ """Return and verify the given JWT."""
+ issuers = []
+ if allow_vouchers:
+ issuers.append("thallium:voucher")
+ if allow_regular_users:
+ issuers.append("thallium:user")
+ try:
+ return jwt.decode(
+ jwt_data,
+ key=CONFIG.signing_key.get_secret_value(),
+ issuer=issuers,
+ algorithms=("HS256",),
+ options={"require": ["exp", "iss", "sub", "nbf"]},
+ )
+ except jwt.InvalidIssuerError as e:
+ raise HTTPException(403, "Your user type does not have access to this resource") from e
+ except jwt.InvalidSignatureError as e:
+ raise HTTPException(401, "Invalid JWT signature") from e
+ except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e:
+ raise HTTPException(401, "Invalid JWT passed") from e
+ except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e:
+ raise HTTPException(401, "JWT not valid for current time") from e