aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--backend/authentication/backend.py2
-rw-r--r--backend/authentication/user.py4
-rw-r--r--backend/discord.py43
-rw-r--r--backend/routes/discord.py4
4 files changed, 16 insertions, 37 deletions
diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py
index c84ba10..e150580 100644
--- a/backend/authentication/backend.py
+++ b/backend/authentication/backend.py
@@ -68,6 +68,6 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
if await user.fetch_admin_status(request.state.db):
scopes.append("admin")
- scopes.extend(await user.get_user_roles(request.state.db))
+ scopes.extend(await user.get_user_roles())
return authentication.AuthCredentials(scopes), user
diff --git a/backend/authentication/user.py b/backend/authentication/user.py
index ad59103..5e99546 100644
--- a/backend/authentication/user.py
+++ b/backend/authentication/user.py
@@ -44,12 +44,12 @@ class User(BaseUser):
def decoded_token(self) -> dict[str, any]:
return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])
- async def get_user_roles(self, database: Database) -> list[str]:
+ async def get_user_roles(self) -> list[str]:
"""Get a list of the user's discord roles."""
if not self.member:
return []
- server_roles = await discord.get_roles(database)
+ server_roles = await discord.get_roles()
roles = [role.name for role in server_roles if role.id in self.member.roles]
if "admin" in roles:
diff --git a/backend/discord.py b/backend/discord.py
index 192fc60..4a1ecf5 100644
--- a/backend/discord.py
+++ b/backend/discord.py
@@ -1,11 +1,9 @@
"""Various utilities for working with the Discord API."""
-import datetime
import json
import httpx
import starlette.requests
-from pymongo.database import Database
from starlette import exceptions
from backend import constants, models
@@ -66,7 +64,6 @@ async def _get_role_info() -> list[models.DiscordRole]:
async def get_roles(
- database: Database,
*,
force_refresh: bool = False,
) -> list[models.DiscordRole]:
@@ -75,35 +72,17 @@ async def get_roles(
If `force_refresh` is True, the cache is skipped and the roles are updated.
"""
- collection = database.get_collection("roles")
-
- if force_refresh:
- # Drop all values in the collection
- await collection.delete_many({})
-
- # `create_index` creates the index if it does not exist, or passes
- # This handles TTL on role objects
- await collection.create_index(
- "inserted_at",
- expireAfterSeconds=60 * 60 * 24, # 1 day
- name="inserted_at",
- )
-
- roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()]
-
- if len(roles) == 0:
- # Fetch roles from the API and insert into the database
- roles = await _get_role_info()
- await collection.insert_many(
- {
- "name": role.name,
- "id": role.id,
- "data": role.json(),
- "inserted_at": datetime.datetime.now(tz=datetime.UTC),
- }
- for role in roles
- )
-
+ role_cache_key = "forms-backend:role_cache"
+ if not force_refresh:
+ roles = await constants.REDIS_CLIENT.hgetall(role_cache_key)
+ if roles:
+ return [
+ models.DiscordRole(**json.loads(role_data)) for role_id, role_data in roles.items()
+ ]
+
+ roles = await _get_role_info()
+ await constants.REDIS_CLIENT.hmset(role_cache_key, {role.id: role.json() for role in roles})
+ await constants.REDIS_CLIENT.expire(role_cache_key, 60 * 60 * 24) # 1 day
return roles
diff --git a/backend/routes/discord.py b/backend/routes/discord.py
index 196d902..5cd6b47 100644
--- a/backend/routes/discord.py
+++ b/backend/routes/discord.py
@@ -31,9 +31,9 @@ class RolesRoute(route.Route):
resp=Response(HTTP_200=RolesResponse),
tags=["roles"],
)
- async def patch(self, request: Request) -> JSONResponse:
+ async def patch(self, request: Request) -> JSONResponse: # noqa: ARG002 Request is required by @requires
"""Refresh the roles database."""
- roles = await discord.get_roles(request.state.db, force_refresh=True)
+ roles = await discord.get_roles(force_refresh=True)
return JSONResponse(
{"roles": [role.dict() for role in roles]},