aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--backend/__init__.py7
-rw-r--r--backend/constants.py11
-rw-r--r--backend/middleware.py8
3 files changed, 21 insertions, 5 deletions
diff --git a/backend/__init__.py b/backend/__init__.py
index c2e1335..eb276c0 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -1,3 +1,6 @@
+import asyncio
+import os
+
import sentry_sdk
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from starlette.applications import Starlette
@@ -14,6 +17,10 @@ from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware
from backend.route_manager import create_route_map
from backend.validation import api
+# On Windows, the selector event loop is required for psycopg.
+if os.name == "nt":
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
ORIGINS = [
r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews
r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews
diff --git a/backend/constants.py b/backend/constants.py
index 1e55cd2..eb0c68f 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -4,17 +4,24 @@ from enum import Enum
from dotenv import load_dotenv
from redis.asyncio import Redis as _Redis
+from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
load_dotenv()
FRONTEND_URL = os.getenv("FRONTEND_URL", "https://forms.pythondiscord.com")
-DATABASE_URL = os.getenv("DATABASE_URL")
-MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")
SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")
REDIS_CLIENT = _Redis.from_url(os.getenv("REDIS_URL"), encoding="utf-8")
+MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")
+MONGO_DATABASE_URL = os.getenv("MONGO_DATABASE_URL")
+
+PSQL_DATABASE_URL = os.getenv("PSQL_DATABASE_URL")
+DATABASE_ECHO = os.getenv("DATABASE_ECHO", "false").lower() == "true"
+_DB_ENGINE = create_async_engine(PSQL_DATABASE_URL, echo=DATABASE_ECHO)
+DB_SESSION_MAKER = async_sessionmaker(_DB_ENGINE)
+
PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false"
PRODUCTION_URL = "https://forms.pythondiscord.com"
diff --git a/backend/middleware.py b/backend/middleware.py
index 0b08859..5b36473 100644
--- a/backend/middleware.py
+++ b/backend/middleware.py
@@ -3,7 +3,7 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
-from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE
+from backend.constants import DB_SESSION_MAKER, DOCS_PASSWORD, MONGO_DATABASE, MONGO_DATABASE_URL
class DatabaseMiddleware:
@@ -12,12 +12,14 @@ class DatabaseMiddleware:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
client: AsyncIOMotorClient = AsyncIOMotorClient(
- DATABASE_URL,
+ MONGO_DATABASE_URL,
tlsAllowInvalidCertificates=True,
)
db = client[MONGO_DATABASE]
Request(scope).state.db = db
- await self._app(scope, receive, send)
+ async with DB_SESSION_MAKER() as session, session.begin():
+ Request(scope).state.psql_db = session
+ await self._app(scope, receive, send)
class ProtectedDocsMiddleware: