aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2020-12-16 10:28:52 +0000
committerGravatar GitHub <[email protected]>2020-12-16 10:28:52 +0000
commit2c5610d1dbf5d6e8fb112d9dc4d93330a87c6708 (patch)
tree78b7f2269ef04da0f9c3f871ac45a00cf14ca3f6 /backend
parentMerge pull request #33 from python-discord/renovate/uvicorn-0.x (diff)
parentMerge branch 'main' into ks123/routes-parsing (diff)
Merge pull request #28 from python-discord/ks123/routes-parsing
Diffstat (limited to 'backend')
-rw-r--r--backend/authentication/backend.py11
-rw-r--r--backend/authentication/user.py5
-rw-r--r--backend/models/form.py12
-rw-r--r--backend/models/form_response.py2
-rw-r--r--backend/models/question.py8
-rw-r--r--backend/route.py10
-rw-r--r--backend/route_manager.py60
7 files changed, 56 insertions, 52 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index 38668eb..f1d2ece 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -1,6 +1,5 @@
import jwt
import typing as t
-from abc import ABC
from starlette import authentication
from starlette.requests import Request
@@ -10,11 +9,11 @@ from backend import constants
from .user import User
-class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
+class JWTAuthenticationBackend(authentication.AuthenticationBackend):
"""Custom Starlette authentication backend for JWT."""
@staticmethod
- def get_token_from_header(header: str) -> t.Optional[str]:
+ def get_token_from_header(header: str) -> str:
"""Parse JWT token from header value."""
try:
prefix, token = header.split()
@@ -32,10 +31,10 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
async def authenticate(
self, request: Request
- ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]:
+ ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
"""Handles JWT authentication process."""
if "Authorization" not in request.headers:
- return
+ return None
auth = request.headers["Authorization"]
token = self.get_token_from_header(auth)
@@ -47,7 +46,7 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
scopes = ["authenticated"]
- if payload.get("admin", False) is True:
+ if payload.get("admin") is True:
scopes.append("admin")
return authentication.AuthCredentials(scopes), User(token, payload)
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index afa243f..722c348 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,13 +1,12 @@
import typing as t
-from abc import ABC
from starlette.authentication import BaseUser
-class User(BaseUser, ABC):
+class User(BaseUser):
"""Starlette BaseUser implementation for JWT authentication."""
- def __init__(self, token: str, payload: t.Dict) -> None:
+ def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
self.token = token
self.payload = payload
diff --git a/backend/models/form.py b/backend/models/form.py
index 2cf8486..cb58065 100644
--- a/backend/models/form.py
+++ b/backend/models/form.py
@@ -12,8 +12,8 @@ class Form(BaseModel):
"""Schema model for form."""
id: str = Field(alias="_id")
- features: t.List[str]
- questions: t.List[Question]
+ features: list[str]
+ questions: list[Question]
name: str
description: str
@@ -21,12 +21,12 @@ class Form(BaseModel):
allow_population_by_field_name = True
@validator("features")
- def validate_features(cls, value: t.List[str]) -> t.Optional[t.List[str]]:
+ def validate_features(cls, value: list[str]) -> t.Optional[list[str]]:
"""Validates is all features in allowed list."""
# Uppercase everything to avoid mixed case in DB
value = [v.upper() for v in value]
- allowed_values = list(v.value for v in FormFeatures.__members__.values())
- if not all(v in allowed_values for v in value):
+ allowed_values = [v.value for v in FormFeatures.__members__.values()]
+ if any(v not in allowed_values for v in value):
raise ValueError("Form features list contains one or more invalid values.")
if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value: # noqa
@@ -34,7 +34,7 @@ class Form(BaseModel):
return value
- def dict(self, admin: bool = True, **kwargs: t.Dict) -> t.Dict[str, t.Any]:
+ def dict(self, admin: bool = True, **kwargs: t.Any) -> dict[str, t.Any]:
"""Wrapper for original function to exclude private data for public access."""
data = super().dict(**kwargs)
diff --git a/backend/models/form_response.py b/backend/models/form_response.py
index bea070f..f3296cd 100644
--- a/backend/models/form_response.py
+++ b/backend/models/form_response.py
@@ -12,7 +12,7 @@ class FormResponse(BaseModel):
id: str = Field(alias="_id")
user: t.Optional[DiscordUser]
antispam: t.Optional[AntiSpam]
- response: t.Dict[str, t.Any]
+ response: dict[str, t.Any]
form_id: str
class Config:
diff --git a/backend/models/question.py b/backend/models/question.py
index 1a012ff..3b98024 100644
--- a/backend/models/question.py
+++ b/backend/models/question.py
@@ -11,7 +11,7 @@ class Question(BaseModel):
id: str = Field(alias="_id")
name: str
type: str
- data: t.Dict[str, t.Any]
+ data: dict[str, t.Any]
class Config:
allow_population_by_field_name = True
@@ -31,14 +31,14 @@ class Question(BaseModel):
@root_validator
def validate_question_data(
cls,
- value: t.Dict[str, t.Any]
- ) -> t.Optional[t.Dict[str, t.Any]]:
+ value: dict[str, t.Any]
+ ) -> t.Optional[dict[str, t.Any]]:
"""Check does required data exists for question type and remove other data."""
# When question type don't need data, don't add anything to keep DB clean.
if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA:
return value
- for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value.get("type")].items():
+ for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items():
if key not in value.get("data", {}):
raise ValueError(f"Required question data key '{key}' not provided.")
diff --git a/backend/route.py b/backend/route.py
index eb69ebc..d778bf0 100644
--- a/backend/route.py
+++ b/backend/route.py
@@ -5,13 +5,13 @@ from starlette.endpoints import HTTPEndpoint
class Route(HTTPEndpoint):
- name: str = None
- path: str = None
+ name: str
+ path: str
@classmethod
- def check_parameters(cls) -> "Route":
- if cls.name is None:
+ def check_parameters(cls) -> None:
+ if not hasattr(cls, "name"):
raise ValueError(f"Route {cls.__name__} has not defined a name")
- if cls.path is None:
+ if not hasattr(cls, "path"):
raise ValueError(f"Route {cls.__name__} has not defined a path")
diff --git a/backend/route_manager.py b/backend/route_manager.py
index 25529eb..031c9b3 100644
--- a/backend/route_manager.py
+++ b/backend/route_manager.py
@@ -4,15 +4,17 @@ Module to dynamically generate a Starlette routing map based on a directory tree
import importlib
import inspect
+import typing as t
+
from pathlib import Path
-from starlette.routing import Route as StarletteRoute, Mount
+from starlette.routing import Route as StarletteRoute, BaseRoute, Mount
from nested_dict import nested_dict
from backend.route import Route
-def construct_route_map_from_dict(route_dict: dict) -> list:
+def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]:
route_map = []
for mount, item in route_dict.items():
if inspect.isclass(item):
@@ -26,35 +28,39 @@ def construct_route_map_from_dict(route_dict: dict) -> list:
return route_map
-def create_route_map() -> list:
- routes_directory = Path("backend") / "routes"
-
- route_dict = nested_dict()
-
- for file in routes_directory.rglob("*.py"):
- import_name = f"{str(file.parent).replace('/', '.')}.{file.stem}"
+def is_route_class(member: t.Any) -> bool:
+ return inspect.isclass(member) and issubclass(member, Route) and member != Route
- route = importlib.import_module(import_name)
- for _member_name, member in inspect.getmembers(route):
- if inspect.isclass(member):
- if issubclass(member, Route) and member != Route:
- member.check_parameters()
+def route_classes() -> t.Iterator[tuple[Path, type[Route]]]:
+ routes_directory = Path("backend") / "routes"
- levels = str(file.parent).split("/")[2:]
+ for module_path in routes_directory.rglob("*.py"):
+ import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}"
+ route_module = importlib.import_module(import_name)
+ for _member_name, member in inspect.getmembers(route_module):
+ if is_route_class(member):
+ member.check_parameters()
+ yield (module_path, member)
- current_level = None
- for level in levels:
- if current_level is None:
- current_level = route_dict[f"/{level}"]
- else:
- current_level = current_level[f"/{level}"]
- if current_level is not None:
- current_level[member.path] = member
- else:
- route_dict[member.path] = member
+def create_route_map() -> list[BaseRoute]:
+ route_dict = nested_dict()
- route_map = construct_route_map_from_dict(route_dict.to_dict())
+ for module_path, member in route_classes():
+ # module_path == Path("backend/routes/foo/bar/baz/bin.py")
+ # => levels == ["foo", "bar", "baz"]
+ levels = module_path.parent.parts[2:]
+ current_level = None
+ for level in levels:
+ if current_level is None:
+ current_level = route_dict[f"/{level}"]
+ else:
+ current_level = current_level[f"/{level}"]
+
+ if current_level is not None:
+ current_level[member.path] = member
+ else:
+ route_dict[member.path] = member
- return route_map
+ return construct_route_map_from_dict(route_dict.to_dict())