aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorGravatar Hassan Abouelela <[email protected]>2022-02-05 16:50:11 +0400
committerGravatar Hassan Abouelela <[email protected]>2022-02-05 17:20:40 +0400
commit134b2f70e4cf947744f1b061766bb37fe616ad65 (patch)
treeef6b95bc5a78528d91ae969f3cfd00bc8e5be8ed /backend
parentAdd Helper Functions For Managing Roles (diff)
Overhaul Scope System
Adds discord role support to the pre-existing scopes system to power more complex access permissions. Signed-off-by: Hassan Abouelela <[email protected]>
Diffstat (limited to 'backend')
-rw-r--r--backend/authentication/backend.py9
-rw-r--r--backend/authentication/user.py45
-rw-r--r--backend/discord.py105
-rw-r--r--backend/models/__init__.py5
-rw-r--r--backend/models/discord_user.py9
-rw-r--r--backend/routes/auth/authorize.py12
-rw-r--r--backend/routes/discord.py83
-rw-r--r--backend/routes/forms/submit.py2
-rw-r--r--backend/routes/roles.py36
9 files changed, 246 insertions, 60 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index c7590e9..54385e2 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -5,6 +5,7 @@ from starlette import authentication
from starlette.requests import Request
from backend import constants
+from backend import discord
# We must import user such way here to avoid circular imports
from .user import User
@@ -60,8 +61,12 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
except Exception:
raise authentication.AuthenticationError("Could not parse user details.")
- user = User(token, user_details)
- if await user.fetch_admin_status(request):
+ user = User(
+ token, user_details, await discord.get_member(request.state.db, user_details["id"])
+ )
+ if await user.fetch_admin_status(request.state.db):
scopes.append("admin")
+ scopes.extend(await user.get_user_roles(request.state.db))
+
return authentication.AuthCredentials(scopes), user
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index 857c2ed..0ec0188 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -1,20 +1,27 @@
+import typing
import typing as t
import jwt
+from pymongo.database import Database
from starlette.authentication import BaseUser
-from starlette.requests import Request
+from backend import discord, models
from backend.constants import SECRET_KEY
-from backend.discord import fetch_user_details
class User(BaseUser):
"""Starlette BaseUser implementation for JWT authentication."""
- def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
+ def __init__(
+ self,
+ token: str,
+ payload: dict[str, t.Any],
+ member: typing.Optional[models.DiscordMember],
+ ) -> None:
self.token = token
self.payload = payload
self.admin = False
+ self.member = member
@property
def is_authenticated(self) -> bool:
@@ -34,16 +41,40 @@ class User(BaseUser):
def decoded_token(self) -> dict[str, any]:
return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
- async def fetch_admin_status(self, request: Request) -> bool:
- self.admin = await request.state.db.admins.find_one(
+ async def get_user_roles(self, database: Database) -> list[str]:
+ """Get a list of the user's discord roles."""
+ if not self.member:
+ return []
+
+ server_roles = await discord.get_roles(database)
+ roles = []
+
+ for role in server_roles:
+ if role.id in self.member.roles:
+ roles.append(role.name)
+
+ if "admin" in roles:
+ # Protect against collision with the forms admin role
+ roles.remove("admin")
+ roles.append("discord admin")
+
+ return roles
+
+ async def fetch_admin_status(self, database: Database) -> bool:
+ self.admin = await database.admins.find_one(
{"_id": self.payload["id"]}
) is not None
return self.admin
- async def refresh_data(self) -> None:
+ async def refresh_data(self, database: Database) -> None:
"""Fetches user data from discord, and updates the instance."""
- self.payload = await fetch_user_details(self.decoded_token.get("token"))
+ self.member = await discord.get_member(database, self.payload["id"])
+
+ if self.member:
+ self.payload = self.member.user.dict()
+ else:
+ self.payload = await discord.fetch_user_details(self.decoded_token.get("token"))
updated_info = self.decoded_token
updated_info["user_details"] = self.payload
diff --git a/backend/discord.py b/backend/discord.py
index cf80cf3..51de26a 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -1,8 +1,13 @@
"""Various utilities for working with the Discord API."""
+
+import datetime
+import json
+import typing
+
import httpx
+from pymongo.database import Database
-from backend import constants
-from backend.models import discord_role, discord_user
+from backend import constants, models
async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
@@ -40,7 +45,7 @@ async def fetch_user_details(bearer_token: str) -> dict:
return r.json()
-async def get_role_info() -> list[discord_role.DiscordRole]:
+async def _get_role_info() -> list[models.DiscordRole]:
"""Get information about the roles in the configured guild."""
async with httpx.AsyncClient() as client:
r = await client.get(
@@ -49,11 +54,50 @@ async def get_role_info() -> list[discord_role.DiscordRole]:
)
r.raise_for_status()
- return [discord_role.DiscordRole(**role) for role in r.json()]
-
-
-async def get_member(member_id: str) -> discord_user.DiscordMember:
- """Get a member by ID from the configured guild."""
+ return [models.DiscordRole(**role) for role in r.json()]
+
+
+async def get_roles(
+ database: Database, *, force_refresh: bool = False
+) -> list[models.DiscordRole]:
+ """
+ Get a list of all roles from the cache, or discord API if not available.
+
+ If `force_refresh` is True, the cache is skipped and the roles are updated.
+ """
+ collection = database.get_collection("roles")
+
+ if force_refresh:
+ # Drop all values in the collection
+ await collection.delete_many({})
+
+ # `create_index` creates the index if it does not exist, or passes
+ # This handles TTL on role objects
+ await collection.create_index(
+ "inserted_at",
+ expireAfterSeconds=60 * 60 * 24, # 1 day
+ name="inserted_at",
+ )
+
+ roles = []
+ async for role in collection.find():
+ roles.append(models.DiscordRole(**json.loads(role["data"])))
+
+ if len(roles) == 0:
+ # Fetch roles from the API and insert into the database
+ roles = await _get_role_info()
+ await collection.insert_many({
+ "name": role.name,
+ "id": role.id,
+ "data": role.json(),
+ "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
+ } for role in roles)
+
+ return roles
+
+
+async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]:
+ """Get a member by ID from the configured guild using the discord API."""
async with httpx.AsyncClient() as client:
r = await client.get(
f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}"
@@ -61,5 +105,48 @@ async def get_member(member_id: str) -> discord_user.DiscordMember:
headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
)
+ if r.status_code == 404:
+ return None
+
r.raise_for_status()
- return discord_user.DiscordMember(**r.json())
+ return models.DiscordMember(**r.json())
+
+
+async def get_member(
+ database: Database, user_id: str, *, force_refresh: bool = False
+) -> typing.Optional[models.DiscordMember]:
+ """
+ Get a member from the cache, or from the discord API.
+
+ If `force_refresh` is True, the cache is skipped and the entry is updated.
+ None may be returned if the member object does not exist.
+ """
+ collection = database.get_collection("discord_members")
+
+ if force_refresh:
+ await collection.delete_one({"user": user_id})
+
+ # `create_index` creates the index if it does not exist, or passes
+ # This handles TTL on member objects
+ await collection.create_index(
+ "inserted_at",
+ expireAfterSeconds=60 * 60, # 1 hour
+ name="inserted_at",
+ )
+
+ result = await collection.find_one({"user": user_id})
+
+ if result is not None:
+ return models.DiscordMember(**json.loads(result["data"]))
+
+ member = await _fetch_member_api(user_id)
+
+ if not member:
+ return None
+
+ await collection.insert_one({
+ "user": user_id,
+ "data": member.json(),
+ "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
+ })
+ return member
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index 8ad7f7f..a9f76e0 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -1,12 +1,15 @@
from .antispam import AntiSpam
-from .discord_user import DiscordUser
+from .discord_role import DiscordRole
+from .discord_user import DiscordMember, DiscordUser
from .form import Form, FormList
from .form_response import FormResponse, ResponseList
from .question import CodeQuestion, Question
__all__ = [
"AntiSpam",
+ "DiscordRole",
"DiscordUser",
+ "DiscordMember",
"Form",
"FormResponse",
"CodeQuestion",
diff --git a/backend/models/discord_user.py b/backend/models/discord_user.py
index 3f4209d..0eca15b 100644
--- a/backend/models/discord_user.py
+++ b/backend/models/discord_user.py
@@ -43,3 +43,12 @@ class DiscordMember(BaseModel):
pending: t.Optional[bool]
permissions: t.Optional[str]
communication_disabled_until: t.Optional[datetime.datetime]
+
+ def dict(self, *args, **kwargs) -> dict[str, t.Any]:
+ """Convert the model to a python dict, and encode timestamps in a serializable format."""
+ data = super().dict(*args, **kwargs)
+ for field, value in data.items():
+ if isinstance(value, datetime.datetime):
+ data[field] = value.isoformat()
+
+ return data
diff --git a/backend/routes/auth/authorize.py b/backend/routes/auth/authorize.py
index d4587f0..42fb3ec 100644
--- a/backend/routes/auth/authorize.py
+++ b/backend/routes/auth/authorize.py
@@ -17,7 +17,7 @@ from starlette.requests import Request
from backend import constants
from backend.authentication.user import User
from backend.constants import SECRET_KEY
-from backend.discord import fetch_bearer_token, fetch_user_details
+from backend.discord import fetch_bearer_token, fetch_user_details, get_member
from backend.route import Route
from backend.validation import ErrorMessage, api
@@ -34,8 +34,8 @@ class AuthorizeResponse(BaseModel):
async def process_token(
- bearer_token: dict,
- request: Request
+ bearer_token: dict,
+ request: Request
) -> Union[AuthorizeResponse, AUTH_FAILURE]:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()
@@ -46,6 +46,9 @@ async def process_token(
AUTH_FAILURE.delete_cookie("token")
return AUTH_FAILURE
+ user_id = user_details["id"]
+ member = await get_member(request.state.db, user_id, force_refresh=True)
+
max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
token_expiry = interaction_start + max_age
@@ -53,11 +56,12 @@ async def process_token(
"token": bearer_token["access_token"],
"refresh": bearer_token["refresh_token"],
"user_details": user_details,
+ "in_guild": bool(member),
"expiry": token_expiry.isoformat()
}
token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
- user = User(token, user_details)
+ user = User(token, user_details, member)
response = responses.JSONResponse({
"username": user.display_name,
diff --git a/backend/routes/discord.py b/backend/routes/discord.py
new file mode 100644
index 0000000..a980d94
--- /dev/null
+++ b/backend/routes/discord.py
@@ -0,0 +1,83 @@
+"""Routes which directly interact with discord related data."""
+
+import pydantic
+from spectree import Response
+from starlette.authentication import requires
+from starlette.responses import JSONResponse
+from starlette.routing import Request
+
+from backend import discord, models, route
+from backend.validation import ErrorMessage, OkayResponse, api
+
+NOT_FOUND_EXCEPTION = JSONResponse(
+ {"error": "Could not find the requested resource in the guild or cache."}, status_code=404
+)
+
+
+class RolesRoute(route.Route):
+ """Refreshes the roles database."""
+
+ name = "roles"
+ path = "/roles"
+
+ class RolesResponse(pydantic.BaseModel):
+ """A list of all roles on the configured server."""
+
+ roles: list[models.DiscordRole]
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=OkayResponse),
+ tags=["roles"]
+ )
+ async def patch(self, request: Request) -> JSONResponse:
+ """Refresh the roles database."""
+ roles = await discord.get_roles(request.state.db, force_refresh=True)
+
+ return JSONResponse(
+ {"status": "ok"},
+ )
+
+
+class MemberRoute(route.Route):
+ """Retrieve information about a server member."""
+
+ name = "member"
+ path = "/member"
+
+ class MemberRequest(pydantic.BaseModel):
+ """An ID of the member to update."""
+
+ user_id: str
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
+ json=MemberRequest,
+ tags=["auth"]
+ )
+ async def delete(self, request: Request):
+ """Force a resync of the cache for the given user."""
+ body = await request.json()
+ member = await discord.get_member(request.state.db, body["user_id"], force_refresh=True)
+
+ if member:
+ return JSONResponse(member.dict())
+ else:
+ return NOT_FOUND_EXCEPTION
+
+ @requires(["authenticated", "admin"])
+ @api.validate(
+ resp=Response(HTTP_200=models.DiscordMember, HTTP_400=ErrorMessage),
+ json=MemberRequest,
+ tags=["auth"]
+ )
+ async def get(self, request: Request):
+ """Get a user's roles on the configured server."""
+ body = await request.json()
+ member = await discord.get_member(request.state.db, body["user_id"])
+
+ if member:
+ return JSONResponse(member.dict())
+ else:
+ return NOT_FOUND_EXCEPTION
diff --git a/backend/routes/forms/submit.py b/backend/routes/forms/submit.py
index 95e30b0..baf403d 100644
--- a/backend/routes/forms/submit.py
+++ b/backend/routes/forms/submit.py
@@ -83,7 +83,7 @@ class SubmitForm(Route):
try:
if hasattr(request.user, User.refresh_data.__name__):
old = request.user.token
- await request.user.refresh_data()
+ await request.user.refresh_data(request.state.db)
if old != request.user.token:
try:
diff --git a/backend/routes/roles.py b/backend/routes/roles.py
deleted file mode 100644
index b18a04b..0000000
--- a/backend/routes/roles.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import starlette.background
-from pymongo.database import Database
-from spectree import Response
-from starlette.authentication import requires
-from starlette.responses import JSONResponse
-from starlette.routing import Request
-
-from backend import discord, route
-from backend.validation import OkayResponse, api
-
-
-async def refresh_roles(database: Database) -> None:
- """Connect to the discord API and refresh the roles database."""
- roles = await discord.get_role_info()
- roles_collection = database.get_collection("roles")
- roles_collection.drop()
- roles_collection.insert_many([role.dict() for role in roles])
-
-
-class RolesRoute(route.Route):
- """Refreshes the roles database."""
-
- name = "roles"
- path = "/roles"
-
- @requires(["authenticated", "admin"])
- @api.validate(
- resp=Response(HTTP_200=OkayResponse),
- tags=["roles"]
- )
- async def patch(self, request: Request) -> JSONResponse:
- """Refresh the roles database."""
- return JSONResponse(
- {"status": "ok"},
- background=starlette.background.BackgroundTask(refresh_roles, request.state.db)
- )