diff options
Diffstat (limited to 'pysite/oauth.py')
-rw-r--r-- | pysite/oauth.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/pysite/oauth.py b/pysite/oauth.py new file mode 100644 index 00000000..8370b713 --- /dev/null +++ b/pysite/oauth.py @@ -0,0 +1,77 @@ +import logging +from uuid import uuid4, uuid5 + +from flask import session +from flask_dance.consumer.backend import BaseBackend +from flask_dance.contrib.discord import discord + + +from pysite.constants import DISCORD_API_ENDPOINT, OAUTH_DATABASE + + +class OauthBackend(BaseBackend): + """ + This is the backend for the oauth + + This is used to manage users that have completed + an oauth dance. It contains 3 functions, get, set, + and delete, however we only use set. + + Inherits: + flake_dance.consumer.backend.BaseBackend + pysite.mixins.DBmixin + + Properties: + key: The app's secret, we use it too make session IDs + """ + + def __init__(self, manager): + super().__init__() + self.db = manager.db + self.key = manager.app.secret_key + self.db.create_table(OAUTH_DATABASE, primary_key="id") + + def get(self, *args, **kwargs): # Not used + pass + + def set(self, blueprint, token): + + user = self.get_user() + self.join_discord(token["access_token"], user["id"]) + sess_id = str(uuid5(uuid4(), self.key)) + self.add_user(token, user, sess_id) + + def delete(self, blueprint): # Not used + pass + + def add_user(self, token_data: dict, user_data: dict, session_id: str): + session["session_id"] = session_id + + self.db.insert(OAUTH_DATABASE, {"id": session_id, + "access_token": token_data["access_token"], + "refresh_token": token_data["refresh_token"], + "expires_at": token_data["expires_at"], + "snowflake": user_data["id"]}) + + self.db.insert("users", {"user_id": user_data["id"], + "username": user_data["username"], + "discriminator": user_data["discriminator"], + "email": user_data["email"]}) + + def get_user(self) -> dict: + resp = discord.get(DISCORD_API_ENDPOINT + "/users/@me") # 'discord' is a request.Session with oauth information + if resp.status_code != 200: + logging.warning("Unable to get user information: " + str(resp.json())) + return resp.json() + + def user_data(self): + user_id = session.get("session_id") + if user_id: # If the user is logged in, get user info. + creds = self.db.get(OAUTH_DATABASE, user_id) + if creds: + return self.db.get("users", creds["snowflake"]) + + def logout(self): + sess_id = session.get("session_id") + if sess_id and self.db.get(OAUTH_DATABASE, sess_id): # If user exists in db, + self.db.delete(OAUTH_DATABASE, sess_id) # remove them (at least, their session) |