diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/authentication/backend.py | 9 | ||||
| -rw-r--r-- | backend/authentication/user.py | 45 | ||||
| -rw-r--r-- | backend/discord.py | 105 | ||||
| -rw-r--r-- | backend/models/__init__.py | 5 | ||||
| -rw-r--r-- | backend/models/discord_user.py | 9 | ||||
| -rw-r--r-- | backend/routes/auth/authorize.py | 12 | ||||
| -rw-r--r-- | backend/routes/discord.py | 83 | ||||
| -rw-r--r-- | backend/routes/forms/submit.py | 2 | ||||
| -rw-r--r-- | backend/routes/roles.py | 36 | 
9 files changed, 246 insertions, 60 deletions
| diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index c7590e9..54385e2 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -5,6 +5,7 @@ from starlette import authentication  from starlette.requests import Request  from backend import constants +from backend import discord  # We must import user such way here to avoid circular imports  from .user import User @@ -60,8 +61,12 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):          except Exception:              raise authentication.AuthenticationError("Could not parse user details.") -        user = User(token, user_details) -        if await user.fetch_admin_status(request): +        user = User( +            token, user_details, await discord.get_member(request.state.db, user_details["id"]) +        ) +        if await user.fetch_admin_status(request.state.db):              scopes.append("admin") +        scopes.extend(await user.get_user_roles(request.state.db)) +          return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index 857c2ed..0ec0188 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,20 +1,27 @@ +import typing  import typing as t  import jwt +from pymongo.database import Database  from starlette.authentication import BaseUser -from starlette.requests import Request +from backend import discord, models  from backend.constants import SECRET_KEY -from backend.discord import fetch_user_details  class User(BaseUser):      """Starlette BaseUser implementation for JWT authentication.""" -    def __init__(self, token: str, payload: dict[str, t.Any]) -> None: +    def __init__( +        self, +        token: str, +        payload: dict[str, t.Any], +        member: typing.Optional[models.DiscordMember], +    ) -> None:          self.token = token          self.payload = payload          self.admin = False +        self.member = member      @property      def is_authenticated(self) -> bool: @@ -34,16 +41,40 @@ class User(BaseUser):      def decoded_token(self) -> dict[str, any]:          return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) -    async def fetch_admin_status(self, request: Request) -> bool: -        self.admin = await request.state.db.admins.find_one( +    async def get_user_roles(self, database: Database) -> list[str]: +        """Get a list of the user's discord roles.""" +        if not self.member: +            return [] + +        server_roles = await discord.get_roles(database) +        roles = [] + +        for role in server_roles: +            if role.id in self.member.roles: +                roles.append(role.name) + +        if "admin" in roles: +            # Protect against collision with the forms admin role +            roles.remove("admin") +            roles.append("discord admin") + +        return roles + +    async def fetch_admin_status(self, database: Database) -> bool: +        self.admin = await database.admins.find_one(              {"_id": self.payload["id"]}          ) is not None          return self.admin -    async def refresh_data(self) -> None: +    async def refresh_data(self, database: Database) -> None:          """Fetches user data from discord, and updates the instance.""" -        self.payload = await fetch_user_details(self.decoded_token.get("token")) +        self.member = await discord.get_member(database, self.payload["id"]) + +        if self.member: +            self.payload = self.member.user.dict() +        else: +            self.payload = await discord.fetch_user_details(self.decoded_token.get("token"))          updated_info = self.decoded_token          updated_info["user_details"] = self.payload diff --git a/backend/discord.py b/backend/discord.py index cf80cf3..51de26a 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -1,8 +1,13 @@  """Various utilities for working with the Discord API.""" + +import datetime +import json +import typing +  import httpx +from pymongo.database import Database -from backend import constants -from backend.models import discord_role, discord_user +from backend import constants, models  async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict: @@ -40,7 +45,7 @@ async def fetch_user_details(bearer_token: str) -> dict:          return r.json() -async def get_role_info() -> list[discord_role.DiscordRole]: +async def _get_role_info() -> list[models.DiscordRole]:      """Get information about the roles in the configured guild."""      async with httpx.AsyncClient() as client:          r = await client.get( @@ -49,11 +54,50 @@ async def get_role_info() -> list[discord_role.DiscordRole]:          )          r.raise_for_status() -        return [discord_role.DiscordRole(**role) for role in r.json()] - - -async def get_member(member_id: str) -> discord_user.DiscordMember: -    """Get a member by ID from the configured guild.""" +        return [models.DiscordRole(**role) for role in r.json()] + + +async def get_roles( +    database: Database, *, force_refresh: bool = False +) -> list[models.DiscordRole]: +    """ +    Get a list of all roles from the cache, or discord API if not available. + +    If `force_refresh` is True, the cache is skipped and the roles are updated. +    """ +    collection = database.get_collection("roles") + +    if force_refresh: +        # Drop all values in the collection +        await collection.delete_many({}) + +    # `create_index` creates the index if it does not exist, or passes +    # This handles TTL on role objects +    await collection.create_index( +        "inserted_at", +        expireAfterSeconds=60 * 60 * 24,  # 1 day +        name="inserted_at", +    ) + +    roles = [] +    async for role in collection.find(): +        roles.append(models.DiscordRole(**json.loads(role["data"]))) + +    if len(roles) == 0: +        # Fetch roles from the API and insert into the database +        roles = await _get_role_info() +        await collection.insert_many({ +            "name": role.name, +            "id": role.id, +            "data": role.json(), +            "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), +        } for role in roles) + +    return roles + + +async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]: +    """Get a member by ID from the configured guild using the discord API."""      async with httpx.AsyncClient() as client:          r = await client.get(              f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" @@ -61,5 +105,48 @@ async def get_member(member_id: str) -> discord_user.DiscordMember:              headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}          ) +        if r.status_code == 404: +            return None +          r.raise_for_status() -        return discord_user.DiscordMember(**r.json()) +        return models.DiscordMember(**r.json()) + + +async def get_member( +    database: Database, user_id: str, *, force_refresh: bool = False +) -> typing.Optional[models.DiscordMember]: +    """ +    Get a member from the cache, or from the discord API. + +    If `force_refresh` is True, the cache is skipped and the entry is updated. +    None may be returned if the member object does not exist. +    """ +    collection = database.get_collection("discord_members") + +    if force_refresh: +        await collection.delete_one({"user": user_id}) + +    # `create_index` creates the index if it does not exist, or passes +    # This handles TTL on member objects +    await collection.create_index( +        "inserted_at", +        expireAfterSeconds=60 * 60,  # 1 hour +        name="inserted_at", +    ) + +    result = await collection.find_one({"user": user_id}) + +    if result is not None: +        return models.DiscordMember(**json.loads(result["data"])) + +    member = await _fetch_member_api(user_id) + +    if not member: +        return None + +    await collection.insert_one({ +        "user": user_id, +        "data": member.json(), +        "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), +    }) +    return member diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 8ad7f7f..a9f76e0 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,12 +1,15 @@  from .antispam import AntiSpam -from .discord_user import DiscordUser +from .discord_role import DiscordRole +from .discord_user import DiscordMember, DiscordUser  from .form import Form, FormList  from .form_response import FormResponse, ResponseList  from .question import CodeQuestion, Question  __all__ = [      "AntiSpam", +    "DiscordRole",      "DiscordUser", +    "DiscordMember",      "Form",      "FormResponse",      "CodeQuestion", diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py index 3f4209d..0eca15b 100644 --- a/backend/models/discord_user.py +++ b/backend/models/discord_user.py @@ -43,3 +43,12 @@ class DiscordMember(BaseModel):      pending: t.Optional[bool]      permissions: t.Optional[str]      communication_disabled_until: t.Optional[datetime.datetime] + +    def dict(self, *args, **kwargs) -> dict[str, t.Any]: +        """Convert the model to a python dict, and encode timestamps in a serializable format.""" +        data = super().dict(*args, **kwargs) +        for field, value in data.items(): +            if isinstance(value, datetime.datetime): +                data[field] = value.isoformat() + +        return data diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index d4587f0..42fb3ec 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -17,7 +17,7 @@ from starlette.requests import Request  from backend import constants  from backend.authentication.user import User  from backend.constants import SECRET_KEY -from backend.discord import fetch_bearer_token, fetch_user_details +from backend.discord import fetch_bearer_token, fetch_user_details, get_member  from backend.route import Route  from backend.validation import ErrorMessage, api @@ -34,8 +34,8 @@ class AuthorizeResponse(BaseModel):  async def process_token( -        bearer_token: dict, -        request: Request +    bearer_token: dict, +    request: Request  ) -> Union[AuthorizeResponse, AUTH_FAILURE]:      """Post a bearer token to Discord, and return a JWT and username."""      interaction_start = datetime.datetime.now() @@ -46,6 +46,9 @@ async def process_token(          AUTH_FAILURE.delete_cookie("token")          return AUTH_FAILURE +    user_id = user_details["id"] +    member = await get_member(request.state.db, user_id, force_refresh=True) +      max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))      token_expiry = interaction_start + max_age @@ -53,11 +56,12 @@ async def process_token(          "token": bearer_token["access_token"],          "refresh": bearer_token["refresh_token"],          "user_details": user_details, +        "in_guild": bool(member),          "expiry": token_expiry.isoformat()      }      token = jwt.encode(data, SECRET_KEY, algorithm="HS256") -    user = User(token, user_details) +    user = User(token, user_details, member)      response = responses.JSONResponse({          "username": user.display_name, diff --git a/backend/routes/discord.py b/backend/routes/discord.py new file mode 100644 index 0000000..a980d94 --- /dev/null +++ b/backend/routes/discord.py @@ -0,0 +1,83 @@ +"""Routes which directly interact with discord related data.""" + +import pydantic +from spectree import Response +from starlette.authentication import requires +from starlette.responses import JSONResponse +from starlette.routing import Request + +from backend import discord, models, route +from backend.validation import ErrorMessage, OkayResponse, api + +NOT_FOUND_EXCEPTION = JSONResponse( +    {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 +) + + +class RolesRoute(route.Route): +    """Refreshes the roles database.""" + +    name = "roles" +    path = "/roles" + +    class RolesResponse(pydantic.BaseModel): +        """A list of all roles on the configured server.""" + +        roles: list[models.DiscordRole] + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=OkayResponse), +        tags=["roles"] +    ) +    async def patch(self, request: Request) -> JSONResponse: +        """Refresh the roles database.""" +        roles = await discord.get_roles(request.state.db, force_refresh=True) + +        return JSONResponse( +            {"status": "ok"}, +        ) + + +class MemberRoute(route.Route): +    """Retrieve information about a server member.""" + +    name = "member" +    path = "/member" + +    class MemberRequest(pydantic.BaseModel): +        """An ID of the member to update.""" + +        user_id: str + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), +        json=MemberRequest, +        tags=["auth"] +    ) +    async def delete(self, request: Request): +        """Force a resync of the cache for the given user.""" +        body = await request.json() +        member = await discord.get_member(request.state.db, body["user_id"], force_refresh=True) + +        if member: +            return JSONResponse(member.dict()) +        else: +            return NOT_FOUND_EXCEPTION + +    @requires(["authenticated", "admin"]) +    @api.validate( +        resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), +        json=MemberRequest, +        tags=["auth"] +    ) +    async def get(self, request: Request): +        """Get a user's roles on the configured server.""" +        body = await request.json() +        member = await discord.get_member(request.state.db, body["user_id"]) + +        if member: +            return JSONResponse(member.dict()) +        else: +            return NOT_FOUND_EXCEPTION diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 95e30b0..baf403d 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -83,7 +83,7 @@ class SubmitForm(Route):          try:              if hasattr(request.user, User.refresh_data.__name__):                  old = request.user.token -                await request.user.refresh_data() +                await request.user.refresh_data(request.state.db)                  if old != request.user.token:                      try: diff --git a/backend/routes/roles.py b/backend/routes/roles.py deleted file mode 100644 index b18a04b..0000000 --- a/backend/routes/roles.py +++ /dev/null @@ -1,36 +0,0 @@ -import starlette.background -from pymongo.database import Database -from spectree import Response -from starlette.authentication import requires -from starlette.responses import JSONResponse -from starlette.routing import Request - -from backend import discord, route -from backend.validation import OkayResponse, api - - -async def refresh_roles(database: Database) -> None: -    """Connect to the discord API and refresh the roles database.""" -    roles = await discord.get_role_info() -    roles_collection = database.get_collection("roles") -    roles_collection.drop() -    roles_collection.insert_many([role.dict() for role in roles]) - - -class RolesRoute(route.Route): -    """Refreshes the roles database.""" - -    name = "roles" -    path = "/roles" - -    @requires(["authenticated", "admin"]) -    @api.validate( -        resp=Response(HTTP_200=OkayResponse), -        tags=["roles"] -    ) -    async def patch(self, request: Request) -> JSONResponse: -        """Refresh the roles database.""" -        return JSONResponse( -            {"status": "ok"}, -            background=starlette.background.BackgroundTask(refresh_roles, request.state.db) -        ) | 
