aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Hassan Abouelela <[email protected]>2021-03-07 04:17:54 +0300
committerGravatar GitHub <[email protected]>2021-03-07 04:17:54 +0300
commit9b47639ca31a14dfce59b2e8a395cea43fea91d2 (patch)
tree5a4392bdc8279b6e23d36441017562f8bc804847 /backend
parentBump httpx from 0.16.1 to 0.17.0 (diff)
parentCorrects Token Cookie Domain (diff)
Merge branch 'main' into dependabot/pip/httpx-0.17.0
Diffstat (limited to '')
-rw-r--r--backend/__init__.py21
-rw-r--r--backend/authentication/backend.py37
-rw-r--r--backend/authentication/user.py26
-rw-r--r--backend/constants.py8
-rw-r--r--backend/discord.py15
-rw-r--r--backend/routes/auth/authorize.py121
-rw-r--r--backend/routes/forms/form.py2
-rw-r--r--backend/routes/forms/submit.py48
-rw-r--r--backend/validation.py11
9 files changed, 229 insertions, 60 deletions
diff --git a/backend/__init__.py b/backend/__init__.py
index a3704a0..220b457 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -7,10 +7,21 @@ from starlette.middleware.cors import CORSMiddleware
from backend import constants
from backend.authentication import JWTAuthenticationBackend
-from backend.route_manager import create_route_map
from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware
+from backend.route_manager import create_route_map
from backend.validation import api
+ORIGINS = [
+ r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews
+ r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews
+]
+
+if not constants.PRODUCTION:
+ # Allow all hosts on non-production deployments
+ ORIGINS.append(r"(.*)")
+
+ALLOW_ORIGIN_REGEX = "|".join(ORIGINS)
+
sentry_sdk.init(
dsn=constants.FORMS_BACKEND_DSN,
send_default_pii=True,
@@ -20,13 +31,13 @@ sentry_sdk.init(
middleware = [
Middleware(
CORSMiddleware,
- # TODO: Convert this into a RegEx that works for prod, netlify & previews
- allow_origins=["*"],
+ allow_origins=["https://forms.pythondiscord.com"],
+ allow_origin_regex=ALLOW_ORIGIN_REGEX,
allow_headers=[
- "Authorization",
"Content-Type"
],
- allow_methods=["*"]
+ allow_methods=["*"],
+ allow_credentials=True
),
Middleware(DatabaseMiddleware),
Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()),
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index f1d2ece..c7590e9 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -1,6 +1,6 @@
-import jwt
import typing as t
+import jwt
from starlette import authentication
from starlette.requests import Request
@@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
"""Custom Starlette authentication backend for JWT."""
@staticmethod
- def get_token_from_header(header: str) -> str:
- """Parse JWT token from header value."""
+ def get_token_from_cookie(cookie: str) -> str:
+ """Parse JWT token from cookie."""
try:
- prefix, token = header.split()
+ prefix, token = cookie.split()
except ValueError:
raise authentication.AuthenticationError(
- "Unable to split prefix and token from Authorization header."
+ "Unable to split prefix and token from authorization cookie."
)
if prefix.upper() != "JWT":
raise authentication.AuthenticationError(
- f"Invalid Authorization header prefix '{prefix}'."
+ f"Invalid authorization cookie prefix '{prefix}'."
)
return token
@@ -33,11 +33,11 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
self, request: Request
) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
"""Handles JWT authentication process."""
- if "Authorization" not in request.headers:
+ cookie = request.cookies.get("token")
+ if not cookie:
return None
- auth = request.headers["Authorization"]
- token = self.get_token_from_header(auth)
+ token = self.get_token_from_cookie(cookie)
try:
payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"])
@@ -46,7 +46,22 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
scopes = ["authenticated"]
- if payload.get("admin") is True:
+ if not payload.get("token"):
+ raise authentication.AuthenticationError("Token is missing from JWT.")
+ if not payload.get("refresh"):
+ raise authentication.AuthenticationError(
+ "Refresh token is missing from JWT."
+ )
+
+ try:
+ user_details = payload.get("user_details")
+ if not user_details or not user_details.get("id"):
+ raise authentication.AuthenticationError("Improper user details.")
+ except Exception:
+ raise authentication.AuthenticationError("Could not parse user details.")
+
+ user = User(token, user_details)
+ if await user.fetch_admin_status(request):
scopes.append("admin")
- return authentication.AuthCredentials(scopes), User(token, payload)
+ return authentication.AuthCredentials(scopes), user
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index f40c68c..857c2ed 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,6 +1,11 @@
import typing as t
+import jwt
from starlette.authentication import BaseUser
+from starlette.requests import Request
+
+from backend.constants import SECRET_KEY
+from backend.discord import fetch_user_details
class User(BaseUser):
@@ -9,6 +14,7 @@ class User(BaseUser):
def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
self.token = token
self.payload = payload
+ self.admin = False
@property
def is_authenticated(self) -> bool:
@@ -23,3 +29,23 @@ class User(BaseUser):
@property
def discord_mention(self) -> str:
return f"<@{self.payload['id']}>"
+
+ @property
+ def decoded_token(self) -> dict[str, any]:
+ return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
+
+ async def fetch_admin_status(self, request: Request) -> bool:
+ self.admin = await request.state.db.admins.find_one(
+ {"_id": self.payload["id"]}
+ ) is not None
+
+ return self.admin
+
+ async def refresh_data(self) -> None:
+ """Fetches user data from discord, and updates the instance."""
+ self.payload = await fetch_user_details(self.decoded_token.get("token"))
+
+ updated_info = self.decoded_token
+ updated_info["user_details"] = self.payload
+
+ self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256")
diff --git a/backend/constants.py b/backend/constants.py
index 59b56e0..4bb7fd1 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -1,8 +1,9 @@
-from dotenv import load_dotenv
-import os
import binascii
+import os
from enum import Enum
+from dotenv import load_dotenv
+
load_dotenv()
@@ -11,6 +12,9 @@ DATABASE_URL = os.getenv("DATABASE_URL")
MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")
SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")
+PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false"
+PRODUCTION_URL = "https://forms.pythondiscord.com/"
+
OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET")
OAUTH2_REDIRECT_URI = os.getenv(
diff --git a/backend/discord.py b/backend/discord.py
index d6310b7..8cb602c 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -2,22 +2,27 @@
import httpx
from backend.constants import (
- OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET, OAUTH2_REDIRECT_URI
+ OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET
)
API_BASE_URL = "https://discord.com/api/v8"
-async def fetch_bearer_token(access_code: str) -> dict:
+async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
async with httpx.AsyncClient() as client:
data = {
"client_id": OAUTH2_CLIENT_ID,
"client_secret": OAUTH2_CLIENT_SECRET,
- "grant_type": "authorization_code",
- "code": access_code,
- "redirect_uri": OAUTH2_REDIRECT_URI
+ "redirect_uri": f"{redirect}/callback"
}
+ if refresh:
+ data["grant_type"] = "refresh_token"
+ data["refresh_token"] = code
+ else:
+ data["grant_type"] = "authorization_code"
+ data["code"] = code
+
r = await client.post(f"{API_BASE_URL}/oauth2/token", headers={
"Content-Type": "application/x-www-form-urlencoded"
}, data=data)
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 975936a..d4587f0 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,26 +2,101 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import datetime
+from typing import Union
+
import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
+from starlette import responses
+from starlette.authentication import requires
from starlette.requests import Request
-from starlette.responses import JSONResponse
+from backend import constants
+from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.route import Route
from backend.validation import ErrorMessage, api
+AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)
+
class AuthorizeRequest(BaseModel):
token: str = Field(description="The access token received from Discord.")
class AuthorizeResponse(BaseModel):
- token: str = Field(description="A JWT token containing the user information")
+ username: str = Field("Discord display name.")
+ expiry: str = Field("ISO formatted timestamp of expiry.")
+
+
+async def process_token(
+ bearer_token: dict,
+ request: Request
+) -> Union[AuthorizeResponse, AUTH_FAILURE]:
+ """Post a bearer token to Discord, and return a JWT and username."""
+ interaction_start = datetime.datetime.now()
+
+ try:
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ AUTH_FAILURE.delete_cookie("token")
+ return AUTH_FAILURE
+
+ max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
+ token_expiry = interaction_start + max_age
+
+ data = {
+ "token": bearer_token["access_token"],
+ "refresh": bearer_token["refresh_token"],
+ "user_details": user_details,
+ "expiry": token_expiry.isoformat()
+ }
+
+ token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
+ user = User(token, user_details)
+
+ response = responses.JSONResponse({
+ "username": user.display_name,
+ "expiry": token_expiry.isoformat()
+ })
+
+ await set_response_token(response, request, token, bearer_token["expires_in"])
+ return response
+
+
+async def set_response_token(
+ response: responses.Response,
+ request: Request,
+ new_token: str,
+ expiry: int
+) -> None:
+ """Helper that handles logic for updating a token in a set-cookie response."""
+ origin_url = request.headers.get("origin")
+
+ if origin_url == constants.PRODUCTION_URL:
+ domain = request.url.netloc
+ samesite = "strict"
+
+ elif not constants.PRODUCTION:
+ domain = None
+ samesite = "strict"
+
+ else:
+ domain = request.url.netloc
+ samesite = "None"
+
+ response.set_cookie(
+ "token", f"JWT {new_token}",
+ secure=constants.PRODUCTION,
+ httponly=True,
+ samesite=samesite,
+ domain=domain,
+ max_age=expiry
+ )
class AuthorizeRoute(Route):
@@ -37,22 +112,38 @@ class AuthorizeRoute(Route):
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
- async def post(self, request: Request) -> JSONResponse:
+ async def post(self, request: Request) -> responses.JSONResponse:
"""Generate an authorization token."""
data = await request.json()
-
try:
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
- return JSONResponse({
- "error": "auth_failure"
- }, status_code=400)
+ return AUTH_FAILURE
- user_details["admin"] = await request.state.db.admins.find_one(
- {"_id": user_details["id"]}
- ) is not None
+ return await process_token(bearer_token, request)
+
+
+class TokenRefreshRoute(Route):
+ """
+ Use the refresh code from a JWT to get a new token and generate a new JWT token.
+ """
- token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
+ name = "refresh"
+ path = "/refresh"
+
+ @requires(["authenticated"])
+ @api.validate(
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
+ async def post(self, request: Request) -> responses.JSONResponse:
+ """Refresh an authorization token."""
+ try:
+ token = request.user.decoded_token.get("refresh")
+ url = request.headers.get("origin")
+ bearer_token = await fetch_bearer_token(token, url, refresh=True)
+ except httpx.HTTPStatusError:
+ return AUTH_FAILURE
- return JSONResponse({"token": token})
+ return await process_token(bearer_token, request)
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index dd1c83f..1c6e44a 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -27,7 +27,7 @@ class SingleForm(Route):
@api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
- admin = request.user.payload["admin"] if request.user.is_authenticated else False
+ admin = request.user.admin if request.user.is_authenticated else False
filters = {
"_id": request.path_params["form_id"]
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index b3a6afd..2624c98 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -3,6 +3,7 @@ Submit a form.
"""
import binascii
+import datetime
import hashlib
import uuid
from typing import Any, Optional
@@ -15,11 +16,13 @@ from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import JSONResponse
-from backend.constants import FRONTEND_URL, FormFeatures, HCAPTCHA_API_SECRET
+from backend import constants
+from backend.authentication.user import User
from backend.models import Form, FormResponse
from backend.route import Route
+from backend.routes.auth.authorize import set_response_token
from backend.routes.forms.unittesting import execute_unittest
-from backend.validation import AuthorizationHeaders, ErrorMessage, api
+from backend.validation import ErrorMessage, api
HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"
HCAPTCHA_HEADERS = {
@@ -52,13 +55,37 @@ class SubmitForm(Route):
HTTP_404=ErrorMessage,
HTTP_400=ErrorMessage
),
- headers=AuthorizationHeaders,
tags=["forms", "responses"]
)
async def post(self, request: Request) -> JSONResponse:
"""Submit a response to the form."""
- data = await request.json()
+ response = await self.submit(request)
+
+ # Silently try to update user data
+ try:
+ if hasattr(request.user, User.refresh_data.__name__):
+ old = request.user.token
+ await request.user.refresh_data()
+
+ if old != request.user.token:
+ try:
+ expiry = datetime.datetime.fromisoformat(
+ request.user.decoded_token.get("expiry")
+ )
+ except ValueError:
+ expiry = None
+ expiry_seconds = (expiry - datetime.datetime.now()).seconds
+ await set_response_token(response, request, request.user.token, expiry_seconds)
+
+ except httpx.HTTPStatusError:
+ pass
+
+ return response
+
+ async def submit(self, request: Request) -> JSONResponse:
+ """Helper method for handling submission logic."""
+ data = await request.json()
data["timestamp"] = None
if form := await request.state.db.forms.find_one(
@@ -69,7 +96,7 @@ class SubmitForm(Route):
response["id"] = str(uuid.uuid4())
response["form_id"] = form.id
- if FormFeatures.DISABLE_ANTISPAM.value not in form.features:
+ if constants.FormFeatures.DISABLE_ANTISPAM.value not in form.features:
ip_hash_ctx = hashlib.md5()
ip_hash_ctx.update(request.client.host.encode())
ip_hash = binascii.hexlify(ip_hash_ctx.digest())
@@ -79,7 +106,7 @@ class SubmitForm(Route):
async with httpx.AsyncClient() as client:
query_params = {
- "secret": HCAPTCHA_API_SECRET,
+ "secret": constants.HCAPTCHA_API_SECRET,
"response": data.get("captcha")
}
r = await client.post(
@@ -96,12 +123,13 @@ class SubmitForm(Route):
"captcha_pass": captcha_data["success"]
}
- if FormFeatures.REQUIRES_LOGIN.value in form.features:
+ if constants.FormFeatures.REQUIRES_LOGIN.value in form.features:
if request.user.is_authenticated:
response["user"] = request.user.payload
+ response["user"]["admin"] = request.user.admin
if (
- FormFeatures.COLLECT_EMAIL.value in form.features
+ constants.FormFeatures.COLLECT_EMAIL.value in form.features
and "email" not in response["user"]
):
return JSONResponse({
@@ -153,7 +181,7 @@ class SubmitForm(Route):
)
send_webhook = None
- if FormFeatures.WEBHOOK_ENABLED.value in form.features:
+ if constants.FormFeatures.WEBHOOK_ENABLED.value in form.features:
send_webhook = BackgroundTask(
self.send_submission_webhook,
form=form,
@@ -193,7 +221,7 @@ class SubmitForm(Route):
embed = {
"title": "New Form Response",
"description": f"{mention} submitted a response to `{form.name}`.",
- "url": f"{FRONTEND_URL}/path_to_view_form/{response.id}", # TODO: Enter Form View URL
+ "url": f"{constants.FRONTEND_URL}/path_to_view_form/{response.id}", # noqa # TODO: Enter Form View URL
"timestamp": response.timestamp,
"color": 7506394,
}
diff --git a/backend/validation.py b/backend/validation.py
index e696683..8771924 100644
--- a/backend/validation.py
+++ b/backend/validation.py
@@ -1,6 +1,5 @@
"""Utilities for providing API payload validation."""
-from typing import Optional
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree import SpecTree
@@ -18,13 +17,3 @@ class ErrorMessage(BaseModel):
class OkayResponse(BaseModel):
status: str = "ok"
-
-
-class AuthorizationHeaders(BaseModel):
- authorization: Optional[str] = Field(
- title="Authorization",
- description=(
- "The Authorization JWT token received from the "
- "authorize route in the format `JWT {token}`"
- )
- )