diff options
Diffstat (limited to 'thallium-backend/src')
| -rw-r--r-- | thallium-backend/src/app.py | 20 | ||||
| -rw-r--r-- | thallium-backend/src/auth.py | 110 | ||||
| -rw-r--r-- | thallium-backend/src/dto/__init__.py | 5 | ||||
| -rw-r--r-- | thallium-backend/src/dto/login.py | 13 | ||||
| -rw-r--r-- | thallium-backend/src/dto/users.py | 25 | ||||
| -rw-r--r-- | thallium-backend/src/dto/vouchers.py | 21 | ||||
| -rw-r--r-- | thallium-backend/src/orm/__init__.py | 2 | ||||
| -rw-r--r-- | thallium-backend/src/orm/base.py | 6 | ||||
| -rw-r--r-- | thallium-backend/src/orm/users.py | 5 | ||||
| -rw-r--r-- | thallium-backend/src/orm/vouchers.py | 25 | ||||
| -rw-r--r-- | thallium-backend/src/routes/__init__.py | 4 | ||||
| -rw-r--r-- | thallium-backend/src/routes/debug.py | 17 | ||||
| -rw-r--r-- | thallium-backend/src/routes/login.py | 31 | ||||
| -rw-r--r-- | thallium-backend/src/routes/vouchers.py | 24 | ||||
| -rw-r--r-- | thallium-backend/src/settings.py | 1 |
15 files changed, 300 insertions, 9 deletions
diff --git a/thallium-backend/src/app.py b/thallium-backend/src/app.py index 6060ec3..3e5847c 100644 --- a/thallium-backend/src/app.py +++ b/thallium-backend/src/app.py @@ -1,6 +1,8 @@ import logging +import time +from collections.abc import Awaitable, Callable -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -24,3 +26,19 @@ def pydantic_validation_error(request: Request, error: RequestValidationError) - """Raise a warning for pydantic validation errors, before returning.""" log.warning("Error from %s: %s", request.url, error) return JSONResponse({"error": str(error)}, status_code=422) + + +@fastapi_app.middleware("http") +async def add_process_time_and_security_headers( + request: Request, + call_next: Callable[[Request], Awaitable[Response]], +) -> Response: + """Add process time and some security headers before sending the response.""" + start_time = time.perf_counter() + response = await call_next(request) + response.headers["X-Process-Time"] = str(time.perf_counter() - start_time) + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Strict-Transport-Security"] = "max-age=31536000" + response.headers["X-Content-Type-Options"] = "nosniff" + return response diff --git a/thallium-backend/src/auth.py b/thallium-backend/src/auth.py new file mode 100644 index 0000000..1dd1e23 --- /dev/null +++ b/thallium-backend/src/auth.py @@ -0,0 +1,110 @@ +import logging +import typing as t +from datetime import UTC, datetime, timedelta +from enum import IntFlag +from uuid import uuid4 + +import jwt +from fastapi import HTTPException, Request +from fastapi.security import HTTPAuthorizationCredentials +from fastapi.security.http import HTTPBase + +from src.dto.users import User, UserPermission +from src.settings import CONFIG + +log = logging.getLogger(__name__) + + +class UserTypes(IntFlag): + """All types of users.""" + + VOUCHER_USER = 2**0 + REGULAR_USER = 2**1 + + +class TokenAuth(HTTPBase): + """Ensure all requests with this auth enabled include an auth header with the expected token.""" + + def __init__( + self, + *, + auto_error: bool = True, + allow_vouchers: bool = False, + allow_regular_users: bool = False, + ) -> None: + super().__init__(scheme="token", auto_error=auto_error) + self.allow_vouchers = allow_vouchers + self.allow_regular_users = allow_regular_users + + async def __call__(self, request: Request) -> HTTPAuthorizationCredentials: + """Parse the token in the auth header, and check it matches with the expected token.""" + creds: HTTPAuthorizationCredentials = await super().__call__(request) + if creds.scheme.lower() != "token": + raise HTTPException( + status_code=401, + detail="Incorrect scheme passed", + ) + if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value(): + request.state.user = User(user_id=uuid4(), permissions=~UserPermission(0)) + return + + jwt_data = verify_jwt( + creds.credentials, + allow_vouchers=self.allow_vouchers, + allow_regular_users=self.allow_regular_users, + ) + if not jwt_data: + raise HTTPException( + status_code=403, + detail="Invalid authentication credentials", + ) + if jwt_data["iss"] == "thallium:user": + request.state.user_id = jwt_data["sub"] + else: + request.state.voucher_id = jwt_data["sub"] + + +def build_jwt( + identifier: str, + user_type: t.Literal["voucher", "user"], +) -> str: + """Build & sign a jwt.""" + return jwt.encode( + payload={ + "sub": identifier, + "iss": f"thallium:{user_type}", + "exp": datetime.now(tz=UTC) + timedelta(minutes=30), + "nbf": datetime.now(tz=UTC) - timedelta(minutes=1), + }, + key=CONFIG.signing_key.get_secret_value(), + ) + + +def verify_jwt( + jwt_data: str, + *, + allow_vouchers: bool, + allow_regular_users: bool, +) -> dict | None: + """Return and verify the given JWT.""" + issuers = [] + if allow_vouchers: + issuers.append("thallium:voucher") + if allow_regular_users: + issuers.append("thallium:user") + try: + return jwt.decode( + jwt_data, + key=CONFIG.signing_key.get_secret_value(), + issuer=issuers, + algorithms=("HS256",), + options={"require": ["exp", "iss", "sub", "nbf"]}, + ) + except jwt.InvalidIssuerError as e: + raise HTTPException(403, "Your user type does not have access to this resource") from e + except jwt.InvalidSignatureError as e: + raise HTTPException(401, "Invalid JWT signature") from e + except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e: + raise HTTPException(401, "Invalid JWT passed") from e + except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e: + raise HTTPException(401, "JWT not valid for current time") from e diff --git a/thallium-backend/src/dto/__init__.py b/thallium-backend/src/dto/__init__.py new file mode 100644 index 0000000..92d3914 --- /dev/null +++ b/thallium-backend/src/dto/__init__.py @@ -0,0 +1,5 @@ +from .login import VoucherClaim, VoucherLogin +from .users import User +from .vouchers import Voucher + +__all__ = ("LoginData", "User", "Voucher", "VoucherClaim", "VoucherLogin") diff --git a/thallium-backend/src/dto/login.py b/thallium-backend/src/dto/login.py new file mode 100644 index 0000000..8f27acb --- /dev/null +++ b/thallium-backend/src/dto/login.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class VoucherLogin(BaseModel): + """The data needed to login with a voucher.""" + + voucher_code: str + + +class VoucherClaim(VoucherLogin): + """A JWT for a verified voucher.""" + + jwt: str diff --git a/thallium-backend/src/dto/users.py b/thallium-backend/src/dto/users.py new file mode 100644 index 0000000..0d1cdac --- /dev/null +++ b/thallium-backend/src/dto/users.py @@ -0,0 +1,25 @@ +from enum import IntFlag +from uuid import UUID + +from pydantic import BaseModel + + +class UserPermission(IntFlag): + """The permissions a user has.""" + + VIEW_VOUCHERS = 2**0 + ISSUE_VOUCHERS = 2**1 + REVOKE_VOUCHERS = 2**1 + VIEW_PRODUCTS = 2**2 + MANAGE_USERS = 2**3 + + +class User(BaseModel): + """An user authenticated with the backend.""" + + id: UUID + permissions: int + + def has_permission(self, permission: UserPermission) -> bool: + """Whether the user has the given permission.""" + return (self.permissions & permission) == permission diff --git a/thallium-backend/src/dto/vouchers.py b/thallium-backend/src/dto/vouchers.py new file mode 100644 index 0000000..81dfe02 --- /dev/null +++ b/thallium-backend/src/dto/vouchers.py @@ -0,0 +1,21 @@ +from datetime import datetime +from decimal import Decimal +from uuid import UUID + +from pydantic import BaseModel + + +class VoucherCreate(BaseModel): + """The data required to create a new Voucher.""" + + voucher_code: str + balance: Decimal + + +class Voucher(VoucherCreate): + """A voucher as stored in the database.""" + + id: UUID + created_at: datetime + updated_at: datetime + active: bool diff --git a/thallium-backend/src/orm/__init__.py b/thallium-backend/src/orm/__init__.py index ed803e8..cf70ddd 100644 --- a/thallium-backend/src/orm/__init__.py +++ b/thallium-backend/src/orm/__init__.py @@ -3,10 +3,12 @@ from .base import AuditBase, Base from .products import Product from .users import User +from .vouchers import Voucher __all__ = ( "AuditBase", "Base", "Product", "User", + "Voucher", ) diff --git a/thallium-backend/src/orm/base.py b/thallium-backend/src/orm/base.py index a1642c7..ec79d99 100644 --- a/thallium-backend/src/orm/base.py +++ b/thallium-backend/src/orm/base.py @@ -2,13 +2,13 @@ import re from datetime import datetime -from uuid import UUID, uuid4 +from uuid import UUID from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import MetaData -from sqlalchemy.sql import func +from sqlalchemy.sql import func, text from sqlalchemy.types import DateTime NAMING_CONVENTIONS = { @@ -35,7 +35,7 @@ class Base(AsyncAttrs, DeclarativeBase): class AuditBase: """Common columns for a table with UUID PK and datetime audit columns.""" - id: Mapped[UUID] = mapped_column(default=uuid4, primary_key=True) + id: Mapped[UUID] = mapped_column(server_default=text("gen_random_uuid()"), primary_key=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), diff --git a/thallium-backend/src/orm/users.py b/thallium-backend/src/orm/users.py index 065519a..8f78387 100644 --- a/thallium-backend/src/orm/users.py +++ b/thallium-backend/src/orm/users.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped from .base import AuditBase, Base @@ -8,5 +8,4 @@ class User(AuditBase, Base): __tablename__ = "users" - user_id: Mapped[int] = mapped_column(primary_key=True) - is_admin: Mapped[bool] + permissions: Mapped[int] diff --git a/thallium-backend/src/orm/vouchers.py b/thallium-backend/src/orm/vouchers.py new file mode 100644 index 0000000..2b3af77 --- /dev/null +++ b/thallium-backend/src/orm/vouchers.py @@ -0,0 +1,25 @@ +from decimal import Decimal + +from sqlalchemy import Index, text +from sqlalchemy.orm import Mapped, mapped_column + +from .base import AuditBase, Base + + +class Voucher(AuditBase, Base): + """A valid voucher in the database.""" + + __tablename__ = "vouchers" + + voucher_code: Mapped[str] = mapped_column() + active: Mapped[bool] = mapped_column(default=True) + balance: Mapped[Decimal] + + __table_args__ = ( + Index( + "ix_unique_active_voucher_code", + voucher_code, + unique=True, + postgresql_where=text("active"), + ), + ) diff --git a/thallium-backend/src/routes/__init__.py b/thallium-backend/src/routes/__init__.py index afa02af..2dc76c8 100644 --- a/thallium-backend/src/routes/__init__.py +++ b/thallium-backend/src/routes/__init__.py @@ -1,8 +1,12 @@ from fastapi import APIRouter from src.routes.debug import router as debug_router +from src.routes.login import router as login_router +from src.routes.vouchers import router as voucher_router from src.settings import CONFIG top_level_router = APIRouter() +top_level_router.include_router(login_router) +top_level_router.include_router(voucher_router) if CONFIG.debug: top_level_router.include_router(debug_router) diff --git a/thallium-backend/src/routes/debug.py b/thallium-backend/src/routes/debug.py index fac40d7..60c8643 100644 --- a/thallium-backend/src/routes/debug.py +++ b/thallium-backend/src/routes/debug.py @@ -1,10 +1,13 @@ import logging from fastapi import APIRouter +from sqlalchemy import select -from src.settings import PrintfulClient +from src.dto import Voucher +from src.orm import Voucher as DBVoucher +from src.settings import DBSession, PrintfulClient -router = APIRouter(tags=["debug"]) +router = APIRouter(tags=["debug"], prefix="/debug") log = logging.getLogger(__name__) @@ -34,3 +37,13 @@ async def get_v2_oauth_scopes(client: PrintfulClient) -> dict: """Return all templates in printful.""" resp = await client.get("/v2/oauth-scopes") return resp.json() + + [email protected]("/vouchers") +async def get_vouchers(db: DBSession, *, only_active: bool = True) -> list[Voucher]: + """Return all templates in printful.""" + stmt = select(DBVoucher) + if only_active: + stmt = stmt.where(DBVoucher.active) + res = await db.execute(stmt) + return res.scalars().all() diff --git a/thallium-backend/src/routes/login.py b/thallium-backend/src/routes/login.py new file mode 100644 index 0000000..7eeb2cf --- /dev/null +++ b/thallium-backend/src/routes/login.py @@ -0,0 +1,31 @@ +import logging + +from fastapi import APIRouter, HTTPException +from sqlalchemy import and_, select + +from src.auth import build_jwt +from src.dto import VoucherClaim, VoucherLogin +from src.orm import Voucher as DBVoucher +from src.settings import DBSession + +router = APIRouter(tags=["Login"]) +log = logging.getLogger(__name__) + + [email protected]("/voucher-login") +async def handle_voucher_login(login_payload: VoucherLogin, db: DBSession) -> VoucherClaim: + """Return a signed JWT if the given voucher is present in the database.""" + stmt = select(DBVoucher).where( + and_( + DBVoucher.voucher_code == login_payload.voucher_code, + DBVoucher.active, + ) + ) + voucher = await db.scalar(stmt) + if not voucher: + raise HTTPException(422, "Voucher not found") + + return VoucherClaim( + voucher_code=login_payload.voucher_code, + jwt=build_jwt(str(voucher.id), "voucher"), + ) diff --git a/thallium-backend/src/routes/vouchers.py b/thallium-backend/src/routes/vouchers.py new file mode 100644 index 0000000..97b9fef --- /dev/null +++ b/thallium-backend/src/routes/vouchers.py @@ -0,0 +1,24 @@ +import logging + +from fastapi import APIRouter, Depends, Request +from sqlalchemy import select + +from src.auth import TokenAuth +from src.dto import Voucher +from src.orm import Voucher as DBVoucher +from src.settings import DBSession + +router = APIRouter( + prefix="/vouchers", + tags=["Voucher users"], + dependencies=[Depends(TokenAuth(allow_vouchers=True))], +) +log = logging.getLogger(__name__) + + [email protected]("/me") +async def get_vouchers(request: Request, db: DBSession) -> Voucher | None: + """Get the voucher for the currently authenticated voucher id.""" + stmt = select(DBVoucher).where(DBVoucher.id == request.state.voucher_id) + res = await db.execute(stmt) + return res.scalars().one_or_none() diff --git a/thallium-backend/src/settings.py b/thallium-backend/src/settings.py index f6e144c..81a6335 100644 --- a/thallium-backend/src/settings.py +++ b/thallium-backend/src/settings.py @@ -23,6 +23,7 @@ class _Config( debug: bool = False git_sha: str = "development" + signing_key: pydantic.SecretStr database_url: pydantic.SecretStr super_admin_token: pydantic.SecretStr |