diff options
Diffstat (limited to 'backend/discord.py')
-rw-r--r-- | backend/discord.py | 171 |
1 files changed, 164 insertions, 7 deletions
diff --git a/backend/discord.py b/backend/discord.py index e5c7f8f..be12109 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -1,16 +1,22 @@ """Various utilities for working with the Discord API.""" + +import datetime +import json +import typing + import httpx +import starlette.requests +from pymongo.database import Database +from starlette import exceptions -from backend.constants import ( - DISCORD_API_BASE_URL, OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET -) +from backend import constants, models async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict: async with httpx.AsyncClient() as client: data = { - "client_id": OAUTH2_CLIENT_ID, - "client_secret": OAUTH2_CLIENT_SECRET, + "client_id": constants.OAUTH2_CLIENT_ID, + "client_secret": constants.OAUTH2_CLIENT_SECRET, "redirect_uri": f"{redirect}/callback" } @@ -21,7 +27,7 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict data["grant_type"] = "authorization_code" data["code"] = code - r = await client.post(f"{DISCORD_API_BASE_URL}/oauth2/token", headers={ + r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={ "Content-Type": "application/x-www-form-urlencoded" }, data=data) @@ -32,10 +38,161 @@ async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict async def fetch_user_details(bearer_token: str) -> dict: async with httpx.AsyncClient() as client: - r = await client.get(f"{DISCORD_API_BASE_URL}/users/@me", headers={ + r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={ "Authorization": f"Bearer {bearer_token}" }) r.raise_for_status() return r.json() + + +async def _get_role_info() -> list[models.DiscordRole]: + """Get information about the roles in the configured guild.""" + async with httpx.AsyncClient() as client: + r = await client.get( + f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles", + headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} + ) + + r.raise_for_status() + return [models.DiscordRole(**role) for role in r.json()] + + +async def get_roles( + database: Database, *, force_refresh: bool = False +) -> list[models.DiscordRole]: + """ + Get a list of all roles from the cache, or discord API if not available. + + If `force_refresh` is True, the cache is skipped and the roles are updated. + """ + collection = database.get_collection("roles") + + if force_refresh: + # Drop all values in the collection + await collection.delete_many({}) + + # `create_index` creates the index if it does not exist, or passes + # This handles TTL on role objects + await collection.create_index( + "inserted_at", + expireAfterSeconds=60 * 60 * 24, # 1 day + name="inserted_at", + ) + + roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()] + + if len(roles) == 0: + # Fetch roles from the API and insert into the database + roles = await _get_role_info() + await collection.insert_many({ + "name": role.name, + "id": role.id, + "data": role.json(), + "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), + } for role in roles) + + return roles + + +async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]: + """Get a member by ID from the configured guild using the discord API.""" + async with httpx.AsyncClient() as client: + r = await client.get( + f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}" + f"/members/{member_id}", + headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"} + ) + + if r.status_code == 404: + return None + + r.raise_for_status() + return models.DiscordMember(**r.json()) + + +async def get_member( + database: Database, user_id: str, *, force_refresh: bool = False +) -> typing.Optional[models.DiscordMember]: + """ + Get a member from the cache, or from the discord API. + + If `force_refresh` is True, the cache is skipped and the entry is updated. + None may be returned if the member object does not exist. + """ + collection = database.get_collection("discord_members") + + if force_refresh: + await collection.delete_one({"user": user_id}) + + # `create_index` creates the index if it does not exist, or passes + # This handles TTL on member objects + await collection.create_index( + "inserted_at", + expireAfterSeconds=60 * 60, # 1 hour + name="inserted_at", + ) + + result = await collection.find_one({"user": user_id}) + + if result is not None: + return models.DiscordMember(**json.loads(result["data"])) + + member = await _fetch_member_api(user_id) + + if not member: + return None + + await collection.insert_one({ + "user": user_id, + "data": member.json(), + "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc), + }) + return member + + +class FormNotFoundError(exceptions.HTTPException): + """The requested form was not found.""" + + +class UnauthorizedError(exceptions.HTTPException): + """You are not authorized to use this resource.""" + + +async def _verify_access_helper( + form_id: str, request: starlette.requests.Request, attribute: str +) -> None: + """A low level helper to validate access to a form resource based on the user's scopes.""" + form = await request.state.db.forms.find_one({"_id": form_id}) + + if not form: + raise FormNotFoundError(status_code=404) + + # Short circuit all resources for forms admins + if "admin" in request.auth.scopes: + return + + form = models.Form(**form) + + for role_id in getattr(form, attribute, []): + role = await request.state.db.roles.find_one({"id": role_id}) + if not role: + continue + + role = models.DiscordRole(**json.loads(role["data"])) + + if role.name in request.auth.scopes: + return + + raise UnauthorizedError(status_code=401) + + +async def verify_response_access(form_id: str, request: starlette.requests.Request) -> None: + """Ensure the user can access responses on the requested resource.""" + await _verify_access_helper(form_id, request, "response_readers") + + +async def verify_edit_access(form_id: str, request: starlette.requests.Request) -> None: + """Ensure the user can view and modify the requested resource.""" + await _verify_access_helper(form_id, request, "editors") |