diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | backend/__init__.py | 7 | ||||
-rw-r--r-- | backend/constants.py | 2 | ||||
-rw-r--r-- | backend/middleware.py | 14 | ||||
-rw-r--r-- | backend/models/__init__.py | 14 | ||||
-rw-r--r-- | backend/models/form.py | 4 | ||||
-rw-r--r-- | backend/models/form_response.py | 4 | ||||
-rw-r--r-- | backend/routes/auth/authorize.py | 28 | ||||
-rw-r--r-- | backend/routes/forms/discover.py | 6 | ||||
-rw-r--r-- | backend/routes/forms/form.py | 7 | ||||
-rw-r--r-- | backend/routes/forms/index.py | 19 | ||||
-rw-r--r-- | backend/routes/forms/response.py | 14 | ||||
-rw-r--r-- | backend/routes/forms/responses.py | 8 | ||||
-rw-r--r-- | backend/routes/forms/submit.py | 25 | ||||
-rw-r--r-- | backend/routes/index.py | 18 | ||||
-rw-r--r-- | backend/validation.py | 30 | ||||
-rw-r--r-- | poetry.lock | 25 | ||||
-rw-r--r-- | pyproject.toml | 1 |
18 files changed, 208 insertions, 21 deletions
@@ -130,3 +130,6 @@ dmypy.json # IntelliJ project settings .idea/ + +# VSCode Settings +.vscode/settings.json diff --git a/backend/__init__.py b/backend/__init__.py index 770107b..5c6328b 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -5,7 +5,8 @@ from starlette.middleware.cors import CORSMiddleware from backend.authentication import JWTAuthenticationBackend from backend.route_manager import create_route_map -from backend.middleware import DatabaseMiddleware +from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware +from backend.validation import api middleware = [ Middleware( @@ -19,7 +20,9 @@ middleware = [ allow_methods=["*"] ), Middleware(DatabaseMiddleware), - Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()) + Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()), + Middleware(ProtectedDocsMiddleware) ] app = Starlette(routes=create_route_map(), middleware=middleware) +api.register(app) diff --git a/backend/constants.py b/backend/constants.py index fdf7092..4218bff 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -16,6 +16,8 @@ OAUTH2_REDIRECT_URI = os.getenv( "https://forms.pythondiscord.com/callback" ) +DOCS_PASSWORD = os.getenv("DOCS_PASSWORD") + SECRET_KEY = os.getenv("SECRET_KEY", binascii.hexlify(os.urandom(30)).decode()) HCAPTCHA_API_SECRET = os.getenv("HCAPTCHA_API_SECRET") diff --git a/backend/middleware.py b/backend/middleware.py index 2267a9a..f74091b 100644 --- a/backend/middleware.py +++ b/backend/middleware.py @@ -4,9 +4,9 @@ import ssl from motor.motor_asyncio import AsyncIOMotorClient from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse, Response -from backend.constants import DATABASE_URL, MONGO_DATABASE +from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE class DatabaseMiddleware(BaseHTTPMiddleware): @@ -19,3 +19,13 @@ class DatabaseMiddleware(BaseHTTPMiddleware): request.state.db = db response = await call_next(request) return response + + +class ProtectedDocsMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: t.Callable) -> Response: + if DOCS_PASSWORD and request.url.path.startswith("/docs"): + if request.cookies.get("docs_password") != DOCS_PASSWORD: + return JSONResponse({"status": "unauthorized"}, status_code=403) + + resp = await call_next(request) + return resp diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 98fa619..29ccb24 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,7 +1,15 @@ from .antispam import AntiSpam from .discord_user import DiscordUser -from .form import Form -from .form_response import FormResponse +from .form import Form, FormList +from .form_response import FormResponse, ResponseList from .question import Question -__all__ = ["AntiSpam", "DiscordUser", "Form", "FormResponse", "Question"] +__all__ = [ + "AntiSpam", + "DiscordUser", + "Form", + "FormResponse", + "Question", + "FormList", + "ResponseList" +] diff --git a/backend/models/form.py b/backend/models/form.py index cb58065..9d8ffaa 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -52,3 +52,7 @@ class Form(BaseModel): returned_data = data return returned_data + + +class FormList(BaseModel): + __root__: t.List[Form] diff --git a/backend/models/form_response.py b/backend/models/form_response.py index 0da7b15..933f5e4 100644 --- a/backend/models/form_response.py +++ b/backend/models/form_response.py @@ -30,3 +30,7 @@ class FormResponse(BaseModel): class Config: allow_population_by_field_name = True + + +class ResponseList(BaseModel): + __root__: t.List[FormResponse] diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 41c0a0b..2509109 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -2,13 +2,26 @@ Use a token received from the Discord OAuth2 system to fetch user information. """ +import httpx import jwt +from pydantic.fields import Field +from pydantic.main import BaseModel +from spectree.response import Response from starlette.requests import Request from starlette.responses import JSONResponse from backend.constants import SECRET_KEY from backend.route import Route from backend.discord import fetch_bearer_token, fetch_user_details +from backend.validation import ErrorMessage, api + + +class AuthorizeRequest(BaseModel): + token: str = Field(description="The access token received from Discord.") + + +class AuthorizeResponse(BaseModel): + token: str = Field(description="A JWT token containing the user information") class AuthorizeRoute(Route): @@ -19,11 +32,22 @@ class AuthorizeRoute(Route): name = "authorize" path = "/authorize" + @api.validate( + json=AuthorizeRequest, + resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage), + tags=["auth"] + ) async def post(self, request: Request) -> JSONResponse: + """Generate an authorization token.""" data = await request.json() - bearer_token = await fetch_bearer_token(data["token"]) - user_details = await fetch_user_details(bearer_token["access_token"]) + try: + bearer_token = await fetch_bearer_token(data["token"]) + user_details = await fetch_user_details(bearer_token["access_token"]) + except httpx.HTTPStatusError: + return JSONResponse({ + "error": "auth_failure" + }, status_code=400) user_details["admin"] = await request.state.db.admins.find_one( {"_id": user_details["id"]} diff --git a/backend/routes/forms/discover.py b/backend/routes/forms/discover.py index bba6fd4..9400f05 100644 --- a/backend/routes/forms/discover.py +++ b/backend/routes/forms/discover.py @@ -1,11 +1,13 @@ """ Return a list of all publicly discoverable forms to unauthenticated users. """ +from spectree.response import Response from starlette.requests import Request from starlette.responses import JSONResponse -from backend.models import Form +from backend.models import Form, FormList from backend.route import Route +from backend.validation import api class DiscoverableFormsList(Route): @@ -16,7 +18,9 @@ class DiscoverableFormsList(Route): name = "discoverable_forms_list" path = "/discoverable" + @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: + """List all discoverable forms that should be shown on the homepage.""" forms = [] cursor = request.state.db.forms.find({"features": "DISCOVERABLE"}) diff --git a/backend/routes/forms/form.py b/backend/routes/forms/form.py index 8fdd8a2..c953135 100644 --- a/backend/routes/forms/form.py +++ b/backend/routes/forms/form.py @@ -1,12 +1,14 @@ """ Returns or deletes a single form given an ID. """ +from spectree.response import Response from starlette.authentication import requires from starlette.requests import Request from starlette.responses import JSONResponse from backend.route import Route from backend.models import Form +from backend.validation import OkayResponse, api, ErrorMessage class SingleForm(Route): @@ -19,6 +21,7 @@ class SingleForm(Route): name = "form" path = "/{form_id:str}" + @api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: """Returns single form information by ID.""" admin = request.user.payload["admin"] if request.user.is_authenticated else False # noqa @@ -37,6 +40,10 @@ class SingleForm(Route): return JSONResponse({"error": "not_found"}, status_code=404) @requires(["authenticated", "admin"]) + @api.validate( + resp=Response(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), + tags=["forms"] + ) async def delete(self, request: Request) -> JSONResponse: """Deletes form by ID.""" if not await request.state.db.forms.find_one( diff --git a/backend/routes/forms/index.py b/backend/routes/forms/index.py index bb2b299..d1373e4 100644 --- a/backend/routes/forms/index.py +++ b/backend/routes/forms/index.py @@ -1,13 +1,14 @@ """ Return a list of all forms to authenticated users. """ -from pydantic import ValidationError +from spectree.response import Response from starlette.authentication import requires from starlette.requests import Request from starlette.responses import JSONResponse from backend.route import Route -from backend.models import Form +from backend.models import Form, FormList +from backend.validation import ErrorMessage, OkayResponse, api class FormsList(Route): @@ -19,7 +20,9 @@ class FormsList(Route): path = "/" @requires(["authenticated", "admin"]) + @api.validate(resp=Response(HTTP_200=FormList), tags=["forms"]) async def get(self, request: Request) -> JSONResponse: + """Return a list of all forms to authenticated users.""" forms = [] cursor = request.state.db.forms.find() @@ -34,12 +37,16 @@ class FormsList(Route): ) @requires(["authenticated", "admin"]) + @api.validate( + json=Form, + resp=Response(HTTP_200=OkayResponse, HTTP_400=ErrorMessage), + tags=["forms"] + ) async def post(self, request: Request) -> JSONResponse: + """Create a new form.""" form_data = await request.json() - try: - form = Form(**form_data) - except ValidationError as e: - return JSONResponse(e.errors()) + + form = Form(**form_data) if await request.state.db.forms.find_one({"_id": form.id}): return JSONResponse({ diff --git a/backend/routes/forms/response.py b/backend/routes/forms/response.py index acaa647..d8d8d17 100644 --- a/backend/routes/forms/response.py +++ b/backend/routes/forms/response.py @@ -1,12 +1,14 @@ """ Returns or deletes form response by ID. """ +from spectree import Response as RouteResponse from starlette.authentication import requires from starlette.requests import Request from starlette.responses import JSONResponse from backend.models import FormResponse from backend.route import Route +from backend.validation import ErrorMessage, OkayResponse, api class Response(Route): @@ -16,8 +18,12 @@ class Response(Route): path = "/{form_id:str}/responses/{response_id:str}" @requires(["authenticated", "admin"]) + @api.validate( + resp=RouteResponse(HTTP_200=FormResponse, HTTP_404=ErrorMessage), + tags=["forms", "responses"] + ) async def get(self, request: Request) -> JSONResponse: - """Returns single form response by ID.""" + """Return a single form response by ID.""" if raw_response := await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], @@ -30,8 +36,12 @@ class Response(Route): return JSONResponse({"error": "not_found"}, status_code=404) @requires(["authenticated", "admin"]) + @api.validate( + resp=RouteResponse(HTTP_200=OkayResponse, HTTP_404=ErrorMessage), + tags=["forms", "responses"] + ) async def delete(self, request: Request) -> JSONResponse: - """Deletes form response by ID.""" + """Delete a form response by ID.""" if not await request.state.db.responses.find_one( { "_id": request.path_params["response_id"], diff --git a/backend/routes/forms/responses.py b/backend/routes/forms/responses.py index ee8ab84..54da246 100644 --- a/backend/routes/forms/responses.py +++ b/backend/routes/forms/responses.py @@ -1,12 +1,14 @@ """ Returns all form responses by form ID. """ +from spectree import Response from starlette.authentication import requires from starlette.requests import Request from starlette.responses import JSONResponse -from backend.models import FormResponse +from backend.models import FormResponse, ResponseList from backend.route import Route +from backend.validation import api, ErrorMessage class Responses(Route): @@ -18,6 +20,10 @@ class Responses(Route): path = "/{form_id:str}/responses" @requires(["authenticated", "admin"]) + @api.validate( + resp=Response(HTTP_200=ResponseList, HTTP_404=ErrorMessage), + tags=["forms", "responses"] + ) async def get(self, request: Request) -> JSONResponse: """Returns all form responses by form ID.""" if not await request.state.db.forms.find_one( diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 48ae4f6..dfa0de6 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -4,11 +4,14 @@ Submit a form. import binascii import hashlib +from typing import Any, Optional import uuid import httpx +from pydantic.main import BaseModel import pydnsbl from pydantic import ValidationError +from spectree import Response from starlette.requests import Request from starlette.responses import JSONResponse @@ -16,6 +19,7 @@ from starlette.responses import JSONResponse from backend.constants import HCAPTCHA_API_SECRET, FormFeatures from backend.models import Form, FormResponse from backend.route import Route +from backend.validation import AuthorizationHeaders, ErrorMessage, api HCAPTCHA_VERIFY_URL = "https://hcaptcha.com/siteverify" HCAPTCHA_HEADERS = { @@ -23,6 +27,16 @@ HCAPTCHA_HEADERS = { } +class SubmissionResponse(BaseModel): + form: Form + response: FormResponse + + +class PartialSubmission(BaseModel): + response: dict[str, Any] + captcha: Optional[str] + + class SubmitForm(Route): """ Submit a form with the provided form ID. @@ -31,7 +45,18 @@ class SubmitForm(Route): name = "submit_form" path = "/submit/{form_id:str}" + @api.validate( + json=PartialSubmission, + resp=Response( + HTTP_200=SubmissionResponse, + HTTP_404=ErrorMessage, + HTTP_400=ErrorMessage + ), + headers=AuthorizationHeaders, + tags=["forms", "responses"] + ) async def post(self, request: Request) -> JSONResponse: + """Submit a response to the form.""" data = await request.json() data["timestamp"] = None diff --git a/backend/routes/index.py b/backend/routes/index.py index b37f381..dd40d01 100644 --- a/backend/routes/index.py +++ b/backend/routes/index.py @@ -1,10 +1,24 @@ """ Index route for the forms API. """ +from pydantic import BaseModel +from pydantic.fields import Field +from spectree import Response from starlette.requests import Request from starlette.responses import JSONResponse from backend.route import Route +from backend.validation import api + + +class IndexResponse(BaseModel): + message: str = Field(description="A hello message") + client: str = Field( + description=( + "The connecting client, in production this will" + " be an IP of our internal load balancer" + ) + ) class IndexRoute(Route): @@ -17,7 +31,11 @@ class IndexRoute(Route): name = "index" path = "/" + @api.validate(resp=Response(HTTP_200=IndexResponse)) def get(self, request: Request) -> JSONResponse: + """ + Return a hello from Python Discord forms! + """ response_data = { "message": "Hello, world!", "client": request.client.host, diff --git a/backend/validation.py b/backend/validation.py new file mode 100644 index 0000000..e696683 --- /dev/null +++ b/backend/validation.py @@ -0,0 +1,30 @@ +"""Utilities for providing API payload validation.""" + +from typing import Optional +from pydantic.fields import Field +from pydantic.main import BaseModel +from spectree import SpecTree + +api = SpecTree( + "starlette", + TITLE="Python Discord Forms", + PATH="docs" +) + + +class ErrorMessage(BaseModel): + error: str = Field(description="The details on the error") + + +class OkayResponse(BaseModel): + status: str = "ok" + + +class AuthorizationHeaders(BaseModel): + authorization: Optional[str] = Field( + title="Authorization", + description=( + "The Authorization JWT token received from the " + "authorize route in the format `JWT {token}`" + ) + ) diff --git a/poetry.lock b/poetry.lock index a32ce7e..cd0d1bc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -280,7 +280,7 @@ version = "5.3.1" description = "YAML parser and emitter for Python" category = "main" optional = false -python-versions = "*" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "rfc3986" @@ -305,6 +305,23 @@ optional = false python-versions = ">=3.5" [[package]] +name = "spectree" +version = "0.3.16" +description = "generate OpenAPI document and validate request&response with Python annotations." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +pydantic = ">=1.2" + +[package.extras] +dev = ["pytest (>=6)", "flake8 (>=3.8)", "black (>=20.8b1)", "isort (>=5.6)", "autoflake (>=1.4)"] +falcon = ["falcon"] +flask = ["flask"] +starlette = ["starlette"] + +[[package]] name = "starlette" version = "0.14.1" description = "The little ASGI library that shines." @@ -364,7 +381,7 @@ python-versions = ">=3.6.1" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "31df3f6fb5c2739f0ac3158fc73d7ec699bf0b4a228b936e35463f0f977d4beb" +content-hash = "f0529cd6559892497787a807a6fd3ee7c84b60c04cbc2513bf8caca6b7c3b367" [metadata.files] aiodns = [ @@ -642,6 +659,10 @@ sniffio = [ {file = "sniffio-1.2.0-py3-none-any.whl", hash = "sha256:471b71698eac1c2112a40ce2752bb2f4a4814c22a54a3eed3676bc0f5ca9f663"}, {file = "sniffio-1.2.0.tar.gz", hash = "sha256:c4666eecec1d3f50960c6bdf61ab7bc350648da6c126e3cf6898d8cd4ddcd3de"}, ] +spectree = [ + {file = "spectree-0.3.16-py3-none-any.whl", hash = "sha256:e6cb74ce759361103805dcbd05b311eb46bf11e23486d0787e3f93723d6bab31"}, + {file = "spectree-0.3.16.tar.gz", hash = "sha256:4d94b79ce2c73acaee5e306c71c7408f5a52e43048e1f7e734a0ce1e75b0c8c8"}, +] starlette = [ {file = "starlette-0.14.1-py3-none-any.whl", hash = "sha256:d2f55fb835378442b812637ed3e3fcef3d3e22d292fcb8400fa48d2473202411"}, {file = "starlette-0.14.1.tar.gz", hash = "sha256:5268ef5d4904ec69582d5fd207b869a5aa0cd59529848ba4cf429b06e3ced99a"}, diff --git a/pyproject.toml b/pyproject.toml index 774dc4c..b14e876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ httpx = "^0.16.1" gunicorn = "^20.0.4" pydantic = "^1.7.2" pydnsbl = "^1.1" +spectree = "^0.3.16" [tool.poetry.dev-dependencies] flake8 = "^3.8.4" |