aboutsummaryrefslogtreecommitdiffstats
path: root/pysite/oauth.py
blob: 86e7cdde7fa46cdd9c314dc092be3d501d67d0ee (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging
from uuid import uuid4, uuid5

from flask import session
from flask_dance.consumer.backend import BaseBackend
from flask_dance.contrib.discord import discord

from pysite.constants import DISCORD_API_ENDPOINT, OAUTH_DATABASE


class OAuthBackend(BaseBackend):
    """
    This is the backend for the oauth

    This is used to manage users that have completed
    an oauth dance. It contains 3 functions, get, set,
    and delete, however we only use set.

    Inherits:
        flake_dance.consumer.backend.BaseBackend
        pysite.mixins.DBmixin

    Properties:
        key: The app's secret, we use it too make session IDs
    """

    def __init__(self, manager):
        super().__init__()
        self.db = manager.db
        self.key = manager.app.secret_key
        self.db.create_table(OAUTH_DATABASE, primary_key="id")

    def get(self, *args, **kwargs):  # Not used
        pass

    def set(self, blueprint, token):
        user = self.get_user()
        sess_id = str(uuid5(uuid4(), self.key))
        self.add_user(token, user, sess_id)

    def delete(self, blueprint):  # Not used
        pass

    def add_user(self, token_data: dict, user_data: dict, session_id: str):
        session["session_id"] = session_id

        self.db.insert(
            OAUTH_DATABASE,
            {
                "id": session_id,
                "access_token": token_data["access_token"],
                "refresh_token": token_data["refresh_token"],
                "expires_at": token_data["expires_at"],
                "snowflake": user_data["id"]
            },
            conflict="replace"
        )

        self.db.insert(
            "users",
            {
                "user_id": user_data["id"],
                "username": user_data["username"],
                "discriminator": user_data["discriminator"]
            },
            conflict="update"
        )

    def get_user(self) -> dict:
        resp = discord.get(DISCORD_API_ENDPOINT + "/users/@me")  # 'discord' is a request.Session with oauth information
        if resp.status_code != 200:
            logging.warning("Unable to get user information: " + str(resp.json()))
        return resp.json()

    def user_data(self):
        user_id = session.get("session_id")
        if user_id:  # If the user is logged in, get user info.
            creds = self.db.get(OAUTH_DATABASE, user_id)
            if creds:
                return self.db.get("users", creds["snowflake"])

    def logout(self):
        sess_id = session.get("session_id")
        if sess_id and self.db.get(OAUTH_DATABASE, sess_id):  # If user exists in db,
            self.db.delete(OAUTH_DATABASE, sess_id)  # remove them (at least, their session)
            session.clear()