diff options
| -rw-r--r-- | backend/constants.py | 6 | ||||
| -rw-r--r-- | backend/routes/auth/authorize.py | 49 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 9 | 
3 files changed, 45 insertions, 19 deletions
| diff --git a/backend/constants.py b/backend/constants.py index e1f4a5b..4bb7fd1 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -1,8 +1,9 @@ -from dotenv import load_dotenv -import os  import binascii +import os  from enum import Enum +from dotenv import load_dotenv +  load_dotenv() @@ -12,6 +13,7 @@ MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")  SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")  PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false" +PRODUCTION_URL = "https://forms.pythondiscord.com/"  OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")  OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET") diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index 26d8622..1e773d6 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -10,9 +10,9 @@ import jwt  from pydantic.fields import Field  from pydantic.main import BaseModel  from spectree.response import Response +from starlette import responses  from starlette.authentication import requires  from starlette.requests import Request -from starlette.responses import JSONResponse  from backend import constants  from backend.authentication.user import User @@ -21,7 +21,7 @@ from backend.discord import fetch_bearer_token, fetch_user_details  from backend.route import Route  from backend.validation import ErrorMessage, api -AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400) +AUTH_FAILURE = responses.JSONResponse({"error": "auth_failure"}, status_code=400)  class AuthorizeRequest(BaseModel): @@ -33,7 +33,7 @@ class AuthorizeResponse(BaseModel):      expiry: str = Field("ISO formatted timestamp of expiry.") -async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]: +async def process_token(bearer_token: dict, origin: str) -> Union[AuthorizeResponse, AUTH_FAILURE]:      """Post a bearer token to Discord, and return a JWT and username."""      interaction_start = datetime.datetime.now() @@ -56,17 +56,42 @@ async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAI      token = jwt.encode(data, SECRET_KEY, algorithm="HS256")      user = User(token, user_details) -    response = JSONResponse({ +    response = responses.JSONResponse({          "username": user.display_name,          "expiry": token_expiry.isoformat()      }) +    await set_response_token(response, origin, token, bearer_token["expires_in"]) +    return response + + +async def set_response_token( +    response: responses.Response, +    origin_url: str, +    new_token: str, +    expiry: int +) -> None: +    """Helper that handles logic for updating a token in a set-cookie response.""" +    if origin_url == constants.PRODUCTION_URL: +        domain = constants.PRODUCTION_URL +        samesite = "strict" + +    elif not constants.PRODUCTION: +        domain = None +        samesite = "strict" + +    else: +        domain = origin_url +        samesite = "None" +      response.set_cookie( -        "token", f"JWT {token}", -        secure=constants.PRODUCTION, httponly=True, samesite="strict", -        max_age=bearer_token["expires_in"] +        "token", f"JWT {new_token}", +        secure=constants.PRODUCTION, +        httponly=True, +        samesite=samesite, +        domain=domain, +        max_age=expiry      ) -    return response  class AuthorizeRoute(Route): @@ -82,7 +107,7 @@ class AuthorizeRoute(Route):          resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),          tags=["auth"]      ) -    async def post(self, request: Request) -> JSONResponse: +    async def post(self, request: Request) -> responses.JSONResponse:          """Generate an authorization token."""          data = await request.json()          try: @@ -91,7 +116,7 @@ class AuthorizeRoute(Route):          except httpx.HTTPStatusError:              return AUTH_FAILURE -        return await process_token(bearer_token) +        return await process_token(bearer_token, url)  class TokenRefreshRoute(Route): @@ -107,7 +132,7 @@ class TokenRefreshRoute(Route):          resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),          tags=["auth"]      ) -    async def post(self, request: Request) -> JSONResponse: +    async def post(self, request: Request) -> responses.JSONResponse:          """Refresh an authorization token."""          try:              token = request.user.decoded_token.get("refresh") @@ -116,4 +141,4 @@ class TokenRefreshRoute(Route):          except httpx.HTTPStatusError:              return AUTH_FAILURE -        return await process_token(bearer_token) +        return await process_token(bearer_token, url) diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 8680b2d..975307b 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -20,6 +20,7 @@ from backend import constants  from backend.authentication.user import User  from backend.models import Form, FormResponse  from backend.route import Route +from backend.routes.auth.authorize import set_response_token  from backend.routes.forms.unittesting import execute_unittest  from backend.validation import ErrorMessage, api @@ -74,11 +75,9 @@ class SubmitForm(Route):                      except ValueError:                          expiry = None -                    response.set_cookie( -                        "token", f"JWT {request.user.token}", -                        secure=constants.PRODUCTION, httponly=True, samesite="strict", -                        max_age=(expiry - datetime.datetime.now()).seconds -                    ) +                    origin = request.headers.get("origin") +                    expiry_seconds = (expiry - datetime.datetime.now()).seconds +                    await set_response_token(response, origin, request.user.token, expiry_seconds)          except httpx.HTTPStatusError:              pass | 
