diff options
Diffstat (limited to '')
| -rw-r--r-- | backend/authentication/backend.py | 11 | ||||
| -rw-r--r-- | backend/authentication/user.py | 5 | ||||
| -rw-r--r-- | backend/models/form.py | 12 | ||||
| -rw-r--r-- | backend/models/form_response.py | 2 | ||||
| -rw-r--r-- | backend/models/question.py | 8 | ||||
| -rw-r--r-- | backend/route.py | 10 | ||||
| -rw-r--r-- | backend/route_manager.py | 60 | 
7 files changed, 56 insertions, 52 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index 38668eb..f1d2ece 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -1,6 +1,5 @@  import jwt  import typing as t -from abc import ABC  from starlette import authentication  from starlette.requests import Request @@ -10,11 +9,11 @@ from backend import constants  from .user import User -class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC): +class JWTAuthenticationBackend(authentication.AuthenticationBackend):      """Custom Starlette authentication backend for JWT."""      @staticmethod -    def get_token_from_header(header: str) -> t.Optional[str]: +    def get_token_from_header(header: str) -> str:          """Parse JWT token from header value."""          try:              prefix, token = header.split() @@ -32,10 +31,10 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):      async def authenticate(          self, request: Request -    ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]: +    ) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:          """Handles JWT authentication process."""          if "Authorization" not in request.headers: -            return +            return None          auth = request.headers["Authorization"]          token = self.get_token_from_header(auth) @@ -47,7 +46,7 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):          scopes = ["authenticated"] -        if payload.get("admin", False) is True: +        if payload.get("admin") is True:              scopes.append("admin")          return authentication.AuthCredentials(scopes), User(token, payload) diff --git a/backend/authentication/user.py b/backend/authentication/user.py index afa243f..722c348 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,13 +1,12 @@  import typing as t -from abc import ABC  from starlette.authentication import BaseUser -class User(BaseUser, ABC): +class User(BaseUser):      """Starlette BaseUser implementation for JWT authentication.""" -    def __init__(self, token: str, payload: t.Dict) -> None: +    def __init__(self, token: str, payload: dict[str, t.Any]) -> None:          self.token = token          self.payload = payload diff --git a/backend/models/form.py b/backend/models/form.py index 2cf8486..cb58065 100644 --- a/backend/models/form.py +++ b/backend/models/form.py @@ -12,8 +12,8 @@ class Form(BaseModel):      """Schema model for form."""      id: str = Field(alias="_id") -    features: t.List[str] -    questions: t.List[Question] +    features: list[str] +    questions: list[Question]      name: str      description: str @@ -21,12 +21,12 @@ class Form(BaseModel):          allow_population_by_field_name = True      @validator("features") -    def validate_features(cls, value: t.List[str]) -> t.Optional[t.List[str]]: +    def validate_features(cls, value: list[str]) -> t.Optional[list[str]]:          """Validates is all features in allowed list."""          # Uppercase everything to avoid mixed case in DB          value = [v.upper() for v in value] -        allowed_values = list(v.value for v in FormFeatures.__members__.values()) -        if not all(v in allowed_values for v in value): +        allowed_values = [v.value for v in FormFeatures.__members__.values()] +        if any(v not in allowed_values for v in value):              raise ValueError("Form features list contains one or more invalid values.")          if FormFeatures.COLLECT_EMAIL in value and FormFeatures.REQUIRES_LOGIN not in value:  # noqa @@ -34,7 +34,7 @@ class Form(BaseModel):          return value -    def dict(self, admin: bool = True, **kwargs: t.Dict) -> t.Dict[str, t.Any]: +    def dict(self, admin: bool = True, **kwargs: t.Any) -> dict[str, t.Any]:          """Wrapper for original function to exclude private data for public access."""          data = super().dict(**kwargs) diff --git a/backend/models/form_response.py b/backend/models/form_response.py index bea070f..f3296cd 100644 --- a/backend/models/form_response.py +++ b/backend/models/form_response.py @@ -12,7 +12,7 @@ class FormResponse(BaseModel):      id: str = Field(alias="_id")      user: t.Optional[DiscordUser]      antispam: t.Optional[AntiSpam] -    response: t.Dict[str, t.Any] +    response: dict[str, t.Any]      form_id: str      class Config: diff --git a/backend/models/question.py b/backend/models/question.py index 1a012ff..3b98024 100644 --- a/backend/models/question.py +++ b/backend/models/question.py @@ -11,7 +11,7 @@ class Question(BaseModel):      id: str = Field(alias="_id")      name: str      type: str -    data: t.Dict[str, t.Any] +    data: dict[str, t.Any]      class Config:          allow_population_by_field_name = True @@ -31,14 +31,14 @@ class Question(BaseModel):      @root_validator      def validate_question_data(              cls, -            value: t.Dict[str, t.Any] -    ) -> t.Optional[t.Dict[str, t.Any]]: +            value: dict[str, t.Any] +    ) -> t.Optional[dict[str, t.Any]]:          """Check does required data exists for question type and remove other data."""          # When question type don't need data, don't add anything to keep DB clean.          if value.get("type") not in REQUIRED_QUESTION_TYPE_DATA:              return value -        for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value.get("type")].items(): +        for key, data_type in REQUIRED_QUESTION_TYPE_DATA[value["type"]].items():              if key not in value.get("data", {}):                  raise ValueError(f"Required question data key '{key}' not provided.") diff --git a/backend/route.py b/backend/route.py index eb69ebc..d778bf0 100644 --- a/backend/route.py +++ b/backend/route.py @@ -5,13 +5,13 @@ from starlette.endpoints import HTTPEndpoint  class Route(HTTPEndpoint): -    name: str = None -    path: str = None +    name: str +    path: str      @classmethod -    def check_parameters(cls) -> "Route": -        if cls.name is None: +    def check_parameters(cls) -> None: +        if not hasattr(cls, "name"):              raise ValueError(f"Route {cls.__name__} has not defined a name") -        if cls.path is None: +        if not hasattr(cls, "path"):              raise ValueError(f"Route {cls.__name__} has not defined a path") diff --git a/backend/route_manager.py b/backend/route_manager.py index 25529eb..031c9b3 100644 --- a/backend/route_manager.py +++ b/backend/route_manager.py @@ -4,15 +4,17 @@ Module to dynamically generate a Starlette routing map based on a directory tree  import importlib  import inspect +import typing as t +  from pathlib import Path -from starlette.routing import Route as StarletteRoute, Mount +from starlette.routing import Route as StarletteRoute, BaseRoute, Mount  from nested_dict import nested_dict  from backend.route import Route -def construct_route_map_from_dict(route_dict: dict) -> list: +def construct_route_map_from_dict(route_dict: dict) -> list[BaseRoute]:      route_map = []      for mount, item in route_dict.items():          if inspect.isclass(item): @@ -26,35 +28,39 @@ def construct_route_map_from_dict(route_dict: dict) -> list:      return route_map -def create_route_map() -> list: -    routes_directory = Path("backend") / "routes" - -    route_dict = nested_dict() - -    for file in routes_directory.rglob("*.py"): -        import_name = f"{str(file.parent).replace('/', '.')}.{file.stem}" +def is_route_class(member: t.Any) -> bool: +    return inspect.isclass(member) and issubclass(member, Route) and member != Route -        route = importlib.import_module(import_name) -        for _member_name, member in inspect.getmembers(route): -            if inspect.isclass(member): -                if issubclass(member, Route) and member != Route: -                    member.check_parameters() +def route_classes() -> t.Iterator[tuple[Path, type[Route]]]: +    routes_directory = Path("backend") / "routes" -                    levels = str(file.parent).split("/")[2:] +    for module_path in routes_directory.rglob("*.py"): +        import_name = f"{'.'.join(module_path.parent.parts)}.{module_path.stem}" +        route_module = importlib.import_module(import_name) +        for _member_name, member in inspect.getmembers(route_module): +            if is_route_class(member): +                member.check_parameters() +                yield (module_path, member) -                    current_level = None -                    for level in levels: -                        if current_level is None: -                            current_level = route_dict[f"/{level}"] -                        else: -                            current_level = current_level[f"/{level}"] -                    if current_level is not None: -                        current_level[member.path] = member -                    else: -                        route_dict[member.path] = member +def create_route_map() -> list[BaseRoute]: +    route_dict = nested_dict() -    route_map = construct_route_map_from_dict(route_dict.to_dict()) +    for module_path, member in route_classes(): +        #    module_path == Path("backend/routes/foo/bar/baz/bin.py") +        # => levels == ["foo", "bar", "baz"] +        levels = module_path.parent.parts[2:] +        current_level = None +        for level in levels: +            if current_level is None: +                current_level = route_dict[f"/{level}"] +            else: +                current_level = current_level[f"/{level}"] + +        if current_level is not None: +            current_level[member.path] = member +        else: +            route_dict[member.path] = member -    return route_map +    return construct_route_map_from_dict(route_dict.to_dict())  |