aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/__init__.py5
-rw-r--r--backend/authentication/__init__.py4
-rw-r--r--backend/authentication/backend.py53
-rw-r--r--backend/authentication/user.py22
-rw-r--r--backend/routes/index.py18
5 files changed, 98 insertions, 4 deletions
diff --git a/backend/__init__.py b/backend/__init__.py
index 6215961..7b4cb4d 100644
--- a/backend/__init__.py
+++ b/backend/__init__.py
@@ -2,8 +2,10 @@ import os
from starlette.applications import Starlette
from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
+from backend.authentication import JWTAuthenticationBackend
from backend.route_manager import create_route_map
from backend.middleware import DatabaseMiddleware
@@ -19,7 +21,8 @@ middleware = [
],
allow_methods=["*"]
),
- Middleware(DatabaseMiddleware)
+ Middleware(DatabaseMiddleware),
+ Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend())
]
app = Starlette(routes=create_route_map(), middleware=middleware)
diff --git a/backend/authentication/__init__.py b/backend/authentication/__init__.py
new file mode 100644
index 0000000..43601a7
--- /dev/null
+++ b/backend/authentication/__init__.py
@@ -0,0 +1,4 @@
+from .backend import JWTAuthenticationBackend
+from .user import User
+
+__all__ = ["JWTAuthenticationBackend", "User"]
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
new file mode 100644
index 0000000..38668eb
--- /dev/null
+++ b/backend/authentication/backend.py
@@ -0,0 +1,53 @@
+import jwt
+import typing as t
+from abc import ABC
+
+from starlette import authentication
+from starlette.requests import Request
+
+from backend import constants
+# We must import user such way here to avoid circular imports
+from .user import User
+
+
+class JWTAuthenticationBackend(authentication.AuthenticationBackend, ABC):
+ """Custom Starlette authentication backend for JWT."""
+
+ @staticmethod
+ def get_token_from_header(header: str) -> t.Optional[str]:
+ """Parse JWT token from header value."""
+ try:
+ prefix, token = header.split()
+ except ValueError:
+ raise authentication.AuthenticationError(
+ "Unable to split prefix and token from Authorization header."
+ )
+
+ if prefix.upper() != "JWT":
+ raise authentication.AuthenticationError(
+ f"Invalid Authorization header prefix '{prefix}'."
+ )
+
+ return token
+
+ async def authenticate(
+ self, request: Request
+ ) -> t.Optional[t.Tuple[authentication.AuthCredentials, authentication.BaseUser]]:
+ """Handles JWT authentication process."""
+ if "Authorization" not in request.headers:
+ return
+
+ auth = request.headers["Authorization"]
+ token = self.get_token_from_header(auth)
+
+ try:
+ payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"])
+ except jwt.InvalidTokenError as e:
+ raise authentication.AuthenticationError(str(e))
+
+ scopes = ["authenticated"]
+
+ if payload.get("admin", False) is True:
+ scopes.append("admin")
+
+ return authentication.AuthCredentials(scopes), User(token, payload)
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
new file mode 100644
index 0000000..afa243f
--- /dev/null
+++ b/backend/authentication/user.py
@@ -0,0 +1,22 @@
+import typing as t
+from abc import ABC
+
+from starlette.authentication import BaseUser
+
+
+class User(BaseUser, ABC):
+ """Starlette BaseUser implementation for JWT authentication."""
+
+ def __init__(self, token: str, payload: t.Dict) -> None:
+ self.token = token
+ self.payload = payload
+
+ @property
+ def is_authenticated(self) -> bool:
+ """Returns True because user is always authenticated at this stage."""
+ return True
+
+ @property
+ def display_name(self) -> str:
+ """Return username and discriminator as display name."""
+ return f"{self.payload['username']}#{self.payload['discriminator']}"
diff --git a/backend/routes/index.py b/backend/routes/index.py
index 8144723..b37f381 100644
--- a/backend/routes/index.py
+++ b/backend/routes/index.py
@@ -18,7 +18,19 @@ class IndexRoute(Route):
path = "/"
def get(self, request: Request) -> JSONResponse:
- return JSONResponse({
+ response_data = {
"message": "Hello, world!",
- "client": request.client.host
- })
+ "client": request.client.host,
+ "user": {
+ "authenticated": False
+ }
+ }
+
+ if request.user.is_authenticated:
+ response_data["user"] = {
+ "authenticated": True,
+ "user": request.user.payload,
+ "scopes": request.auth.scopes
+ }
+
+ return JSONResponse(response_data)