1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
import logging
import typing as t
from datetime import UTC, datetime, timedelta
from enum import IntFlag
from uuid import uuid4
import jwt
from fastapi import HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials
from fastapi.security.http import HTTPBearer
from sqlalchemy import select
from src.dto import User, UserPermission, Voucher
from src.orm import User as DBUser, Voucher as DBVoucher
from src.settings import CONFIG, DBSession
log = logging.getLogger(__name__)
class UserTypes(IntFlag):
"""All types of users."""
VOUCHER_USER = 2**0
REGULAR_USER = 2**1
class TokenAuth(HTTPBearer):
"""Ensure all requests with this auth enabled include an auth header with the expected token."""
def __init__(
self,
*,
auto_error: bool = True,
allow_vouchers: bool = False,
allow_regular_users: bool = False,
) -> None:
super().__init__(scheme_name="bearer", bearerFormat="JWT", auto_error=auto_error)
self.allow_vouchers = allow_vouchers
self.allow_regular_users = allow_regular_users
async def __call__(self, request: Request, db: DBSession) -> HTTPAuthorizationCredentials:
"""Parse the token in the auth header, and check it matches with the expected token."""
creds: HTTPAuthorizationCredentials = await super().__call__(request)
if creds.scheme.lower() != "bearer":
raise HTTPException(
status_code=401,
detail="Incorrect scheme passed",
)
if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value():
request.state.user = User(id=uuid4(), permissions=~UserPermission(0))
return
jwt_data = verify_jwt(
creds.credentials,
allow_vouchers=self.allow_vouchers,
allow_regular_users=self.allow_regular_users,
)
if not jwt_data:
raise HTTPException(
status_code=403,
detail="Invalid authentication credentials",
)
await self.attach_auth_info_to_state(jwt_data, request, db)
async def attach_auth_info_to_state(self, jwt_data: dict, request: Request, db: DBSession) -> None:
"""Attach the auth info of the requesting user to the state object."""
requester_id = jwt_data["sub"]
table = DBUser if jwt_data["iss"] == "thallium:user" else DBVoucher
stmt = select(table).where(table.id == requester_id)
res = await db.scalar(stmt)
if isinstance(res, DBUser):
request.state.user = User.model_validate(res.__dict__)
elif isinstance(res, DBVoucher):
request.state.voucher = Voucher.model_validate(res.__dict__)
else:
raise HTTPException(403, "Your user no longer exists")
def build_jwt(
identifier: str,
user_type: t.Literal["voucher", "user"],
) -> str:
"""Build & sign a jwt."""
return jwt.encode(
payload={
"sub": identifier,
"iss": f"thallium:{user_type}",
"exp": datetime.now(tz=UTC) + timedelta(minutes=30),
"nbf": datetime.now(tz=UTC) - timedelta(minutes=1),
},
key=CONFIG.signing_key.get_secret_value(),
)
def verify_jwt(
jwt_data: str,
*,
allow_vouchers: bool,
allow_regular_users: bool,
) -> dict | None:
"""Return and verify the given JWT."""
issuers = []
if allow_vouchers:
issuers.append("thallium:voucher")
if allow_regular_users:
issuers.append("thallium:user")
try:
return jwt.decode(
jwt_data,
key=CONFIG.signing_key.get_secret_value(),
issuer=issuers,
algorithms=("HS256",),
options={"require": ["exp", "iss", "sub", "nbf"]},
)
except jwt.InvalidIssuerError as e:
raise HTTPException(403, "Your user type does not have access to this resource") from e
except jwt.InvalidSignatureError as e:
raise HTTPException(401, "Invalid JWT signature") from e
except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e:
raise HTTPException(401, "Invalid JWT passed") from e
except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e:
raise HTTPException(401, "JWT not valid for current time") from e
class MissingPermissionsError(HTTPException):
"""Raised when a user is missing the required permissions."""
def __init__(self) -> None:
super().__init__(403, "Missing permissions for this resource.")
class HasPermission:
"""
Check the requesting user has all specified permissions.
Raises MissingPermissions if not.
"""
def __init__(self, required_permissions: UserPermission, *, allow_vouchers: bool = False) -> None:
self.required_permissions = required_permissions
self.allow_vouchers = allow_vouchers
async def __call__(self, request: Request) -> None:
"""Check the requesting user has all specified permissions."""
if hasattr(request.state, "voucher") and self.allow_vouchers:
return
if not request.state.user.permissions & self.required_permissions == self.required_permissions:
raise MissingPermissionsError
|