aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--README.md1
-rw-r--r--backend/__init__.py20
-rw-r--r--backend/authentication/backend.py37
-rw-r--r--backend/authentication/user.py26
-rw-r--r--backend/constants.py2
-rw-r--r--backend/discord.py15
-rw-r--r--backend/routes/auth/authorize.py88
-rw-r--r--backend/routes/forms/form.py2
-rw-r--r--backend/routes/forms/submit.py50
-rw-r--r--backend/validation.py11
10 files changed, 196 insertions, 56 deletions
diff --git a/README.md b/README.md
index be0c8b9..59bdf17 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,7 @@ Create a `.env` file in the root with the following values inside it (each varia
- `OAUTH2_CLIENT_ID`: Client ID of Discord OAuth2 Application (see prerequisites).
- `OAUTH2_CLIENT_SECRET`: Client Secret of Discord OAuth2 Application (see prerequisites).
- `ALLOWED_URL`: Allowed origin for CORS middleware.
+- `PRODUCTION`: Set to False if running on localhost. Defaults to true.
#### Running
To start using the application, simply run `docker-compose up` in the repository root. You'll be able to access the application by visiting http://localhost:8000/
diff --git a/backend/__init__.py b/backend/__init__.py
index a3704a0..d56edfb 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -7,10 +7,20 @@ from starlette.middleware.cors import CORSMiddleware
from backend import constants
from backend.authentication import JWTAuthenticationBackend
-from backend.route_manager import create_route_map
from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware
+from backend.route_manager import create_route_map
from backend.validation import api
+ORIGINS = [
+ r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews
+ r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews
+]
+if not constants.PRODUCTION:
+ # Add localhost to allowed origins on non-production deployments
+ ORIGINS.append(r"(https?://localhost:\d{0,4})")
+
+ALLOW_ORIGIN_REGEX = "|".join(ORIGINS)
+
sentry_sdk.init(
dsn=constants.FORMS_BACKEND_DSN,
send_default_pii=True,
@@ -20,13 +30,13 @@ sentry_sdk.init(
middleware = [
Middleware(
CORSMiddleware,
- # TODO: Convert this into a RegEx that works for prod, netlify & previews
- allow_origins=["*"],
+ allow_origins=["https://forms.pythondiscord.com"],
+ allow_origin_regex=ALLOW_ORIGIN_REGEX,
allow_headers=[
- "Authorization",
"Content-Type"
],
- allow_methods=["*"]
+ allow_methods=["*"],
+ allow_credentials=True
),
Middleware(DatabaseMiddleware),
Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()),
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index f1d2ece..bdff796 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -1,6 +1,6 @@
-import jwt
import typing as t
+import jwt
from starlette import authentication
from starlette.requests import Request
@@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
"""Custom Starlette authentication backend for JWT."""
@staticmethod
- def get_token_from_header(header: str) -> str:
- """Parse JWT token from header value."""
+ def get_token_from_cookie(cookie: str) -> str:
+ """Parse JWT token from cookie."""
try:
- prefix, token = header.split()
+ prefix, token = cookie.split()
except ValueError:
raise authentication.AuthenticationError(
- "Unable to split prefix and token from Authorization header."
+ "Unable to split prefix and token from authorization cookie."
)
if prefix.upper() != "JWT":
raise authentication.AuthenticationError(
- f"Invalid Authorization header prefix '{prefix}'."
+ f"Invalid authorization cookie prefix '{prefix}'."
)
return token
@@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
self, request: Request
) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
"""Handles JWT authentication process."""
- if "Authorization" not in request.headers:
+ cookie = request.cookies.get("BackendToken")
+ if not cookie:
return None
- auth = request.headers["Authorization"]
- token = self.get_token_from_header(auth)
+ token = self.get_token_from_cookie(cookie)
try:
payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"])
@@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
scopes = ["authenticated"]
- if payload.get("admin") is True:
+ if not payload.get("token"):
+ raise authentication.AuthenticationError("Token is missing from JWT.")
+ if not payload.get("refresh"):
+ raise authentication.AuthenticationError(
+ "Refresh token is missing from JWT."
+ )
+
+ try:
+ user_details = payload.get("user_details")
+ if not user_details or not user_details.get("id"):
+ raise authentication.AuthenticationError("Improper user details.")
+ except Exception:
+ raise authentication.AuthenticationError("Could not parse user details.")
+
+ user = User(token, user_details)
+ if user.fetch_admin_status(request):
scopes.append("admin")
- return authentication.AuthCredentials(scopes), User(token, payload)
+ return authentication.AuthCredentials(scopes), user
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index f40c68c..52baa61 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,6 +1,11 @@
import typing as t
+import jwt
from starlette.authentication import BaseUser
+from starlette.requests import Request
+
+from backend.constants import SECRET_KEY
+from backend.discord import fetch_user_details
class User(BaseUser):
@@ -9,6 +14,7 @@ class User(BaseUser):
def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
self.token = token
self.payload = payload
+ self.admin = False
@property
def is_authenticated(self) -> bool:
@@ -23,3 +29,23 @@ class User(BaseUser):
@property
def discord_mention(self) -> str:
return f"<@{self.payload['id']}>"
+
+ @property
+ def decoded_token(self) -> dict[str, any]:
+ return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
+
+ def fetch_admin_status(self, request: Request) -> bool:
+ self.admin = request.state.db.admins.find_one(
+ {"_id": self.payload["id"]}
+ ) is not None
+
+ return self.admin
+
+ async def refresh_data(self) -> None:
+ """Fetches user data from discord, and updates the instance."""
+ self.payload = await fetch_user_details(self.decoded_token.get("token"))
+
+ updated_info = self.decoded_token
+ updated_info["user_details"] = self.payload
+
+ self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256")
diff --git a/backend/constants.py b/backend/constants.py
index 59b56e0..e1f4a5b 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -11,6 +11,8 @@ DATABASE_URL = os.getenv("DATABASE_URL")
MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")
SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")
+PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false"
+
OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET")
OAUTH2_REDIRECT_URI = os.getenv(
diff --git a/backend/discord.py b/backend/discord.py
index d6310b7..8cb602c 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -2,22 +2,27 @@
import httpx
from backend.constants import (
- OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET, OAUTH2_REDIRECT_URI
+ OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET
)
API_BASE_URL = "https://discord.com/api/v8"
-async def fetch_bearer_token(access_code: str) -> dict:
+async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
async with httpx.AsyncClient() as client:
data = {
"client_id": OAUTH2_CLIENT_ID,
"client_secret": OAUTH2_CLIENT_SECRET,
- "grant_type": "authorization_code",
- "code": access_code,
- "redirect_uri": OAUTH2_REDIRECT_URI
+ "redirect_uri": f"{redirect}/callback"
}
+ if refresh:
+ data["grant_type"] = "refresh_token"
+ data["refresh_token"] = code
+ else:
+ data["grant_type"] = "authorization_code"
+ data["code"] = code
+
r = await client.post(f"{API_BASE_URL}/oauth2/token", headers={
"Content-Type": "application/x-www-form-urlencoded"
}, data=data)
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 975936a..65709ab 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,17 +2,23 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import datetime
+from typing import Union
+
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
+from backend import constants
+from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.route import Route
from backend.validation import ErrorMessage, api
@@ -21,7 +27,47 @@ class AuthorizeRequest(BaseModel):
class AuthorizeResponse(BaseModel):
- token: str = Field(description="A JWT token containing the user information")
+ username: str = Field("Discord display name.")
+ expiry: str = Field("ISO formatted timestamp of expiry.")
+
+
+AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400)
+
+
+async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ """Post a bearer token to Discord, and return a JWT and username."""
+ interaction_start = datetime.datetime.now()
+
+ try:
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ AUTH_FAILURE.delete_cookie("BackendToken")
+ return AUTH_FAILURE
+
+ max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
+ token_expiry = interaction_start + max_age
+
+ data = {
+ "token": bearer_token["access_token"],
+ "refresh": bearer_token["refresh_token"],
+ "user_details": user_details,
+ "expiry": token_expiry.isoformat()
+ }
+
+ token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
+ user = User(token, user_details)
+
+ response = JSONResponse({
+ "username": user.display_name,
+ "expiry": token_expiry.isoformat()
+ })
+
+ response.set_cookie(
+ "BackendToken", f"JWT {token}",
+ secure=constants.PRODUCTION, httponly=True, samesite="strict",
+ max_age=bearer_token["expires_in"]
+ )
+ return response
class AuthorizeRoute(Route):
@@ -40,19 +86,35 @@ class AuthorizeRoute(Route):
async def post(self, request: Request) -> JSONResponse:
"""Generate an authorization token."""
data = await request.json()
-
try:
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
- return JSONResponse({
- "error": "auth_failure"
- }, status_code=400)
+ return AUTH_FAILURE
+
+ return await process_token(bearer_token)
- user_details["admin"] = await request.state.db.admins.find_one(
- {"_id": user_details["id"]}
- ) is not None
- token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
+class TokenRefreshRoute(Route):
+ """
+ Use the refresh code from a JWT to get a new token and generate a new JWT token.
+ """
+
+ name = "refresh"
+ path = "/refresh"
+
+ @requires(["authenticated"])
+ @api.validate(
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
+ async def post(self, request: Request) -> JSONResponse:
+ """Refresh an authorization token."""
+ try:
+ token = request.user.decoded_token.get("refresh")
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(token, url, refresh=True)
+ except httpx.HTTPStatusError:
+ return AUTH_FAILURE
- return JSONResponse({"token": token})
+ return await process_token(bearer_token)
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index dd1c83f..1c6e44a 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -27,7 +27,7 @@ class SingleForm(Route):
@api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
- admin = request.user.payload["admin"] if request.user.is_authenticated else False
+ admin = request.user.admin if request.user.is_authenticated else False
filters = {
"_id": request.path_params["form_id"]
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index b3a6afd..4224586 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -3,6 +3,7 @@ Submit a form.
"""
import binascii
+import datetime
import hashlib
import uuid
from typing import Any, Optional
@@ -15,11 +16,12 @@ from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import JSONResponse
-from backend.constants import FRONTEND_URL, FormFeatures, HCAPTCHA_API_SECRET
+from backend import constants
+from backend.authentication.user import User
from backend.models import Form, FormResponse
from backend.route import Route
from backend.routes.forms.unittesting import execute_unittest
-from backend.validation import AuthorizationHeaders, ErrorMessage, api
+from backend.validation import ErrorMessage, api
HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"
HCAPTCHA_HEADERS = {
@@ -52,13 +54,40 @@ class SubmitForm(Route):
HTTP_404=ErrorMessage,
HTTP_400=ErrorMessage
),
- headers=AuthorizationHeaders,
tags=["forms", "responses"]
)
async def post(self, request: Request) -> JSONResponse:
"""Submit a response to the form."""
- data = await request.json()
+ response = await self.submit(request)
+
+ # Silently try to update user data
+ try:
+ if hasattr(request.user, User.refresh_data.__name__):
+ old = request.user.token
+ await request.user.refresh_data()
+
+ if old != request.user.token:
+ try:
+ expiry = datetime.datetime.fromisoformat(
+ request.user.decoded_token.get("expiry")
+ )
+ except ValueError:
+ expiry = None
+
+ response.set_cookie(
+ "BackendToken", f"JWT {request.user.token}",
+ secure=constants.PRODUCTION, httponly=True, samesite="strict",
+ max_age=(expiry - datetime.datetime.now()).seconds
+ )
+ except httpx.HTTPStatusError:
+ pass
+
+ return response
+
+ async def submit(self, request: Request) -> JSONResponse:
+ """Helper method for handling submission logic."""
+ data = await request.json()
data["timestamp"] = None
if form := await request.state.db.forms.find_one(
@@ -69,7 +98,7 @@ class SubmitForm(Route):
response["id"] = str(uuid.uuid4())
response["form_id"] = form.id
- if FormFeatures.DISABLE_ANTISPAM.value not in form.features:
+ if constants.FormFeatures.DISABLE_ANTISPAM.value not in form.features:
ip_hash_ctx = hashlib.md5()
ip_hash_ctx.update(request.client.host.encode())
ip_hash = binascii.hexlify(ip_hash_ctx.digest())
@@ -79,7 +108,7 @@ class SubmitForm(Route):
async with httpx.AsyncClient() as client:
query_params = {
- "secret": HCAPTCHA_API_SECRET,
+ "secret": constants.HCAPTCHA_API_SECRET,
"response": data.get("captcha")
}
r = await client.post(
@@ -96,12 +125,13 @@ class SubmitForm(Route):
"captcha_pass": captcha_data["success"]
}
- if FormFeatures.REQUIRES_LOGIN.value in form.features:
+ if constants.FormFeatures.REQUIRES_LOGIN.value in form.features:
if request.user.is_authenticated:
response["user"] = request.user.payload
+ response["user"]["admin"] = request.user.admin
if (
- FormFeatures.COLLECT_EMAIL.value in form.features
+ constants.FormFeatures.COLLECT_EMAIL.value in form.features
and "email" not in response["user"]
):
return JSONResponse({
@@ -153,7 +183,7 @@ class SubmitForm(Route):
)
send_webhook = None
- if FormFeatures.WEBHOOK_ENABLED.value in form.features:
+ if constants.FormFeatures.WEBHOOK_ENABLED.value in form.features:
send_webhook = BackgroundTask(
self.send_submission_webhook,
form=form,
@@ -193,7 +223,7 @@ class SubmitForm(Route):
embed = {
"title": "New Form Response",
"description": f"{mention} submitted a response to `{form.name}`.",
- "url": f"{FRONTEND_URL}/path_to_view_form/{response.id}", # TODO: Enter Form View URL
+ "url": f"{constants.FRONTEND_URL}/path_to_view_form/{response.id}", # noqa # TODO: Enter Form View URL
"timestamp": response.timestamp,
"color": 7506394,
}
diff --git a/backend/validation.py b/backend/validation.py
index e696683..8771924 100644
--- a/backend/validation.py
+++ b/backend/validation.py
@@ -1,6 +1,5 @@
"""Utilities for providing API payload validation."""
-from typing import Optional
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree import SpecTree
@@ -18,13 +17,3 @@ class ErrorMessage(BaseModel):
class OkayResponse(BaseModel):
status: str = "ok"
-
-
-class AuthorizationHeaders(BaseModel):
- authorization: Optional[str] = Field(
- title="Authorization",
- description=(
- "The Authorization JWT token received from the "
- "authorize route in the format `JWT {token}`"
- )
- )