aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Adrian Garcia Badaracco <[email protected]>2022-06-13 07:33:06 -0700
committerGravatar GitHub <[email protected]>2022-06-13 07:33:06 -0700
commitcb3a4c2a8cdb2cf8a594eabb368284465613e1b8 (patch)
tree6b8c890f45ef4e92e6aec204647cd2b82ad31618 /backend
parentMerge pull request #168 from python-discord/dependabot/pip/pyjwt-2.4.0 (diff)
Replace BaseHTTPMiddleware with pure ASGI middleware
Diffstat (limited to 'backend')
-rw-r--r--backend/middleware.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/backend/middleware.py b/backend/middleware.py
index f74091b..a555b25 100644
--- a/backend/middleware.py
+++ b/backend/middleware.py
@@ -5,27 +5,35 @@ from motor.motor_asyncio import AsyncIOMotorClient
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
+from starlette.types import ASGIApp, Scope, Receive, Send
from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE
-class DatabaseMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: t.Callable) -> Response:
+class DatabaseMiddleware:
+
+ def __init__(self, app: ASGIApp) -> None:
+ self._app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
client: AsyncIOMotorClient = AsyncIOMotorClient(
DATABASE_URL,
ssl_cert_reqs=ssl.CERT_NONE
)
db = client[MONGO_DATABASE]
- request.state.db = db
- response = await call_next(request)
- return response
+ scope["state"].db = db
+ await self._app(scope, send, receive)
+
+class ProtectedDocsMiddleware:
-class ProtectedDocsMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: t.Callable) -> Response:
+ def __init__(self, app: ASGIApp) -> None:
+ self._app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ request = Request(scope)
if DOCS_PASSWORD and request.url.path.startswith("/docs"):
if request.cookies.get("docs_password") != DOCS_PASSWORD:
- return JSONResponse({"status": "unauthorized"}, status_code=403)
-
- resp = await call_next(request)
- return resp
+ await JSONResponse({"status": "unauthorized"}, status_code=403)(scope, receive, send)
+ return
+ await self._app(scope, receive, send)