aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ks129 <[email protected]>2020-12-16 12:46:02 +0200
committerGravatar GitHub <[email protected]>2020-12-16 12:46:02 +0200
commite68e05960ee0b01a34154d811ecc295981c8fdbc (patch)
treebd270b94154c57af7666e57eddf1d138308bacfc
parentReturn some JSON from delete endpoint (diff)
parentMerge pull request #28 from python-discord/ks123/routes-parsing (diff)
Merge branch 'main' into ks123/form-delete
-rw-r--r--backend/authentication/backend.py11
-rw-r--r--backend/authentication/user.py5
-rw-r--r--backend/models/form.py12
-rw-r--r--backend/models/form_response.py2
-rw-r--r--backend/models/question.py8
-rw-r--r--backend/route.py10
-rw-r--r--backend/route_manager.py60
-rw-r--r--poetry.lock8
-rw-r--r--pyproject.toml2
9 files changed, 61 insertions, 57 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index 38668eb..f1d2ece 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -1,6 +1,5 @@
import jwt
import typing as t
-from abc import ABC
from starlette import authentication
from starlette.requests import Request
@@ -10,11 +9,11 @@ from backend import constants
from .user import User
-class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
+class JWTAuthenticationBackend(authentication.AuthenticationBackend):
"""Custom Starlette authentication backend for JWT."""
@staticmethod
- def get_token_from_header(header: str) -> t.Optional[str]:
+ def get_token_from_header(header: str) -> str:
"""Parse JWT token from header value."""
try:
prefix, token = header.split()
@@ -32,10 +31,10 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
async def authenticate(
self, request: Request
- ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]:
+ ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
"""Handles JWT authentication process."""
if "Authorization" not in request.headers:
- return
+ return None
auth = request.headers["Authorization"]
token = self.get_token_from_header(auth)
@@ -47,7 +46,7 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
scopes = ["authenticated"]
- if payload.get("admin", False) is True:
+ if payload.get("admin") is True:
scopes.append("admin")
return authentication.AuthCredentials(scopes), User(token, payload)
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index afa243f..722c348 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,13 +1,12 @@
import typing as t
-from abc import ABC
from starlette.authentication import BaseUser
-class User(BaseUser, ABC):
+class User(BaseUser):
"""Starlette BaseUser implementation for JWT authentication."""
- def __init__(self, token: str, payload: t.Dict) -> None:
+ def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
self.token = token
self.payload = payload
diff --git a/backend/models/form.py b/backend/models/form.py
index 2cf8486..cb58065 100644
--- a/backend/models/form.py
+++ b/backend/models/form.py
@@ -12,8 +12,8 @@ class Form(BaseModel):
"""Schema model for form."""
id: str = Field(alias="_id")
- features: t.List[str]
- questions: t.List[Question]
+ features: list[str]
+ questions: list[Question]
name: str
description: str
@@ -21,12 +21,12 @@ class Form(BaseModel):
allow_population_by_field_name = True
@validator("features")
- def validate_features(cls, value: t.List[str]) -> t.Optional[t.List[str]]:
+ def validate_features(cls, value: list[str]) -> t.Optional[list[str]]:
"""Validates is all features in allowed list."""
# Uppercase everything to avoid mixed case in DB
value = [v.upper() for v in value]
- allowed_values = list(v.value for v in FormFeatures.__members__.values())
- if not all(v in allowed_values for v in value):
+ allowed_values = [v.value for v in FormFeatures.__members__.values()]
+ if any(v not in allowed_values for v in value):
raise ValueError("Form features list contains one or more invalid values.")
if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value: # noqa
@@ -34,7 +34,7 @@ class Form(BaseModel):
return value
- def dict(self, admin: bool = True, **kwargs: t.Dict) -> t.Dict[str, t.Any]:
+ def dict(self, admin: bool = True, **kwargs: t.Any) -> dict[str, t.Any]:
"""Wrapper for original function to exclude private data for public access."""
data = super().dict(**kwargs)
diff --git a/backend/models/form_response.py b/backend/models/form_response.py
index bea070f..f3296cd 100644
--- a/backend/models/form_response.py
+++ b/backend/models/form_response.py
@@ -12,7 +12,7 @@ class FormResponse(BaseModel):
id: str = Field(alias="_id")
user: t.Optional[DiscordUser]
antispam: t.Optional[AntiSpam]
- response: t.Dict[str, t.Any]
+ response: dict[str, t.Any]
form_id: str
class Config:
diff --git a/backend/models/question.py b/backend/models/question.py
index 1a012ff..3b98024 100644
--- a/backend/models/question.py
+++ b/backend/models/question.py
@@ -11,7 +11,7 @@ class Question(BaseModel):
id: str = Field(alias="_id")
name: str
type: str
- data: t.Dict[str, t.Any]
+ data: dict[str, t.Any]
class Config:
allow_population_by_field_name = True
@@ -31,14 +31,14 @@ class Question(BaseModel):
@root_validator
def validate_question_data(
cls,
- value: t.Dict[str, t.Any]
- ) -> t.Optional[t.Dict[str, t.Any]]:
+ value: dict[str, t.Any]
+ ) -> t.Optional[dict[str, t.Any]]:
"""Check does required data exists for question type and remove other data."""
# When question type don't need data, don't add anything to keep DB clean.
if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA:
return value
- for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value.get("type")].items():
+ for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items():
if key not in value.get("data", {}):
raise ValueError(f"Required question data key '{key}' not provided.")
diff --git a/backend/route.py b/backend/route.py
index eb69ebc..d778bf0 100644
--- a/backend/route.py
+++ b/backend/route.py
@@ -5,13 +5,13 @@ from starlette.endpoints import HTTPEndpoint
class Route(HTTPEndpoint):
- name: str = None
- path: str = None
+ name: str
+ path: str
@classmethod
- def check_parameters(cls) -> "Route":
- if cls.name is None:
+ def check_parameters(cls) -> None:
+ if not hasattr(cls, "name"):
raise ValueError(f"Route {cls.__name__} has not defined a name")
- if cls.path is None:
+ if not hasattr(cls, "path"):
raise ValueError(f"Route {cls.__name__} has not defined a path")
diff --git a/backend/route_manager.py b/backend/route_manager.py
index 25529eb..031c9b3 100644
--- a/backend/route_manager.py
+++ b/backend/route_manager.py
@@ -4,15 +4,17 @@ Module to dynamically generate a Starlette routing map based on a directory tree
import importlib
import inspect
+import typing as t
+
from pathlib import Path
-from starlette.routing import Route as StarletteRoute, Mount
+from starlette.routing import Route as StarletteRoute, BaseRoute, Mount
from nested_dict import nested_dict
from backend.route import Route
-def construct_route_map_from_dict(route_dict: dict) -> list:
+def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]:
route_map = []
for mount, item in route_dict.items():
if inspect.isclass(item):
@@ -26,35 +28,39 @@ def construct_route_map_from_dict(route_dict: dict) -> list:
return route_map
-def create_route_map() -> list:
- routes_directory = Path("backend") / "routes"
-
- route_dict = nested_dict()
-
- for file in routes_directory.rglob("*.py"):
- import_name = f"{str(file.parent).replace('/', '.')}.{file.stem}"
+def is_route_class(member: t.Any) -> bool:
+ return inspect.isclass(member) and issubclass(member, Route) and member != Route
- route = importlib.import_module(import_name)
- for _member_name, member in inspect.getmembers(route):
- if inspect.isclass(member):
- if issubclass(member, Route) and member != Route:
- member.check_parameters()
+def route_classes() -> t.Iterator[tuple[Path, type[Route]]]:
+ routes_directory = Path("backend") / "routes"
- levels = str(file.parent).split("/")[2:]
+ for module_path in routes_directory.rglob("*.py"):
+ import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}"
+ route_module = importlib.import_module(import_name)
+ for _member_name, member in inspect.getmembers(route_module):
+ if is_route_class(member):
+ member.check_parameters()
+ yield (module_path, member)
- current_level = None
- for level in levels:
- if current_level is None:
- current_level = route_dict[f"/{level}"]
- else:
- current_level = current_level[f"/{level}"]
- if current_level is not None:
- current_level[member.path] = member
- else:
- route_dict[member.path] = member
+def create_route_map() -> list[BaseRoute]:
+ route_dict = nested_dict()
- route_map = construct_route_map_from_dict(route_dict.to_dict())
+ for module_path, member in route_classes():
+ # module_path == Path("backend/routes/foo/bar/baz/bin.py")
+ # => levels == ["foo", "bar", "baz"]
+ levels = module_path.parent.parts[2:]
+ current_level = None
+ for level in levels:
+ if current_level is None:
+ current_level = route_dict[f"/{level}"]
+ else:
+ current_level = current_level[f"/{level}"]
+
+ if current_level is not None:
+ current_level[member.path] = member
+ else:
+ route_dict[member.path] = member
- return route_map
+ return construct_route_map_from_dict(route_dict.to_dict())
diff --git a/poetry.lock b/poetry.lock
index 9620a71..a32ce7e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -317,7 +317,7 @@ full = ["aiofiles", "graphene", "itsdangerous", "jinja2", "python-multipart", "p
[[package]]
name = "uvicorn"
-version = "0.12.3"
+version = "0.13.1"
description = "The lightning-fast ASGI server."
category = "main"
optional = false
@@ -364,7 +364,7 @@ python-versions = ">=3.6.1"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
-content-hash = "aaf11a0a091a509880f780239e67d0c12c2213d1570e07c78acc60b20814bbe7"
+content-hash = "31df3f6fb5c2739f0ac3158fc73d7ec699bf0b4a228b936e35463f0f977d4beb"
[metadata.files]
aiodns = [
@@ -647,8 +647,8 @@ starlette = [
{file = "starlette-0.14.1.tar.gz", hash = "sha256:5268ef5d4904ec69582d5fd207b869a5aa0cd59529848ba4cf429b06e3ced99a"},
]
uvicorn = [
- {file = "uvicorn-0.12.3-py3-none-any.whl", hash = "sha256:562ef6aaa8fa723ab6b82cf9e67a774088179d0ec57cb17e447b15d58b603bcf"},
- {file = "uvicorn-0.12.3.tar.gz", hash = "sha256:5836edaf4d278fe67ba0298c0537bdb6398cf359eb644f79e6500ca1aad232b3"},
+ {file = "uvicorn-0.13.1-py3-none-any.whl", hash = "sha256:6fcce74c00b77d4f4b3ed7ba1b2a370d27133bfdb46f835b7a76dfe0a8c110ae"},
+ {file = "uvicorn-0.13.1.tar.gz", hash = "sha256:2a7b17f4d9848d6557ccc2274a5f7c97f1daf037d130a0c6918f67cd9bc8cdf5"},
]
uvloop = [
{file = "uvloop-0.14.0-cp35-cp35m-macosx_10_11_x86_64.whl", hash = "sha256:08b109f0213af392150e2fe6f81d33261bb5ce968a288eb698aad4f46eb711bd"},
diff --git a/pyproject.toml b/pyproject.toml
index 4b8e993..774dc4c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ license = "MIT"
python = "^3.9"
starlette = "^0.14.0"
nested_dict = "^1.61"
-uvicorn = {extras = ["standard"], version = "^0.12.2"}
+uvicorn = {extras = ["standard"], version = "^0.13.0"}
motor = "^2.3.0"
python-dotenv = "^0.15.0"
pyjwt = "^1.7.1"