diff options
author | 2020-12-16 12:46:02 +0200 | |
---|---|---|
committer | 2020-12-16 12:46:02 +0200 | |
commit | e68e05960ee0b01a34154d811ecc295981c8fdbc (patch) | |
tree | bd270b94154c57af7666e57eddf1d138308bacfc | |
parent | Return some JSON from delete endpoint (diff) | |
parent | Merge pull request #28 from python-discord/ks123/routes-parsing (diff) |
Merge branch 'main' into ks123/form-delete
-rw-r--r-- | backend/authentication/backend.py | 11 | ||||
-rw-r--r-- | backend/authentication/user.py | 5 | ||||
-rw-r--r-- | backend/models/form.py | 12 | ||||
-rw-r--r-- | backend/models/form_response.py | 2 | ||||
-rw-r--r-- | backend/models/question.py | 8 | ||||
-rw-r--r-- | backend/route.py | 10 | ||||
-rw-r--r-- | backend/route_manager.py | 60 | ||||
-rw-r--r-- | poetry.lock | 8 | ||||
-rw-r--r-- | pyproject.toml | 2 |
9 files changed, 61 insertions, 57 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index 38668eb..f1d2ece 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,5 @@ import jwt import typing as t -from abc import ABC from starlette import authentication from starlette.requests import Request @@ -10,11 +9,11 @@ from backend import constants from .user import User -class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): +class JWTAuthenticationBackend(authentication.AuthenticationBackend): """Custom Starlette authentication backend for JWT.""" @staticmethod - def get_token_from_header(header: str) -> t.Optional[str]: + def get_token_from_header(header: str) -> str: """Parse JWT token from header value.""" try: prefix, token = header.split() @@ -32,10 +31,10 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): async def authenticate( self, request: Request - ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]: + ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]: """Handles JWT authentication process.""" if "Authorization" not in request.headers: - return + return None auth = request.headers["Authorization"] token = self.get_token_from_header(auth) @@ -47,7 +46,7 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): scopes = ["authenticated"] - if payload.get("admin", False) is True: + if payload.get("admin") is True: scopes.append("admin") return authentication.AuthCredentials(scopes), User(token, payload) diff --git a/backend/authentication/user.py b/backend/authentication/user.py index afa243f..722c348 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,13 +1,12 @@ import typing as t -from abc import ABC from starlette.authentication import BaseUser -class User(BaseUser, ABC): +class User(BaseUser): """Starlette BaseUser implementation for JWT authentication.""" - def __init__(self, token: str, payload: t.Dict) -> None: + def __init__(self, token: str, payload: dict[str, t.Any]) -> None: self.token = token self.payload = payload diff --git a/backend/models/form.py b/backend/models/form.py index 2cf8486..cb58065 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -12,8 +12,8 @@ class Form(BaseModel): """Schema model for form.""" id: str = Field(alias="_id") - features: t.List[str] - questions: t.List[Question] + features: list[str] + questions: list[Question] name: str description: str @@ -21,12 +21,12 @@ class Form(BaseModel): allow_population_by_field_name = True @validator("features") - def validate_features(cls, value: t.List[str]) -> t.Optional[t.List[str]]: + def validate_features(cls, value: list[str]) -> t.Optional[list[str]]: """Validates is all features in allowed list.""" # Uppercase everything to avoid mixed case in DB value = [v.upper() for v in value] - allowed_values = list(v.value for v in FormFeatures.__members__.values()) - if not all(v in allowed_values for v in value): + allowed_values = [v.value for v in FormFeatures.__members__.values()] + if any(v not in allowed_values for v in value): raise ValueError("Form features list contains one or more invalid values.") if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value: # noqa @@ -34,7 +34,7 @@ class Form(BaseModel): return value - def dict(self, admin: bool = True, **kwargs: t.Dict) -> t.Dict[str, t.Any]: + def dict(self, admin: bool = True, **kwargs: t.Any) -> dict[str, t.Any]: """Wrapper for original function to exclude private data for public access.""" data = super().dict(**kwargs) diff --git a/backend/models/form_response.py b/backend/models/form_response.py index bea070f..f3296cd 100644 --- a/backend/models/form_response.py +++ b/backend/models/form_response.py @@ -12,7 +12,7 @@ class FormResponse(BaseModel): id: str = Field(alias="_id") user: t.Optional[DiscordUser] antispam: t.Optional[AntiSpam] - response: t.Dict[str, t.Any] + response: dict[str, t.Any] form_id: str class Config: diff --git a/backend/models/question.py b/backend/models/question.py index 1a012ff..3b98024 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -11,7 +11,7 @@ class Question(BaseModel): id: str = Field(alias="_id") name: str type: str - data: t.Dict[str, t.Any] + data: dict[str, t.Any] class Config: allow_population_by_field_name = True @@ -31,14 +31,14 @@ class Question(BaseModel): @root_validator def validate_question_data( cls, - value: t.Dict[str, t.Any] - ) -> t.Optional[t.Dict[str, t.Any]]: + value: dict[str, t.Any] + ) -> t.Optional[dict[str, t.Any]]: """Check does required data exists for question type and remove other data.""" # When question type don't need data, don't add anything to keep DB clean. if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA: return value - for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value.get("type")].items(): + for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items(): if key not in value.get("data", {}): raise ValueError(f"Required question data key '{key}' not provided.") diff --git a/backend/route.py b/backend/route.py index eb69ebc..d778bf0 100644 --- a/backend/route.py +++ b/backend/route.py @@ -5,13 +5,13 @@ from starlette.endpoints import HTTPEndpoint class Route(HTTPEndpoint): - name: str = None - path: str = None + name: str + path: str @classmethod - def check_parameters(cls) -> "Route": - if cls.name is None: + def check_parameters(cls) -> None: + if not hasattr(cls, "name"): raise ValueError(f"Route {cls.__name__} has not defined a name") - if cls.path is None: + if not hasattr(cls, "path"): raise ValueError(f"Route {cls.__name__} has not defined a path") diff --git a/backend/route_manager.py b/backend/route_manager.py index 25529eb..031c9b3 100644 --- a/backend/route_manager.py +++ b/backend/route_manager.py @@ -4,15 +4,17 @@ Module to dynamically generate a Starlette routing map based on a directory tree import importlib import inspect +import typing as t + from pathlib import Path -from starlette.routing import Route as StarletteRoute, Mount +from starlette.routing import Route as StarletteRoute, BaseRoute, Mount from nested_dict import nested_dict from backend.route import Route -def construct_route_map_from_dict(route_dict: dict) -> list: +def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]: route_map = [] for mount, item in route_dict.items(): if inspect.isclass(item): @@ -26,35 +28,39 @@ def construct_route_map_from_dict(route_dict: dict) -> list: return route_map -def create_route_map() -> list: - routes_directory = Path("backend") / "routes" - - route_dict = nested_dict() - - for file in routes_directory.rglob("*.py"): - import_name = f"{str(file.parent).replace('/', '.')}.{file.stem}" +def is_route_class(member: t.Any) -> bool: + return inspect.isclass(member) and issubclass(member, Route) and member != Route - route = importlib.import_module(import_name) - for _member_name, member in inspect.getmembers(route): - if inspect.isclass(member): - if issubclass(member, Route) and member != Route: - member.check_parameters() +def route_classes() -> t.Iterator[tuple[Path, type[Route]]]: + routes_directory = Path("backend") / "routes" - levels = str(file.parent).split("/")[2:] + for module_path in routes_directory.rglob("*.py"): + import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}" + route_module = importlib.import_module(import_name) + for _member_name, member in inspect.getmembers(route_module): + if is_route_class(member): + member.check_parameters() + yield (module_path, member) - current_level = None - for level in levels: - if current_level is None: - current_level = route_dict[f"/{level}"] - else: - current_level = current_level[f"/{level}"] - if current_level is not None: - current_level[member.path] = member - else: - route_dict[member.path] = member +def create_route_map() -> list[BaseRoute]: + route_dict = nested_dict() - route_map = construct_route_map_from_dict(route_dict.to_dict()) + for module_path, member in route_classes(): + # module_path == Path("backend/routes/foo/bar/baz/bin.py") + # => levels == ["foo", "bar", "baz"] + levels = module_path.parent.parts[2:] + current_level = None + for level in levels: + if current_level is None: + current_level = route_dict[f"/{level}"] + else: + current_level = current_level[f"/{level}"] + + if current_level is not None: + current_level[member.path] = member + else: + route_dict[member.path] = member - return route_map + return construct_route_map_from_dict(route_dict.to_dict()) diff --git a/poetry.lock b/poetry.lock index 9620a71..a32ce7e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -317,7 +317,7 @@ full = ["aiofiles", "graphene", "itsdangerous", "jinja2", "python-multipart", "p [[package]] name = "uvicorn" -version = "0.12.3" +version = "0.13.1" description = "The lightning-fast ASGI server." category = "main" optional = false @@ -364,7 +364,7 @@ python-versions = ">=3.6.1" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "aaf11a0a091a509880f780239e67d0c12c2213d1570e07c78acc60b20814bbe7" +content-hash = "31df3f6fb5c2739f0ac3158fc73d7ec699bf0b4a228b936e35463f0f977d4beb" [metadata.files] aiodns = [ @@ -647,8 +647,8 @@ starlette = [ {file = "starlette-0.14.1.tar.gz", hash = "sha256:5268ef5d4904ec69582d5fd207b869a5aa0cd59529848ba4cf429b06e3ced99a"}, ] uvicorn = [ - {file = "uvicorn-0.12.3-py3-none-any.whl", hash = "sha256:562ef6aaa8fa723ab6b82cf9e67a774088179d0ec57cb17e447b15d58b603bcf"}, - {file = "uvicorn-0.12.3.tar.gz", hash = "sha256:5836edaf4d278fe67ba0298c0537bdb6398cf359eb644f79e6500ca1aad232b3"}, + {file = "uvicorn-0.13.1-py3-none-any.whl", hash = "sha256:6fcce74c00b77d4f4b3ed7ba1b2a370d27133bfdb46f835b7a76dfe0a8c110ae"}, + {file = "uvicorn-0.13.1.tar.gz", hash = "sha256:2a7b17f4d9848d6557ccc2274a5f7c97f1daf037d130a0c6918f67cd9bc8cdf5"}, ] uvloop = [ {file = "uvloop-0.14.0-cp35-cp35m-macosx_10_11_x86_64.whl", hash = "sha256:08b109f0213af392150e2fe6f81d33261bb5ce968a288eb698aad4f46eb711bd"}, diff --git a/pyproject.toml b/pyproject.toml index 4b8e993..774dc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" python = "^3.9" starlette = "^0.14.0" nested_dict = "^1.61" -uvicorn = {extras = ["standard"], version = "^0.12.2"} +uvicorn = {extras = ["standard"], version = "^0.13.0"} motor = "^2.3.0" python-dotenv = "^0.15.0" pyjwt = "^1.7.1" |