aboutsummaryrefslogtreecommitdiffstats
path: root/thallium-backend/src
diff options
context:
space:
mode:
Diffstat (limited to 'thallium-backend/src')
-rw-r--r--thallium-backend/src/app.py20
-rw-r--r--thallium-backend/src/auth.py110
-rw-r--r--thallium-backend/src/dto/__init__.py5
-rw-r--r--thallium-backend/src/dto/login.py13
-rw-r--r--thallium-backend/src/dto/users.py25
-rw-r--r--thallium-backend/src/dto/vouchers.py21
-rw-r--r--thallium-backend/src/orm/__init__.py2
-rw-r--r--thallium-backend/src/orm/base.py6
-rw-r--r--thallium-backend/src/orm/users.py5
-rw-r--r--thallium-backend/src/orm/vouchers.py25
-rw-r--r--thallium-backend/src/routes/__init__.py4
-rw-r--r--thallium-backend/src/routes/debug.py17
-rw-r--r--thallium-backend/src/routes/login.py31
-rw-r--r--thallium-backend/src/routes/vouchers.py24
-rw-r--r--thallium-backend/src/settings.py1
15 files changed, 300 insertions, 9 deletions
diff --git a/thallium-backend/src/app.py b/thallium-backend/src/app.py
index 6060ec3..3e5847c 100644
--- a/thallium-backend/src/app.py
+++ b/thallium-backend/src/app.py
@@ -1,6 +1,8 @@
import logging
+import time
+from collections.abc import Awaitable, Callable
-from fastapi import FastAPI, Request
+from fastapi import FastAPI, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@@ -24,3 +26,19 @@ def pydantic_validation_error(request: Request, error: RequestValidationError) -
"""Raise a warning for pydantic validation errors, before returning."""
log.warning("Error from %s: %s", request.url, error)
return JSONResponse({"error": str(error)}, status_code=422)
+
+
+@fastapi_app.middleware("http")
+async def add_process_time_and_security_headers(
+ request: Request,
+ call_next: Callable[[Request], Awaitable[Response]],
+) -> Response:
+ """Add process time and some security headers before sending the response."""
+ start_time = time.perf_counter()
+ response = await call_next(request)
+ response.headers["X-Process-Time"] = str(time.perf_counter() - start_time)
+ response.headers["X-Frame-Options"] = "DENY"
+ response.headers["X-XSS-Protection"] = "1; mode=block"
+ response.headers["Strict-Transport-Security"] = "max-age=31536000"
+ response.headers["X-Content-Type-Options"] = "nosniff"
+ return response
diff --git a/thallium-backend/src/auth.py b/thallium-backend/src/auth.py
new file mode 100644
index 0000000..1dd1e23
--- /dev/null
+++ b/thallium-backend/src/auth.py
@@ -0,0 +1,110 @@
+import logging
+import typing as t
+from datetime import UTC, datetime, timedelta
+from enum import IntFlag
+from uuid import uuid4
+
+import jwt
+from fastapi import HTTPException, Request
+from fastapi.security import HTTPAuthorizationCredentials
+from fastapi.security.http import HTTPBase
+
+from src.dto.users import User, UserPermission
+from src.settings import CONFIG
+
+log = logging.getLogger(__name__)
+
+
+class UserTypes(IntFlag):
+ """All types of users."""
+
+ VOUCHER_USER = 2**0
+ REGULAR_USER = 2**1
+
+
+class TokenAuth(HTTPBase):
+ """Ensure all requests with this auth enabled include an auth header with the expected token."""
+
+ def __init__(
+ self,
+ *,
+ auto_error: bool = True,
+ allow_vouchers: bool = False,
+ allow_regular_users: bool = False,
+ ) -> None:
+ super().__init__(scheme="token", auto_error=auto_error)
+ self.allow_vouchers = allow_vouchers
+ self.allow_regular_users = allow_regular_users
+
+ async def __call__(self, request: Request) -> HTTPAuthorizationCredentials:
+ """Parse the token in the auth header, and check it matches with the expected token."""
+ creds: HTTPAuthorizationCredentials = await super().__call__(request)
+ if creds.scheme.lower() != "token":
+ raise HTTPException(
+ status_code=401,
+ detail="Incorrect scheme passed",
+ )
+ if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value():
+ request.state.user = User(user_id=uuid4(), permissions=~UserPermission(0))
+ return
+
+ jwt_data = verify_jwt(
+ creds.credentials,
+ allow_vouchers=self.allow_vouchers,
+ allow_regular_users=self.allow_regular_users,
+ )
+ if not jwt_data:
+ raise HTTPException(
+ status_code=403,
+ detail="Invalid authentication credentials",
+ )
+ if jwt_data["iss"] == "thallium:user":
+ request.state.user_id = jwt_data["sub"]
+ else:
+ request.state.voucher_id = jwt_data["sub"]
+
+
+def build_jwt(
+ identifier: str,
+ user_type: t.Literal["voucher", "user"],
+) -> str:
+ """Build & sign a jwt."""
+ return jwt.encode(
+ payload={
+ "sub": identifier,
+ "iss": f"thallium:{user_type}",
+ "exp": datetime.now(tz=UTC) + timedelta(minutes=30),
+ "nbf": datetime.now(tz=UTC) - timedelta(minutes=1),
+ },
+ key=CONFIG.signing_key.get_secret_value(),
+ )
+
+
+def verify_jwt(
+ jwt_data: str,
+ *,
+ allow_vouchers: bool,
+ allow_regular_users: bool,
+) -> dict | None:
+ """Return and verify the given JWT."""
+ issuers = []
+ if allow_vouchers:
+ issuers.append("thallium:voucher")
+ if allow_regular_users:
+ issuers.append("thallium:user")
+ try:
+ return jwt.decode(
+ jwt_data,
+ key=CONFIG.signing_key.get_secret_value(),
+ issuer=issuers,
+ algorithms=("HS256",),
+ options={"require": ["exp", "iss", "sub", "nbf"]},
+ )
+ except jwt.InvalidIssuerError as e:
+ raise HTTPException(403, "Your user type does not have access to this resource") from e
+ except jwt.InvalidSignatureError as e:
+ raise HTTPException(401, "Invalid JWT signature") from e
+ except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e:
+ raise HTTPException(401, "Invalid JWT passed") from e
+ except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e:
+ raise HTTPException(401, "JWT not valid for current time") from e
diff --git a/thallium-backend/src/dto/__init__.py b/thallium-backend/src/dto/__init__.py
new file mode 100644
index 0000000..92d3914
--- /dev/null
+++ b/thallium-backend/src/dto/__init__.py
@@ -0,0 +1,5 @@
+from .login import VoucherClaim, VoucherLogin
+from .users import User
+from .vouchers import Voucher
+
+__all__ = ("LoginData", "User", "Voucher", "VoucherClaim", "VoucherLogin")
diff --git a/thallium-backend/src/dto/login.py b/thallium-backend/src/dto/login.py
new file mode 100644
index 0000000..8f27acb
--- /dev/null
+++ b/thallium-backend/src/dto/login.py
@@ -0,0 +1,13 @@
+from pydantic import BaseModel
+
+
+class VoucherLogin(BaseModel):
+ """The data needed to login with a voucher."""
+
+ voucher_code: str
+
+
+class VoucherClaim(VoucherLogin):
+ """A JWT for a verified voucher."""
+
+ jwt: str
diff --git a/thallium-backend/src/dto/users.py b/thallium-backend/src/dto/users.py
new file mode 100644
index 0000000..0d1cdac
--- /dev/null
+++ b/thallium-backend/src/dto/users.py
@@ -0,0 +1,25 @@
+from enum import IntFlag
+from uuid import UUID
+
+from pydantic import BaseModel
+
+
+class UserPermission(IntFlag):
+ """The permissions a user has."""
+
+ VIEW_VOUCHERS = 2**0
+ ISSUE_VOUCHERS = 2**1
+ REVOKE_VOUCHERS = 2**1
+ VIEW_PRODUCTS = 2**2
+ MANAGE_USERS = 2**3
+
+
+class User(BaseModel):
+ """An user authenticated with the backend."""
+
+ id: UUID
+ permissions: int
+
+ def has_permission(self, permission: UserPermission) -> bool:
+ """Whether the user has the given permission."""
+ return (self.permissions & permission) == permission
diff --git a/thallium-backend/src/dto/vouchers.py b/thallium-backend/src/dto/vouchers.py
new file mode 100644
index 0000000..81dfe02
--- /dev/null
+++ b/thallium-backend/src/dto/vouchers.py
@@ -0,0 +1,21 @@
+from datetime import datetime
+from decimal import Decimal
+from uuid import UUID
+
+from pydantic import BaseModel
+
+
+class VoucherCreate(BaseModel):
+ """The data required to create a new Voucher."""
+
+ voucher_code: str
+ balance: Decimal
+
+
+class Voucher(VoucherCreate):
+ """A voucher as stored in the database."""
+
+ id: UUID
+ created_at: datetime
+ updated_at: datetime
+ active: bool
diff --git a/thallium-backend/src/orm/__init__.py b/thallium-backend/src/orm/__init__.py
index ed803e8..cf70ddd 100644
--- a/thallium-backend/src/orm/__init__.py
+++ b/thallium-backend/src/orm/__init__.py
@@ -3,10 +3,12 @@
from .base import AuditBase, Base
from .products import Product
from .users import User
+from .vouchers import Voucher
__all__ = (
"AuditBase",
"Base",
"Product",
"User",
+ "Voucher",
)
diff --git a/thallium-backend/src/orm/base.py b/thallium-backend/src/orm/base.py
index a1642c7..ec79d99 100644
--- a/thallium-backend/src/orm/base.py
+++ b/thallium-backend/src/orm/base.py
@@ -2,13 +2,13 @@
import re
from datetime import datetime
-from uuid import UUID, uuid4
+from uuid import UUID
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.schema import MetaData
-from sqlalchemy.sql import func
+from sqlalchemy.sql import func, text
from sqlalchemy.types import DateTime
NAMING_CONVENTIONS = {
@@ -35,7 +35,7 @@ class Base(AsyncAttrs, DeclarativeBase):
class AuditBase:
"""Common columns for a table with UUID PK and datetime audit columns."""
- id: Mapped[UUID] = mapped_column(default=uuid4, primary_key=True)
+ id: Mapped[UUID] = mapped_column(server_default=text("gen_random_uuid()"), primary_key=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
diff --git a/thallium-backend/src/orm/users.py b/thallium-backend/src/orm/users.py
index 065519a..8f78387 100644
--- a/thallium-backend/src/orm/users.py
+++ b/thallium-backend/src/orm/users.py
@@ -1,4 +1,4 @@
-from sqlalchemy.orm import Mapped, mapped_column
+from sqlalchemy.orm import Mapped
from .base import AuditBase, Base
@@ -8,5 +8,4 @@ class User(AuditBase, Base):
__tablename__ = "users"
- user_id: Mapped[int] = mapped_column(primary_key=True)
- is_admin: Mapped[bool]
+ permissions: Mapped[int]
diff --git a/thallium-backend/src/orm/vouchers.py b/thallium-backend/src/orm/vouchers.py
new file mode 100644
index 0000000..2b3af77
--- /dev/null
+++ b/thallium-backend/src/orm/vouchers.py
@@ -0,0 +1,25 @@
+from decimal import Decimal
+
+from sqlalchemy import Index, text
+from sqlalchemy.orm import Mapped, mapped_column
+
+from .base import AuditBase, Base
+
+
+class Voucher(AuditBase, Base):
+ """A valid voucher in the database."""
+
+ __tablename__ = "vouchers"
+
+ voucher_code: Mapped[str] = mapped_column()
+ active: Mapped[bool] = mapped_column(default=True)
+ balance: Mapped[Decimal]
+
+ __table_args__ = (
+ Index(
+ "ix_unique_active_voucher_code",
+ voucher_code,
+ unique=True,
+ postgresql_where=text("active"),
+ ),
+ )
diff --git a/thallium-backend/src/routes/__init__.py b/thallium-backend/src/routes/__init__.py
index afa02af..2dc76c8 100644
--- a/thallium-backend/src/routes/__init__.py
+++ b/thallium-backend/src/routes/__init__.py
@@ -1,8 +1,12 @@
from fastapi import APIRouter
from src.routes.debug import router as debug_router
+from src.routes.login import router as login_router
+from src.routes.vouchers import router as voucher_router
from src.settings import CONFIG
top_level_router = APIRouter()
+top_level_router.include_router(login_router)
+top_level_router.include_router(voucher_router)
if CONFIG.debug:
top_level_router.include_router(debug_router)
diff --git a/thallium-backend/src/routes/debug.py b/thallium-backend/src/routes/debug.py
index fac40d7..60c8643 100644
--- a/thallium-backend/src/routes/debug.py
+++ b/thallium-backend/src/routes/debug.py
@@ -1,10 +1,13 @@
import logging
from fastapi import APIRouter
+from sqlalchemy import select
-from src.settings import PrintfulClient
+from src.dto import Voucher
+from src.orm import Voucher as DBVoucher
+from src.settings import DBSession, PrintfulClient
-router = APIRouter(tags=["debug"])
+router = APIRouter(tags=["debug"], prefix="/debug")
log = logging.getLogger(__name__)
@@ -34,3 +37,13 @@ async def get_v2_oauth_scopes(client: PrintfulClient) -> dict:
"""Return all templates in printful."""
resp = await client.get("/v2/oauth-scopes")
return resp.json()
+
+
[email protected]("/vouchers")
+async def get_vouchers(db: DBSession, *, only_active: bool = True) -> list[Voucher]:
+ """Return all templates in printful."""
+ stmt = select(DBVoucher)
+ if only_active:
+ stmt = stmt.where(DBVoucher.active)
+ res = await db.execute(stmt)
+ return res.scalars().all()
diff --git a/thallium-backend/src/routes/login.py b/thallium-backend/src/routes/login.py
new file mode 100644
index 0000000..7eeb2cf
--- /dev/null
+++ b/thallium-backend/src/routes/login.py
@@ -0,0 +1,31 @@
+import logging
+
+from fastapi import APIRouter, HTTPException
+from sqlalchemy import and_, select
+
+from src.auth import build_jwt
+from src.dto import VoucherClaim, VoucherLogin
+from src.orm import Voucher as DBVoucher
+from src.settings import DBSession
+
+router = APIRouter(tags=["Login"])
+log = logging.getLogger(__name__)
+
+
[email protected]("/voucher-login")
+async def handle_voucher_login(login_payload: VoucherLogin, db: DBSession) -> VoucherClaim:
+ """Return a signed JWT if the given voucher is present in the database."""
+ stmt = select(DBVoucher).where(
+ and_(
+ DBVoucher.voucher_code == login_payload.voucher_code,
+ DBVoucher.active,
+ )
+ )
+ voucher = await db.scalar(stmt)
+ if not voucher:
+ raise HTTPException(422, "Voucher not found")
+
+ return VoucherClaim(
+ voucher_code=login_payload.voucher_code,
+ jwt=build_jwt(str(voucher.id), "voucher"),
+ )
diff --git a/thallium-backend/src/routes/vouchers.py b/thallium-backend/src/routes/vouchers.py
new file mode 100644
index 0000000..97b9fef
--- /dev/null
+++ b/thallium-backend/src/routes/vouchers.py
@@ -0,0 +1,24 @@
+import logging
+
+from fastapi import APIRouter, Depends, Request
+from sqlalchemy import select
+
+from src.auth import TokenAuth
+from src.dto import Voucher
+from src.orm import Voucher as DBVoucher
+from src.settings import DBSession
+
+router = APIRouter(
+ prefix="/vouchers",
+ tags=["Voucher users"],
+ dependencies=[Depends(TokenAuth(allow_vouchers=True))],
+)
+log = logging.getLogger(__name__)
+
+
+async def get_vouchers(request: Request, db: DBSession) -> Voucher | None:
+ """Get the voucher for the currently authenticated voucher id."""
+ stmt = select(DBVoucher).where(DBVoucher.id == request.state.voucher_id)
+ res = await db.execute(stmt)
+ return res.scalars().one_or_none()
diff --git a/thallium-backend/src/settings.py b/thallium-backend/src/settings.py
index f6e144c..81a6335 100644
--- a/thallium-backend/src/settings.py
+++ b/thallium-backend/src/settings.py
@@ -23,6 +23,7 @@ class _Config(
debug: bool = False
git_sha: str = "development"
+ signing_key: pydantic.SecretStr
database_url: pydantic.SecretStr
super_admin_token: pydantic.SecretStr