aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--backend/__init__.py7
-rw-r--r--backend/constants.py2
-rw-r--r--backend/middleware.py14
-rw-r--r--backend/models/__init__.py14
-rw-r--r--backend/models/form.py4
-rw-r--r--backend/models/form_response.py4
-rw-r--r--backend/routes/auth/authorize.py28
-rw-r--r--backend/routes/forms/discover.py6
-rw-r--r--backend/routes/forms/form.py7
-rw-r--r--backend/routes/forms/index.py19
-rw-r--r--backend/routes/forms/response.py14
-rw-r--r--backend/routes/forms/responses.py8
-rw-r--r--backend/routes/forms/submit.py25
-rw-r--r--backend/routes/index.py18
-rw-r--r--backend/validation.py30
-rw-r--r--poetry.lock25
-rw-r--r--pyproject.toml1
18 files changed, 208 insertions, 21 deletions
diff --git a/.gitignore b/.gitignore
index ce78b36..f2bc6c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,3 +130,6 @@ dmypy.json
# IntelliJ project settings
.idea/
+
+# VSCode Settings
+.vscode/settings.json
diff --git a/backend/__init__.py b/backend/__init__.py
index 770107b..5c6328b 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -5,7 +5,8 @@ from starlette.middleware.cors import CORSMiddleware
from backend.authentication import JWTAuthenticationBackend
from backend.route_manager import create_route_map
-from backend.middleware import DatabaseMiddleware
+from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware
+from backend.validation import api
middleware = [
Middleware(
@@ -19,7 +20,9 @@ middleware = [
allow_methods=["*"]
),
Middleware(DatabaseMiddleware),
- Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend())
+ Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()),
+ Middleware(ProtectedDocsMiddleware)
]
app = Starlette(routes=create_route_map(), middleware=middleware)
+api.register(app)
diff --git a/backend/constants.py b/backend/constants.py
index fdf7092..4218bff 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -16,6 +16,8 @@ OAUTH2_REDIRECT_URI = os.getenv(
"https://forms.pythondiscord.com/callback"
)
+DOCS_PASSWORD = os.getenv("DOCS_PASSWORD")
+
SECRET_KEY = os.getenv("SECRET_KEY", binascii.hexlify(os.urandom(30)).decode())
HCAPTCHA_API_SECRET = os.getenv("HCAPTCHA_API_SECRET")
diff --git a/backend/middleware.py b/backend/middleware.py
index 2267a9a..f74091b 100644
--- a/backend/middleware.py
+++ b/backend/middleware.py
@@ -4,9 +4,9 @@ import ssl
from motor.motor_asyncio import AsyncIOMotorClient
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
-from starlette.responses import Response
+from starlette.responses import JSONResponse, Response
-from backend.constants import DATABASE_URL, MONGO_DATABASE
+from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE
class DatabaseMiddleware(BaseHTTPMiddleware):
@@ -19,3 +19,13 @@ class DatabaseMiddleware(BaseHTTPMiddleware):
request.state.db = db
response = await call_next(request)
return response
+
+
+class ProtectedDocsMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request: Request, call_next: t.Callable) -> Response:
+ if DOCS_PASSWORD and request.url.path.startswith("/docs"):
+ if request.cookies.get("docs_password") != DOCS_PASSWORD:
+ return JSONResponse({"status": "unauthorized"}, status_code=403)
+
+ resp = await call_next(request)
+ return resp
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index 98fa619..29ccb24 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -1,7 +1,15 @@
from .antispam import AntiSpam
from .discord_user import DiscordUser
-from .form import Form
-from .form_response import FormResponse
+from .form import Form, FormList
+from .form_response import FormResponse, ResponseList
from .question import Question
-__all__ = ["AntiSpam", "DiscordUser", "Form", "FormResponse", "Question"]
+__all__ = [
+ "AntiSpam",
+ "DiscordUser",
+ "Form",
+ "FormResponse",
+ "Question",
+ "FormList",
+ "ResponseList"
+]
diff --git a/backend/models/form.py b/backend/models/form.py
index cb58065..9d8ffaa 100644
--- a/backend/models/form.py
+++ b/backend/models/form.py
@@ -52,3 +52,7 @@ class Form(BaseModel):
returned_data = data
return returned_data
+
+
+class FormList(BaseModel):
+ __root__: t.List[Form]
diff --git a/backend/models/form_response.py b/backend/models/form_response.py
index 0da7b15..933f5e4 100644
--- a/backend/models/form_response.py
+++ b/backend/models/form_response.py
@@ -30,3 +30,7 @@ class FormResponse(BaseModel):
class Config:
allow_population_by_field_name = True
+
+
+class ResponseList(BaseModel):
+ __root__: t.List[FormResponse]
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index 41c0a0b..2509109 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -2,13 +2,26 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""
+import httpx
import jwt
+from pydantic.fields import Field
+from pydantic.main import BaseModel
+from spectree.response import Response
from starlette.requests import Request
from starlette.responses import JSONResponse
from backend.constants import SECRET_KEY
from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.validation import ErrorMessage, api
+
+
+class AuthorizeRequest(BaseModel):
+ token: str = Field(description="The access token received from Discord.")
+
+
+class AuthorizeResponse(BaseModel):
+ token: str = Field(description="A JWT token containing the user information")
class AuthorizeRoute(Route):
@@ -19,11 +32,22 @@ class AuthorizeRoute(Route):
name = "authorize"
path = "/authorize"
+ @api.validate(
+ json=AuthorizeRequest,
+ resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
+ tags=["auth"]
+ )
async def post(self, request: Request) -> JSONResponse:
+ """Generate an authorization token."""
data = await request.json()
- bearer_token = await fetch_bearer_token(data["token"])
- user_details = await fetch_user_details(bearer_token["access_token"])
+ try:
+ bearer_token = await fetch_bearer_token(data["token"])
+ user_details = await fetch_user_details(bearer_token["access_token"])
+ except httpx.HTTPStatusError:
+ return JSONResponse({
+ "error": "auth_failure"
+ }, status_code=400)
user_details["admin"] = await request.state.db.admins.find_one(
{"_id": user_details["id"]}
diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py
index bba6fd4..9400f05 100644
--- a/backend/routes/forms/discover.py
+++ b/backend/routes/forms/discover.py
@@ -1,11 +1,13 @@
"""
Return a list of all publicly discoverable forms to unauthenticated users.
"""
+from spectree.response import Response
from starlette.requests import Request
from starlette.responses import JSONResponse
-from backend.models import Form
+from backend.models import Form, FormList
from backend.route import Route
+from backend.validation import api
class DiscoverableFormsList(Route):
@@ -16,7 +18,9 @@ class DiscoverableFormsList(Route):
name = "discoverable_forms_list"
path = "/discoverable"
+ @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
+ """List all discoverable forms that should be shown on the homepage."""
forms = []
cursor = request.state.db.forms.find({"features": "DISCOVERABLE"})
diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py
index 8fdd8a2..c953135 100644
--- a/backend/routes/forms/form.py
+++ b/backend/routes/forms/form.py
@@ -1,12 +1,14 @@
"""
Returns or deletes a single form given an ID.
"""
+from spectree.response import Response
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
from backend.route import Route
from backend.models import Form
+from backend.validation import OkayResponse, api, ErrorMessage
class SingleForm(Route):
@@ -19,6 +21,7 @@ class SingleForm(Route):
name = "form"
path = "/{form_id:str}"
+ @api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
admin = request.user.payload["admin"] if request.user.is_authenticated else False # noqa
@@ -37,6 +40,10 @@ class SingleForm(Route):
return JSONResponse({"error": "not_found"}, status_code=404)
@requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=OkayResponse, HTTP_404=ErrorMessage),
+ tags=["forms"]
+ )
async def delete(self, request: Request) -> JSONResponse:
"""Deletes form by ID."""
if not await request.state.db.forms.find_one(
diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py
index bb2b299..d1373e4 100644
--- a/backend/routes/forms/index.py
+++ b/backend/routes/forms/index.py
@@ -1,13 +1,14 @@
"""
Return a list of all forms to authenticated users.
"""
-from pydantic import ValidationError
+from spectree.response import Response
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
from backend.route import Route
-from backend.models import Form
+from backend.models import Form, FormList
+from backend.validation import ErrorMessage, OkayResponse, api
class FormsList(Route):
@@ -19,7 +20,9 @@ class FormsList(Route):
path = "/"
@requires(["authenticated", "admin"])
+ @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
+ """Return a list of all forms to authenticated users."""
forms = []
cursor = request.state.db.forms.find()
@@ -34,12 +37,16 @@ class FormsList(Route):
)
@requires(["authenticated", "admin"])
+ @api.validate(
+ json=Form,
+ resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage),
+ tags=["forms"]
+ )
async def post(self, request: Request) -> JSONResponse:
+ """Create a new form."""
form_data = await request.json()
- try:
- form = Form(**form_data)
- except ValidationError as e:
- return JSONResponse(e.errors())
+
+ form = Form(**form_data)
if await request.state.db.forms.find_one({"_id": form.id}):
return JSONResponse({
diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py
index acaa647..d8d8d17 100644
--- a/backend/routes/forms/response.py
+++ b/backend/routes/forms/response.py
@@ -1,12 +1,14 @@
"""
Returns or deletes form response by ID.
"""
+from spectree import Response as RouteResponse
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
from backend.models import FormResponse
from backend.route import Route
+from backend.validation import ErrorMessage, OkayResponse, api
class Response(Route):
@@ -16,8 +18,12 @@ class Response(Route):
path = "/{form_id:str}/responses/{response_id:str}"
@requires(["authenticated", "admin"])
+ @api.validate(
+ resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage),
+ tags=["forms", "responses"]
+ )
async def get(self, request: Request) -> JSONResponse:
- """Returns single form response by ID."""
+ """Return a single form response by ID."""
if raw_response := await request.state.db.responses.find_one(
{
"_id": request.path_params["response_id"],
@@ -30,8 +36,12 @@ class Response(Route):
return JSONResponse({"error": "not_found"}, status_code=404)
@requires(["authenticated", "admin"])
+ @api.validate(
+ resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage),
+ tags=["forms", "responses"]
+ )
async def delete(self, request: Request) -> JSONResponse:
- """Deletes form response by ID."""
+ """Delete a form response by ID."""
if not await request.state.db.responses.find_one(
{
"_id": request.path_params["response_id"],
diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py
index ee8ab84..54da246 100644
--- a/backend/routes/forms/responses.py
+++ b/backend/routes/forms/responses.py
@@ -1,12 +1,14 @@
"""
Returns all form responses by form ID.
"""
+from spectree import Response
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse
-from backend.models import FormResponse
+from backend.models import FormResponse, ResponseList
from backend.route import Route
+from backend.validation import api, ErrorMessage
class Responses(Route):
@@ -18,6 +20,10 @@ class Responses(Route):
path = "/{form_id:str}/responses"
@requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=ResponseList, HTTP_404=ErrorMessage),
+ tags=["forms", "responses"]
+ )
async def get(self, request: Request) -> JSONResponse:
"""Returns all form responses by form ID."""
if not await request.state.db.forms.find_one(
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 48ae4f6..dfa0de6 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -4,11 +4,14 @@ Submit a form.
import binascii
import hashlib
+from typing import Any, Optional
import uuid
import httpx
+from pydantic.main import BaseModel
import pydnsbl
from pydantic import ValidationError
+from spectree import Response
from starlette.requests import Request
from starlette.responses import JSONResponse
@@ -16,6 +19,7 @@ from starlette.responses import JSONResponse
from backend.constants import HCAPTCHA_API_SECRET, FormFeatures
from backend.models import Form, FormResponse
from backend.route import Route
+from backend.validation import AuthorizationHeaders, ErrorMessage, api
HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify"
HCAPTCHA_HEADERS = {
@@ -23,6 +27,16 @@ HCAPTCHA_HEADERS = {
}
+class SubmissionResponse(BaseModel):
+ form: Form
+ response: FormResponse
+
+
+class PartialSubmission(BaseModel):
+ response: dict[str, Any]
+ captcha: Optional[str]
+
+
class SubmitForm(Route):
"""
Submit a form with the provided form ID.
@@ -31,7 +45,18 @@ class SubmitForm(Route):
name = "submit_form"
path = "/submit/{form_id:str}"
+ @api.validate(
+ json=PartialSubmission,
+ resp=Response(
+ HTTP_200=SubmissionResponse,
+ HTTP_404=ErrorMessage,
+ HTTP_400=ErrorMessage
+ ),
+ headers=AuthorizationHeaders,
+ tags=["forms", "responses"]
+ )
async def post(self, request: Request) -> JSONResponse:
+ """Submit a response to the form."""
data = await request.json()
data["timestamp"] = None
diff --git a/backend/routes/index.py b/backend/routes/index.py
index b37f381..dd40d01 100644
--- a/backend/routes/index.py
+++ b/backend/routes/index.py
@@ -1,10 +1,24 @@
"""
Index route for the forms API.
"""
+from pydantic import BaseModel
+from pydantic.fields import Field
+from spectree import Response
from starlette.requests import Request
from starlette.responses import JSONResponse
from backend.route import Route
+from backend.validation import api
+
+
+class IndexResponse(BaseModel):
+ message: str = Field(description="A hello message")
+ client: str = Field(
+ description=(
+ "The connecting client, in production this will"
+ " be an IP of our internal load balancer"
+ )
+ )
class IndexRoute(Route):
@@ -17,7 +31,11 @@ class IndexRoute(Route):
name = "index"
path = "/"
+ @api.validate(resp=Response(HTTP_200=IndexResponse))
def get(self, request: Request) -> JSONResponse:
+ """
+ Return a hello from Python Discord forms!
+ """
response_data = {
"message": "Hello, world!",
"client": request.client.host,
diff --git a/backend/validation.py b/backend/validation.py
new file mode 100644
index 0000000..e696683
--- /dev/null
+++ b/backend/validation.py
@@ -0,0 +1,30 @@
+"""Utilities for providing API payload validation."""
+
+from typing import Optional
+from pydantic.fields import Field
+from pydantic.main import BaseModel
+from spectree import SpecTree
+
+api = SpecTree(
+ "starlette",
+ TITLE="Python Discord Forms",
+ PATH="docs"
+)
+
+
+class ErrorMessage(BaseModel):
+ error: str = Field(description="The details on the error")
+
+
+class OkayResponse(BaseModel):
+ status: str = "ok"
+
+
+class AuthorizationHeaders(BaseModel):
+ authorization: Optional[str] = Field(
+ title="Authorization",
+ description=(
+ "The Authorization JWT token received from the "
+ "authorize route in the format `JWT {token}`"
+ )
+ )
diff --git a/poetry.lock b/poetry.lock
index a32ce7e..cd0d1bc 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -280,7 +280,7 @@ version = "5.3.1"
description = "YAML parser and emitter for Python"
category = "main"
optional = false
-python-versions = "*"
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "rfc3986"
@@ -305,6 +305,23 @@ optional = false
python-versions = ">=3.5"
[[package]]
+name = "spectree"
+version = "0.3.16"
+description = "generate OpenAPI document and validate request&response with Python annotations."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+pydantic = ">=1.2"
+
+[package.extras]
+dev = ["pytest (>=6)", "flake8 (>=3.8)", "black (>=20.8b1)", "isort (>=5.6)", "autoflake (>=1.4)"]
+falcon = ["falcon"]
+flask = ["flask"]
+starlette = ["starlette"]
+
+[[package]]
name = "starlette"
version = "0.14.1"
description = "The little ASGI library that shines."
@@ -364,7 +381,7 @@ python-versions = ">=3.6.1"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
-content-hash = "31df3f6fb5c2739f0ac3158fc73d7ec699bf0b4a228b936e35463f0f977d4beb"
+content-hash = "f0529cd6559892497787a807a6fd3ee7c84b60c04cbc2513bf8caca6b7c3b367"
[metadata.files]
aiodns = [
@@ -642,6 +659,10 @@ sniffio = [
{file = "sniffio-1.2.0-py3-none-any.whl", hash = "sha256:471b71698eac1c2112a40ce2752bb2f4a4814c22a54a3eed3676bc0f5ca9f663"},
{file = "sniffio-1.2.0.tar.gz", hash = "sha256:c4666eecec1d3f50960c6bdf61ab7bc350648da6c126e3cf6898d8cd4ddcd3de"},
]
+spectree = [
+ {file = "spectree-0.3.16-py3-none-any.whl", hash = "sha256:e6cb74ce759361103805dcbd05b311eb46bf11e23486d0787e3f93723d6bab31"},
+ {file = "spectree-0.3.16.tar.gz", hash = "sha256:4d94b79ce2c73acaee5e306c71c7408f5a52e43048e1f7e734a0ce1e75b0c8c8"},
+]
starlette = [
{file = "starlette-0.14.1-py3-none-any.whl", hash = "sha256:d2f55fb835378442b812637ed3e3fcef3d3e22d292fcb8400fa48d2473202411"},
{file = "starlette-0.14.1.tar.gz", hash = "sha256:5268ef5d4904ec69582d5fd207b869a5aa0cd59529848ba4cf429b06e3ced99a"},
diff --git a/pyproject.toml b/pyproject.toml
index 774dc4c..b14e876 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ httpx = "^0.16.1"
gunicorn = "^20.0.4"
pydantic = "^1.7.2"
pydnsbl = "^1.1"
+spectree = "^0.3.16"
[tool.poetry.dev-dependencies]
flake8 = "^3.8.4"