diff options
Diffstat (limited to 'thallium-backend/src/routes/debug.py')
| -rw-r--r-- | thallium-backend/src/routes/debug.py | 43 | 
1 files changed, 40 insertions, 3 deletions
diff --git a/thallium-backend/src/routes/debug.py b/thallium-backend/src/routes/debug.py index 4d81741..691ac3f 100644 --- a/thallium-backend/src/routes/debug.py +++ b/thallium-backend/src/routes/debug.py @@ -1,15 +1,19 @@  import logging +from datetime import UTC, datetime -from fastapi import APIRouter +import argon2 +from fastapi import APIRouter, HTTPException  from sqlalchemy import select +from sqlalchemy.exc import IntegrityError  from src.auth import build_jwt -from src.dto import Voucher -from src.orm import Voucher as DBVoucher +from src.dto import UserPermission, Voucher +from src.orm import User as DBUser, Voucher as DBVoucher  from src.settings import DBSession, PrintfulClient  router = APIRouter(tags=["debug"], prefix="/debug")  log = logging.getLogger(__name__) +ph = argon2.PasswordHasher()  @router.get("/templates") @@ -54,3 +58,36 @@ async def get_vouchers(db: DBSession, *, only_active: bool = True) -> list[Vouch  async def get_user_jwt(user_id: str) -> str:      """Return the user_id's JWT."""      return build_jwt(user_id, "user") + + [email protected]("/user") +async def create_user(  # noqa: PLR0913 +    db: DBSession, +    username: str, +    password: str, +    *, +    require_password_change: bool = True, +    password_reset_code: str | None = None, +    active: bool = True, +    permissions: int = ~UserPermission(0), +) -> dict: +    """Create a user with the given username & pass.""" +    db_user = DBUser( +        username=username, +        password_hash=ph.hash(password), +        permissions=permissions, +        require_password_change=require_password_change, +        password_reset_code=password_reset_code, +        active=active, +        password_set_at=datetime.now(UTC), +    ) +    db.add(db_user) + +    try: +        await db.flush() +    except IntegrityError as e: +        raise HTTPException(400, detail=str(e)) from e + +    stmt = select(DBUser).where(DBUser.username == username) +    db_user = await db.scalar(stmt) +    return {key: val for key, val in db_user.__dict__.items() if not key.startswith("_")}  |