From 3b5d3e4424c23c8ab27e2c469e3d10af860bbf2e Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 8 Jul 2024 21:35:06 +0100 Subject: Update middleware to use SQLA to create db sessions --- backend/__init__.py | 7 +++++++ backend/constants.py | 11 +++++++++-- backend/middleware.py | 8 +++++--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/backend/__init__.py b/backend/__init__.py index c2e1335..eb276c0 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,3 +1,6 @@ +import asyncio +import os + import sentry_sdk from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from starlette.applications import Starlette @@ -14,6 +17,10 @@ from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware from backend.route_manager import create_route_map from backend.validation import api +# On Windows, the selector event loop is required for psycopg. +if os.name == "nt": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + ORIGINS = [ r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews diff --git a/backend/constants.py b/backend/constants.py index 1e55cd2..eb0c68f 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -4,17 +4,24 @@ from enum import Enum from dotenv import load_dotenv from redis.asyncio import Redis as _Redis +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine load_dotenv() FRONTEND_URL = os.getenv("FRONTEND_URL", "https://forms.pythondiscord.com") -DATABASE_URL = os.getenv("DATABASE_URL") -MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms") SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval") REDIS_CLIENT = _Redis.from_url(os.getenv("REDIS_URL"), encoding="utf-8") +MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms") +MONGO_DATABASE_URL = os.getenv("MONGO_DATABASE_URL") + +PSQL_DATABASE_URL = os.getenv("PSQL_DATABASE_URL") +DATABASE_ECHO = os.getenv("DATABASE_ECHO", "false").lower() == "true" +_DB_ENGINE = create_async_engine(PSQL_DATABASE_URL, echo=DATABASE_ECHO) +DB_SESSION_MAKER = async_sessionmaker(_DB_ENGINE) + PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false" PRODUCTION_URL = "https://forms.pythondiscord.com" diff --git a/backend/middleware.py b/backend/middleware.py index 0b08859..5b36473 100644 --- a/backend/middleware.py +++ b/backend/middleware.py @@ -3,7 +3,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.types import ASGIApp, Receive, Scope, Send -from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE +from backend.constants import DB_SESSION_MAKER, DOCS_PASSWORD, MONGO_DATABASE, MONGO_DATABASE_URL class DatabaseMiddleware: @@ -12,12 +12,14 @@ class DatabaseMiddleware: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: client: AsyncIOMotorClient = AsyncIOMotorClient( - DATABASE_URL, + MONGO_DATABASE_URL, tlsAllowInvalidCertificates=True, ) db = client[MONGO_DATABASE] Request(scope).state.db = db - await self._app(scope, receive, send) + async with DB_SESSION_MAKER() as session, session.begin(): + Request(scope).state.psql_db = session + await self._app(scope, receive, send) class ProtectedDocsMiddleware: -- cgit v1.2.3