aboutsummaryrefslogtreecommitdiffstats
path: root/backend/discord.py
blob: 5a734db62886578d1970462775dfe66a7c95ea94 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""Various utilities for working with the Discord API."""

import datetime
import json
import typing

import httpx
import starlette.requests
from pymongo.database import Database
from starlette import exceptions

from backend import constants, models


async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
    async with httpx.AsyncClient() as client:
        data = {
            "client_id": constants.OAUTH2_CLIENT_ID,
            "client_secret": constants.OAUTH2_CLIENT_SECRET,
            "redirect_uri": f"{redirect}/callback"
        }

        if refresh:
            data["grant_type"] = "refresh_token"
            data["refresh_token"] = code
        else:
            data["grant_type"] = "authorization_code"
            data["code"] = code

        r = await client.post(f"{constants.DISCORD_API_BASE_URL}/oauth2/token", headers={
            "Content-Type": "application/x-www-form-urlencoded"
        }, data=data)

        r.raise_for_status()

        return r.json()


async def fetch_user_details(bearer_token: str) -> dict:
    async with httpx.AsyncClient() as client:
        r = await client.get(f"{constants.DISCORD_API_BASE_URL}/users/@me", headers={
            "Authorization": f"Bearer {bearer_token}"
        })

        r.raise_for_status()

        return r.json()


async def _get_role_info() -> list[models.DiscordRole]:
    """Get information about the roles in the configured guild."""
    async with httpx.AsyncClient() as client:
        r = await client.get(
            f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}/roles",
            headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
        )

        r.raise_for_status()
        return [models.DiscordRole(**role) for role in r.json()]


async def get_roles(
    database: Database, *, force_refresh: bool = False
) -> list[models.DiscordRole]:
    """
    Get a list of all roles from the cache, or discord API if not available.

    If `force_refresh` is True, the cache is skipped and the roles are updated.
    """
    collection = database.get_collection("roles")

    if force_refresh:
        # Drop all values in the collection
        await collection.delete_many({})

    # `create_index` creates the index if it does not exist, or passes
    # This handles TTL on role objects
    await collection.create_index(
        "inserted_at",
        expireAfterSeconds=60 * 60 * 24,  # 1 day
        name="inserted_at",
    )

    roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()]

    if len(roles) == 0:
        # Fetch roles from the API and insert into the database
        roles = await _get_role_info()
        await collection.insert_many({
            "name": role.name,
            "id": role.id,
            "data": role.json(),
            "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
        } for role in roles)

    return roles


async def _fetch_member_api(member_id: str) -> typing.Optional[models.DiscordMember]:
    """Get a member by ID from the configured guild using the discord API."""
    async with httpx.AsyncClient() as client:
        r = await client.get(
            f"{constants.DISCORD_API_BASE_URL}/guilds/{constants.DISCORD_GUILD}"
            f"/members/{member_id}",
            headers={"Authorization": f"Bot {constants.DISCORD_BOT_TOKEN}"}
        )

        if r.status_code == 404:
            return None

        r.raise_for_status()
        return models.DiscordMember(**r.json())


async def get_member(
    database: Database, user_id: str, *, force_refresh: bool = False
) -> typing.Optional[models.DiscordMember]:
    """
    Get a member from the cache, or from the discord API.

    If `force_refresh` is True, the cache is skipped and the entry is updated.
    None may be returned if the member object does not exist.
    """
    collection = database.get_collection("discord_members")

    if force_refresh:
        await collection.delete_one({"user": user_id})

    # `create_index` creates the index if it does not exist, or passes
    # This handles TTL on member objects
    await collection.create_index(
        "inserted_at",
        expireAfterSeconds=60 * 60,  # 1 hour
        name="inserted_at",
    )

    result = await collection.find_one({"user": user_id})

    if result is not None:
        return models.DiscordMember(**json.loads(result["data"]))

    member = await _fetch_member_api(user_id)

    if not member:
        return None

    await collection.insert_one({
        "user": user_id,
        "data": member.json(),
        "inserted_at": datetime.datetime.now(tz=datetime.timezone.utc),
    })
    return member


class FormNotFoundError(exceptions.HTTPException):
    """The requested form was not found."""


class UnauthorizedError(exceptions.HTTPException):
    """You are not authorized to use this resource."""


async def _verify_access_helper(
    form_id: str, request: starlette.requests.Request, attribute: str
) -> None:
    """A low level helper to validate access to a form resource based on the user's scopes."""
    form = await request.state.db.forms.find_one({"id": form_id})

    if not form:
        raise FormNotFoundError(status_code=404)

    # Short circuit all resources for admins
    if "admin" in request.auth.scopes:
        return

    form = models.Form(**form)

    for role_id in getattr(form, attribute, []):
        role = await request.state.db.roles.find_one({"id": role_id})
        if not role:
            continue

        role = models.DiscordRole(**json.loads(role["data"]))

        if role.name in request.auth.scopes:
            return

    raise UnauthorizedError(status_code=401)


async def verify_response_access(form_id: str, request: starlette.requests.Request) -> None:
    """Ensure the user can access responses on the requested resource."""
    await _verify_access_helper(form_id, request, "response_readers")


async def verify_edit_access(form_id: str, request: starlette.requests.Request) -> None:
    """Ensure the user can view and modify the requested resource."""
    await _verify_access_helper(form_id, request, "editors")