diff options
author | 2022-02-05 16:50:11 +0400 | |
---|---|---|
committer | 2022-02-05 17:20:40 +0400 | |
commit | 134b2f70e4cf947744f1b061766bb37fe616ad65 (patch) | |
tree | ef6b95bc5a78528d91ae969f3cfd00bc8e5be8ed /backend | |
parent | Add Helper Functions For Managing Roles (diff) |
Overhaul Scope System
Adds discord role support to the pre-existing scopes system to power
more complex access permissions.
Signed-off-by: Hassan Abouelela <[email protected]>
Diffstat (limited to 'backend')
-rw-r--r-- | backend/authentication/backend.py | 9 | ||||
-rw-r--r-- | backend/authentication/user.py | 45 | ||||
-rw-r--r-- | backend/discord.py | 105 | ||||
-rw-r--r-- | backend/models/__init__.py | 5 | ||||
-rw-r--r-- | backend/models/discord_user.py | 9 | ||||
-rw-r--r-- | backend/routes/auth/authorize.py | 12 | ||||
-rw-r--r-- | backend/routes/discord.py | 83 | ||||
-rw-r--r-- | backend/routes/forms/submit.py | 2 | ||||
-rw-r--r-- | backend/routes/roles.py | 36 |
9 files changed, 246 insertions, 60 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index c7590e9..54385e2 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -5,6 +5,7 @@ from starlette import authentication from starlette.requests import Request from backend import constants +from backend import discord # We must import user such way here to avoid circular imports from .user import User @@ -60,8 +61,12 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend): except Exception: raise authentication.AuthenticationError("Could not parse user details.") - user = User(token, user_details) - if await user.fetch_admin_status(request): + user = User( + token, user_details, await discord.get_member(request.state.db, user_details["id"]) + ) + if await user.fetch_admin_status(request.state.db): scopes.append("admin") + scopes.extend(await user.get_user_roles(request.state.db)) + return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index 857c2ed..0ec0188 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -1,20 +1,27 @@ +import typing import typing as t import jwt +from pymongo.database import Database from starlette.authentication import BaseUser -from starlette.requests import Request +from backend import discord, models from backend.constants import SECRET_KEY -from backend.discord import fetch_user_details class User(BaseUser): """Starlette BaseUser implementation for JWT authentication.""" - def __init__(self, token: str, payload: dict[str, t.Any]) -> None: + def __init__( + self, + token: str, + payload: dict[str, t.Any], + member: typing.Optional[models.DiscordMember], + ) -> None: self.token = token self.payload = payload self.admin = False + self.member = member @property def is_authenticated(self) -> bool: @@ -34,16 +41,40 @@ class User(BaseUser): def decoded_token(self) -> dict[str, any]: return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) - async def fetch_admin_status(self, request: Request) -> bool: - self.admin = await request.state.db.admins.find_one( + async def get_user_roles(self, database: Database) -> list[str]: + """Get a list of the user's discord roles.""" + if not self.member: + return [] + + server_roles = await discord.get_roles(database) + roles = [] + + for role in server_roles: + if role.id in self.member.roles: + roles.append(role.name) + + if "admin" in roles: + # Protect against collision with the forms admin role + roles.remove("admin") + roles.append("discord admin") + + return roles + + async def fetch_admin_status(self, database: Database) -> bool: + self.admin = await database.admins.find_one( {"_id": self.payload["id"]} ) is not None return self.admin - async def refresh_data(self) -> None: + async def refresh_data(self, database: Database) -> None: """Fetches user data from discord, and updates the instance.""" - self.payload = await fetch_user_details(self.decoded_token.get("token")) + self.member = await discord.get_member(database, self.payload["id"]) + + if self.member: + self.payload = self.member.user.dict() + else: + self.payload = await discord.fetch_user_details(self.decoded_token.get("token")) updated_info = self.decoded_token updated_info["user_details"] = self.payload diff --git a/backend/discord.py b/backend/discord.py index cf80cf3..51de26a 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -1,8 +1,13 @@ """Various utilities for working with the Discord API.""" + +import datetime +import json +import typing + import httpx +from pymongo.database import Database -from backend import constants -from backend.models import discord_role, discord_user +from backend import constants, models async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict: @@ -40,7 +45,7 @@ async def fetch_user_details(bearer_token: str) -> dict: return r.json() -async def get_role_info() -> list[discord_role.DiscordRole]: +async def _get_role_info() -> list[models.DiscordRole]: """Get information about the roles in the configured guild.""" async with httpx.AsyncClient() as client: r = await client.get( @@ -49,11 +54,50 @@ async def get_role_info() -> list[discord_role.DiscordRole]: ) r.raise_for_status() - return [discord_role.DiscordRole(**role) for role in r.json()] - - -async def get_member(member_id: str) -> discord_user.DiscordMember: - """Get a member by ID from the configured guild.""" + return [models.DiscordRole(**role) for role in r.json()] + + +async def get_roles( + database: Database, *, force_refresh: bool = False +) -> list[models.DiscordRole]: + """ + Get a list of all roles from the cache, or discord API if not available. + + If `force_refresh` is True, the cache is skipped and the roles are updated. + """ + collection = database.get_collection("roles") + + if force_refresh: + # Drop all values in the collection + await collection.delete_many({}) + + # `create_index` creates the index if it does not exist, or passes + # This handles TTL on role objects + await collection.create_index( + "inserted_at", + expireAfterSeconds=60 * 60 * 24, # 1 day + name="inserted_at", + ) + + roles = [] + async for role in collection.find(): + roles.append(models.DiscordRole(**json.loads(role["data"]))) + + if len(roles) == 0: + # Fetch roles from the API and insert into the database + roles = await _get_role_info() + await collection.insert_many({ + "name": role.name, + "id": role.id, + "data": role.json(), + "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), + } for role in roles) + + return roles + + +async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]: + """Get a member by ID from the configured guild using the discord API.""" async with httpx.AsyncClient() as client: r = await client.get( f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" @@ -61,5 +105,48 @@ async def get_member(member_id: str) -> discord_user.DiscordMember: headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} ) + if r.status_code == 404: + return None + r.raise_for_status() - return discord_user.DiscordMember(**r.json()) + return models.DiscordMember(**r.json()) + + +async def get_member( + database: Database, user_id: str, *, force_refresh: bool = False +) -> typing.Optional[models.DiscordMember]: + """ + Get a member from the cache, or from the discord API. + + If `force_refresh` is True, the cache is skipped and the entry is updated. + None may be returned if the member object does not exist. + """ + collection = database.get_collection("discord_members") + + if force_refresh: + await collection.delete_one({"user": user_id}) + + # `create_index` creates the index if it does not exist, or passes + # This handles TTL on member objects + await collection.create_index( + "inserted_at", + expireAfterSeconds=60 * 60, # 1 hour + name="inserted_at", + ) + + result = await collection.find_one({"user": user_id}) + + if result is not None: + return models.DiscordMember(**json.loads(result["data"])) + + member = await _fetch_member_api(user_id) + + if not member: + return None + + await collection.insert_one({ + "user": user_id, + "data": member.json(), + "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), + }) + return member diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 8ad7f7f..a9f76e0 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,12 +1,15 @@ from .antispam import AntiSpam -from .discord_user import DiscordUser +from .discord_role import DiscordRole +from .discord_user import DiscordMember, DiscordUser from .form import Form, FormList from .form_response import FormResponse, ResponseList from .question import CodeQuestion, Question __all__ = [ "AntiSpam", + "DiscordRole", "DiscordUser", + "DiscordMember", "Form", "FormResponse", "CodeQuestion", diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py index 3f4209d..0eca15b 100644 --- a/backend/models/discord_user.py +++ b/backend/models/discord_user.py @@ -43,3 +43,12 @@ class DiscordMember(BaseModel): pending: t.Optional[bool] permissions: t.Optional[str] communication_disabled_until: t.Optional[datetime.datetime] + + def dict(self, *args, **kwargs) -> dict[str, t.Any]: + """Convert the model to a python dict, and encode timestamps in a serializable format.""" + data = super().dict(*args, **kwargs) + for field, value in data.items(): + if isinstance(value, datetime.datetime): + data[field] = value.isoformat() + + return data diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py index d4587f0..42fb3ec 100644 --- a/backend/routes/auth/authorize.py +++ b/backend/routes/auth/authorize.py @@ -17,7 +17,7 @@ from starlette.requests import Request from backend import constants from backend.authentication.user import User from backend.constants import SECRET_KEY -from backend.discord import fetch_bearer_token, fetch_user_details +from backend.discord import fetch_bearer_token, fetch_user_details, get_member from backend.route import Route from backend.validation import ErrorMessage, api @@ -34,8 +34,8 @@ class AuthorizeResponse(BaseModel): async def process_token( - bearer_token: dict, - request: Request + bearer_token: dict, + request: Request ) -> Union[AuthorizeResponse, AUTH_FAILURE]: """Post a bearer token to Discord, and return a JWT and username.""" interaction_start = datetime.datetime.now() @@ -46,6 +46,9 @@ async def process_token( AUTH_FAILURE.delete_cookie("token") return AUTH_FAILURE + user_id = user_details["id"] + member = await get_member(request.state.db, user_id, force_refresh=True) + max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"])) token_expiry = interaction_start + max_age @@ -53,11 +56,12 @@ async def process_token( "token": bearer_token["access_token"], "refresh": bearer_token["refresh_token"], "user_details": user_details, + "in_guild": bool(member), "expiry": token_expiry.isoformat() } token = jwt.encode(data, SECRET_KEY, algorithm="HS256") - user = User(token, user_details) + user = User(token, user_details, member) response = responses.JSONResponse({ "username": user.display_name, diff --git a/backend/routes/discord.py b/backend/routes/discord.py new file mode 100644 index 0000000..a980d94 --- /dev/null +++ b/backend/routes/discord.py @@ -0,0 +1,83 @@ +"""Routes which directly interact with discord related data.""" + +import pydantic +from spectree import Response +from starlette.authentication import requires +from starlette.responses import JSONResponse +from starlette.routing import Request + +from backend import discord, models, route +from backend.validation import ErrorMessage, OkayResponse, api + +NOT_FOUND_EXCEPTION = JSONResponse( + {"error": "Could not find the requested resource in the guild or cache."}, status_code=404 +) + + +class RolesRoute(route.Route): + """Refreshes the roles database.""" + + name = "roles" + path = "/roles" + + class RolesResponse(pydantic.BaseModel): + """A list of all roles on the configured server.""" + + roles: list[models.DiscordRole] + + @requires(["authenticated", "admin"]) + @api.validate( + resp=Response(HTTP_200=OkayResponse), + tags=["roles"] + ) + async def patch(self, request: Request) -> JSONResponse: + """Refresh the roles database.""" + roles = await discord.get_roles(request.state.db, force_refresh=True) + + return JSONResponse( + {"status": "ok"}, + ) + + +class MemberRoute(route.Route): + """Retrieve information about a server member.""" + + name = "member" + path = "/member" + + class MemberRequest(pydantic.BaseModel): + """An ID of the member to update.""" + + user_id: str + + @requires(["authenticated", "admin"]) + @api.validate( + resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), + json=MemberRequest, + tags=["auth"] + ) + async def delete(self, request: Request): + """Force a resync of the cache for the given user.""" + body = await request.json() + member = await discord.get_member(request.state.db, body["user_id"], force_refresh=True) + + if member: + return JSONResponse(member.dict()) + else: + return NOT_FOUND_EXCEPTION + + @requires(["authenticated", "admin"]) + @api.validate( + resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage), + json=MemberRequest, + tags=["auth"] + ) + async def get(self, request: Request): + """Get a user's roles on the configured server.""" + body = await request.json() + member = await discord.get_member(request.state.db, body["user_id"]) + + if member: + return JSONResponse(member.dict()) + else: + return NOT_FOUND_EXCEPTION diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py index 95e30b0..baf403d 100644 --- a/backend/routes/forms/submit.py +++ b/backend/routes/forms/submit.py @@ -83,7 +83,7 @@ class SubmitForm(Route): try: if hasattr(request.user, User.refresh_data.__name__): old = request.user.token - await request.user.refresh_data() + await request.user.refresh_data(request.state.db) if old != request.user.token: try: diff --git a/backend/routes/roles.py b/backend/routes/roles.py deleted file mode 100644 index b18a04b..0000000 --- a/backend/routes/roles.py +++ /dev/null @@ -1,36 +0,0 @@ -import starlette.background -from pymongo.database import Database -from spectree import Response -from starlette.authentication import requires -from starlette.responses import JSONResponse -from starlette.routing import Request - -from backend import discord, route -from backend.validation import OkayResponse, api - - -async def refresh_roles(database: Database) -> None: - """Connect to the discord API and refresh the roles database.""" - roles = await discord.get_role_info() - roles_collection = database.get_collection("roles") - roles_collection.drop() - roles_collection.insert_many([role.dict() for role in roles]) - - -class RolesRoute(route.Route): - """Refreshes the roles database.""" - - name = "roles" - path = "/roles" - - @requires(["authenticated", "admin"]) - @api.validate( - resp=Response(HTTP_200=OkayResponse), - tags=["roles"] - ) - async def patch(self, request: Request) -> JSONResponse: - """Refresh the roles database.""" - return JSONResponse( - {"status": "ok"}, - background=starlette.background.BackgroundTask(refresh_roles, request.state.db) - ) |