diff options
author | 2018-03-29 09:56:24 +0100 | |
---|---|---|
committer | 2018-03-29 09:56:24 +0100 | |
commit | 5fcd1647e5f8f55240492b6df3b3ff15bab86bf7 (patch) | |
tree | 63c348feb999ca358716ebd9f08b0f5259d68ad1 | |
parent | Make flake8 happier (diff) | |
parent | Oauth (#45) (diff) |
Merge remote-tracking branch 'origin/master'
-rw-r--r-- | app_test.py | 125 | ||||
-rw-r--r-- | pysite/base_route.py | 4 | ||||
-rw-r--r-- | pysite/constants.py | 21 | ||||
-rw-r--r-- | pysite/database.py | 8 | ||||
-rw-r--r-- | pysite/mixins.py | 48 | ||||
-rw-r--r-- | pysite/oauth.py | 77 | ||||
-rw-r--r-- | pysite/route_manager.py | 20 | ||||
-rw-r--r-- | pysite/views/main/index.py | 3 | ||||
-rw-r--r-- | pysite/views/main/logout.py | 15 | ||||
-rw-r--r-- | pysite/views/tests/index.py | 1 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | templates/main/navigation.html | 6 |
12 files changed, 305 insertions, 24 deletions
diff --git a/app_test.py b/app_test.py index 2176fe08..31acf12d 100644 --- a/app_test.py +++ b/app_test.py @@ -5,6 +5,7 @@ from flask import Blueprint from flask_testing import TestCase from app import manager +from pysite.constants import DISCORD_OAUTH_REDIRECT, DISCORD_OAUTH_AUTHORIZED manager.app.tests_blueprint = Blueprint("tests", __name__) manager.load_views(manager.app.tests_blueprint, "pysite/views/tests") @@ -33,7 +34,7 @@ class RootEndpoint(SiteTest): """ Test cases for the root endpoint and error handling """ def test_index(self): - """ Check the root path reponds with 200 OK """ + """ Check the root path responds with 200 OK """ response = self.client.get('/', 'http://pytest.local') self.assertEqual(response.status_code, 200) @@ -79,23 +80,38 @@ class RootEndpoint(SiteTest): self.assertEqual(response.status_code, 302) def test_ws_test(self): - """ check ws_test responds """ + """ Check ws_test responds """ response = self.client.get('/ws_test') self.assertEqual(response.status_code, 200) + def test_oauth_redirects(self): + """ Check oauth redirects """ + response = self.client.get(DISCORD_OAUTH_REDIRECT) + self.assertEqual(response.status_code, 302) + + def test_oauth_logout(self): + """ Check oauth redirects """ + response = self.client.get('/auth/logout') + self.assertEqual(response.status_code, 302) + + def test_oauth_authorized(self): + """ Check oauth authorization """ + response = self.client.get(DISCORD_OAUTH_AUTHORIZED) + self.assertEqual(response.status_code, 302) + def test_datadog_redirect(self): """ Check datadog path redirects """ response = self.client.get('/datadog') self.assertEqual(response.status_code, 302) def test_500_easter_egg(self): - """Check the status of the /500 page""" + """ Check the status of the /500 page""" response = self.client.get("/500") self.assertEqual(response.status_code, 500) class ApiEndpoints(SiteTest): - """ test cases for the api subdomain """ + """ Test cases for the api subdomain """ def test_api_unknown_route(self): """ Check api unknown route """ response = self.client.get('/', app.config['API_SUBDOMAIN']) @@ -240,7 +256,7 @@ class Utilities(SiteTest): ev = pysite.base_route.ErrorView() try: - ev.setup('sdf', 'sdfsdf') + ev.setup(manager, 'sdfsdf') except RuntimeError: return True raise Exception('Expected runtime error on setup() when giving wrongful arguments') @@ -259,9 +275,9 @@ class Utilities(SiteTest): return True - class MixinTests(SiteTest): """ Test cases for mixins """ + def test_dbmixin_runtime_error(self): """ Check that wrong values for error view setup raises runtime error """ from pysite.mixins import DBMixin @@ -280,7 +296,7 @@ class MixinTests(SiteTest): try: dbm = DBMixin() dbm.table_name = 'Table' - self.assertEquals(dbm.table, 'Table') + self.assertEqual(dbm.table, 'Table') except AttributeError: pass @@ -299,7 +315,7 @@ class MixinTests(SiteTest): rv = RouteView() try: - rv.setup('sdf', 'sdfsdf') + rv.setup(manager, 'sdfsdf') except RuntimeError: return True raise Exception('Expected runtime error on setup() when giving wrongful arguments') @@ -307,10 +323,54 @@ class MixinTests(SiteTest): def test_route_manager(self): """ Check route manager """ from pysite.route_manager import RouteManager + os.environ['WEBPAGE_SECRET_KEY'] = 'super_secret' rm = RouteManager() self.assertEqual(rm.app.secret_key, 'super_secret') + def test_oauth_property(self): + """ Make sure the oauth property works""" + from flask import Blueprint + + from pysite.route_manager import RouteView + from pysite.oauth import OauthBackend + + class TestRoute(RouteView): + name = "test" + path = "/test" + + tr = TestRoute() + tr.setup(manager, Blueprint("test", "test_name")) + self.assertIsInstance(tr.oauth, OauthBackend) + + def test_user_data_property(self): + """ Make sure the user_data property works""" + from flask import Blueprint + + from pysite.route_manager import RouteView + + class TestRoute(RouteView): + name = "test" + path = "/test" + + tr = TestRoute() + tr.setup(manager, Blueprint("test", "test_name")) + self.assertIs(tr.user_data, None) + + def test_logged_in_property(self): + """ Make sure the user_data property works""" + from flask import Blueprint + + from pysite.route_manager import RouteView + + class TestRoute(RouteView): + name = "test" + path = "/test" + + tr = TestRoute() + tr.setup(manager, Blueprint("test", "test_name")) + self.assertIs(tr.logged_in, False) + class DecoratorTests(SiteTest): def test_decorator_api_json(self): @@ -352,22 +412,22 @@ class DatabaseTests(SiteTest): rdb = RethinkDB() # Create table name and expect it to work result = rdb.create_table(generated_table_name) - self.assertEquals(result, True) + self.assertEqual(result, True) # Create the same table name and expect it to already exist result = rdb.create_table(generated_table_name) - self.assertEquals(result, False) + self.assertEqual(result, False) # Drop table and expect it to work result = rdb.drop_table(generated_table_name) - self.assertEquals(result, True) + self.assertEqual(result, True) # Drop the same table and expect it to already be gone result = rdb.drop_table(generated_table_name) - self.assertEquals(result, False) + self.assertEqual(result, False) # This is to get some more code coverage - self.assertEquals(rdb.teardown_request('_'), None) + self.assertEqual(rdb.teardown_request('_'), None) class TestWebsocketEcho(SiteTest): @@ -380,3 +440,42 @@ class TestWebsocketEcho(SiteTest): ew.on_open() ew.on_message('message') ew.on_close() + + +class TestOauthBackend(SiteTest): + """ Test cases for the oauth.py file """ + + def test_get(self): + """ Make sure the get function returns nothing """ + self.assertIs(manager.oauth_backend.get(), None) + + def test_delete(self): + """ Make sure the delete function returns nothing """ + self.assertIs(manager.oauth_backend.delete(None), None) + + def test_logout(self): + """ Make sure at least apart of logout is working :/ """ + self.assertIs(manager.oauth_backend.logout(), None) + + def test_add_user(self): + """ Make sure function adds values to database and session """ + from flask import session + + from pysite.constants import OAUTH_DATABASE + + sess_id = "hey bro wazup" + fake_token = {"access_token": "access_token", "id": sess_id, "refresh_token": "refresh_token", "expires_at": 5} + fake_user = {"id": "1235678987654321", "username": "Zwacky", "discriminator": "#6660", "email": "[email protected]"} + manager.db.conn = manager.db.get_connection() + manager.oauth_backend.add_user(fake_token, fake_user, sess_id) + + self.assertEqual(sess_id, session["session_id"]) + fake_token["snowflake"] = fake_user["id"] + fake_user["user_id"] = fake_user["id"] + del fake_user["id"] + self.assertEqual(fake_token, manager.db.get(OAUTH_DATABASE, sess_id)) + self.assertEqual(fake_user, manager.db.get("users", fake_user["user_id"])) + + manager.db.delete(OAUTH_DATABASE, sess_id) + manager.db.delete("users", fake_user["user_id"]) + manager.db.teardown_request(None) diff --git a/pysite/base_route.py b/pysite/base_route.py index 4e1a63a7..71a4c894 100644 --- a/pysite/base_route.py +++ b/pysite/base_route.py @@ -6,9 +6,10 @@ from flask import Blueprint, Response, jsonify, render_template from flask.views import MethodView from pysite.constants import ErrorCodes +from pysite.mixins import OauthMixin -class BaseView(MethodView): +class BaseView(MethodView, OauthMixin): """ Base view class with functions and attributes that should be common to all view classes. @@ -27,6 +28,7 @@ class BaseView(MethodView): """ context["current_page"] = self.name context["view"] = self + context["logged_in"] = self.logged_in return render_template(template_names, **context) diff --git a/pysite/constants.py b/pysite/constants.py index c84ca245..7df4674e 100644 --- a/pysite/constants.py +++ b/pysite/constants.py @@ -1,7 +1,7 @@ # coding=utf-8 from enum import Enum, IntEnum -import os +from os import environ class ErrorCodes(IntEnum): @@ -22,6 +22,17 @@ ADMIN_ROLE = 267628507062992896 MODERATOR_ROLE = 267629731250176001 HELPER_ROLE = 267630620367257601 +SERVER_ID = 267624335836053506 + +DISCORD_API_ENDPOINT = "https://discordapp.com/api" + +DISCORD_OAUTH_REDIRECT = "/auth/discord" +DISCORD_OAUTH_AUTHORIZED = "/auth/discord/authorized" +DISCORD_OAUTH_ID = environ.get('DISCORD_OAUTH_ID', '') +DISCORD_OAUTH_SECRET = environ.get('DISCORD_OAUTH_SECRET', '') +DISCORD_OAUTH_SCOPE = 'identify email guilds.join' +OAUTH_DATABASE = "oauth_data" + ERROR_DESCRIPTIONS = { # 5XX 500: "The server encountered an unexpected error ._.", @@ -46,9 +57,9 @@ ERROR_DESCRIPTIONS = { } # PaperTrail logging -PAPERTRAIL_ADDRESS = os.environ.get("PAPERTRAIL_ADDRESS") or None -PAPERTRAIL_PORT = int(os.environ.get("PAPERTRAIL_PORT") or 0) +PAPERTRAIL_ADDRESS = environ.get("PAPERTRAIL_ADDRESS") or None +PAPERTRAIL_PORT = int(environ.get("PAPERTRAIL_PORT") or 0) # DataDog logging -DATADOG_ADDRESS = os.environ.get("DATADOG_ADDRESS") or None -DATADOG_PORT = int(os.environ.get("DATADOG_PORT") or 0) +DATADOG_ADDRESS = environ.get("DATADOG_ADDRESS") or None +DATADOG_PORT = int(environ.get("DATADOG_PORT") or 0) diff --git a/pysite/database.py b/pysite/database.py index add76923..4c2153fe 100644 --- a/pysite/database.py +++ b/pysite/database.py @@ -103,9 +103,11 @@ class RethinkDB: self.log.debug(f"Table created: '{table_name}'") return True - def delete(self, table_name: str, primary_key: Optional[str] = None, - durability: str = "hard", return_changes: Union[bool, str] = False - ) -> Union[Dict[str, Any], None]: + def delete(self, + table_name: str, + primary_key: Union[str, None] = None, + durability: str="hard", + return_changes: Union[bool, str] = False) -> dict: """ Delete one or all documents from a table. This can only delete either the contents of an entire table, or a single document. diff --git a/pysite/mixins.py b/pysite/mixins.py index 059f871d..5b1a780f 100644 --- a/pysite/mixins.py +++ b/pysite/mixins.py @@ -6,7 +6,7 @@ from _weakref import ref from pysite.database import RethinkDB -class DBMixin(): +class DBMixin: """ Mixin for classes that make use of RethinkDB. It can automatically create a table with the specified primary key using the attributes set at class-level. @@ -59,3 +59,49 @@ class DBMixin(): @property def db(self) -> RethinkDB: return self._db() + + +class OauthMixin: + """ + Mixin for the classes that need access to a logged in user's information. This class should be used + to grant route's access to user information, such as name, email, id, ect. + + There will almost never be a need for someone to inherit this, as BaseView does that for you. + + This class will add 3 properties to your route: + + * logged_in (bool): True if user is registered with the site, False else wise. + + * user_data (dict): A dict that looks like this: + + { + "user_id": Their discord ID, + "username": Their discord username (without discriminator), + "discriminator": Their discord discriminator, + "email": Their email, in which is connected to discord + } + + user_data returns None, if the user isn't logged in. + + * oauth (OauthBackend): The instance of pysite.oauth.OauthBackend, connected to the RouteManager. + """ + + @classmethod + def setup(cls: "OauthMixin", manager: "pysite.route_manager.RouteManager", blueprint: Blueprint): + + if hasattr(super(), "setup"): + super().setup(manager, blueprint) # pragma: no cover + + cls._oauth = ref(manager.oauth_backend) + + @property + def logged_in(self) -> bool: + return self.user_data is not None + + @property + def user_data(self) -> dict: + return self.oauth.user_data() + + @property + def oauth(self): + return self._oauth() diff --git a/pysite/oauth.py b/pysite/oauth.py new file mode 100644 index 00000000..8370b713 --- /dev/null +++ b/pysite/oauth.py @@ -0,0 +1,77 @@ +import logging +from uuid import uuid4, uuid5 + +from flask import session +from flask_dance.consumer.backend import BaseBackend +from flask_dance.contrib.discord import discord + + +from pysite.constants import DISCORD_API_ENDPOINT, OAUTH_DATABASE + + +class OauthBackend(BaseBackend): + """ + This is the backend for the oauth + + This is used to manage users that have completed + an oauth dance. It contains 3 functions, get, set, + and delete, however we only use set. + + Inherits: + flake_dance.consumer.backend.BaseBackend + pysite.mixins.DBmixin + + Properties: + key: The app's secret, we use it too make session IDs + """ + + def __init__(self, manager): + super().__init__() + self.db = manager.db + self.key = manager.app.secret_key + self.db.create_table(OAUTH_DATABASE, primary_key="id") + + def get(self, *args, **kwargs): # Not used + pass + + def set(self, blueprint, token): + + user = self.get_user() + self.join_discord(token["access_token"], user["id"]) + sess_id = str(uuid5(uuid4(), self.key)) + self.add_user(token, user, sess_id) + + def delete(self, blueprint): # Not used + pass + + def add_user(self, token_data: dict, user_data: dict, session_id: str): + session["session_id"] = session_id + + self.db.insert(OAUTH_DATABASE, {"id": session_id, + "access_token": token_data["access_token"], + "refresh_token": token_data["refresh_token"], + "expires_at": token_data["expires_at"], + "snowflake": user_data["id"]}) + + self.db.insert("users", {"user_id": user_data["id"], + "username": user_data["username"], + "discriminator": user_data["discriminator"], + "email": user_data["email"]}) + + def get_user(self) -> dict: + resp = discord.get(DISCORD_API_ENDPOINT + "/users/@me") # 'discord' is a request.Session with oauth information + if resp.status_code != 200: + logging.warning("Unable to get user information: " + str(resp.json())) + return resp.json() + + def user_data(self): + user_id = session.get("session_id") + if user_id: # If the user is logged in, get user info. + creds = self.db.get(OAUTH_DATABASE, user_id) + if creds: + return self.db.get("users", creds["snowflake"]) + + def logout(self): + sess_id = session.get("session_id") + if sess_id and self.db.get(OAUTH_DATABASE, sess_id): # If user exists in db, + self.db.delete(OAUTH_DATABASE, sess_id) # remove them (at least, their session) diff --git a/pysite/route_manager.py b/pysite/route_manager.py index 72517a3c..9ecd3ced 100644 --- a/pysite/route_manager.py +++ b/pysite/route_manager.py @@ -5,10 +5,15 @@ import logging import os from flask import Blueprint, Flask +from flask_dance.contrib.discord import make_discord_blueprint from flask_sockets import Sockets from pysite.base_route import APIView, BaseView, ErrorView, RouteView +from pysite.constants import ( + DISCORD_OAUTH_ID, DISCORD_OAUTH_SCOPE, DISCORD_OAUTH_SECRET, DISCORD_OAUTH_REDIRECT, DISCORD_OAUTH_AUTHORIZED +) from pysite.database import RethinkDB +from pysite.oauth import OauthBackend from pysite.websockets import WS TEMPLATES_PATH = "../templates" @@ -31,6 +36,21 @@ class RouteManager: self.app.before_request(self.db.before_request) self.app.teardown_request(self.db.teardown_request) + # Load the oauth blueprint + self.oauth_backend = OauthBackend(self) + self.oauth_blueprint = make_discord_blueprint( + DISCORD_OAUTH_ID, + DISCORD_OAUTH_SECRET, + DISCORD_OAUTH_SCOPE, + '/', + login_url=DISCORD_OAUTH_REDIRECT, + authorized_url=DISCORD_OAUTH_AUTHORIZED, + backend=self.oauth_backend + ) + self.log.debug(f"Loading Blueprint: {self.oauth_blueprint.name}") + self.app.register_blueprint(self.oauth_blueprint) + self.log.debug("") + # Load the main blueprint self.main_blueprint = Blueprint("main", __name__) self.log.debug(f"Loading Blueprint: {self.main_blueprint.name}") diff --git a/pysite/views/main/index.py b/pysite/views/main/index.py index 210eb057..8d0cb349 100644 --- a/pysite/views/main/index.py +++ b/pysite/views/main/index.py @@ -1,5 +1,6 @@ # coding=utf-8 from pysite.base_route import RouteView +from pysite.constants import DISCORD_OAUTH_REDIRECT class IndexView(RouteView): @@ -7,4 +8,4 @@ class IndexView(RouteView): name = "index" def get(self): - return self.render("main/index.html") + return self.render("main/index.html", login_url=DISCORD_OAUTH_REDIRECT) diff --git a/pysite/views/main/logout.py b/pysite/views/main/logout.py new file mode 100644 index 00000000..fce30972 --- /dev/null +++ b/pysite/views/main/logout.py @@ -0,0 +1,15 @@ +from flask import redirect, session + +from pysite.base_route import RouteView + + +class LogoutView(RouteView): + name = "logout" + path = "/auth/logout" + + def get(self): + if self.logged_in: + # remove user's session + del session["session_id"] + self.oauth.logout() + return redirect("/") diff --git a/pysite/views/tests/index.py b/pysite/views/tests/index.py index 3071bf0e..2a55a112 100644 --- a/pysite/views/tests/index.py +++ b/pysite/views/tests/index.py @@ -7,6 +7,7 @@ from pysite.base_route import RouteView from pysite.constants import ValidationTypes from pysite.decorators import api_params + SCHEMA = Schema([{"test": str}]) REQUIRED_KEYS = ["test"] diff --git a/requirements.txt b/requirements.txt index 138d02f0..2717d1e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ wsaccel ujson schema flask_sockets +Flask-Dance logmatic-python diff --git a/templates/main/navigation.html b/templates/main/navigation.html index cabd0d87..aca3037a 100644 --- a/templates/main/navigation.html +++ b/templates/main/navigation.html @@ -30,6 +30,12 @@ <li class="uk-nav-item uk-hidden@m"><a href="/invite"><i class="uk-icon fab fa-discord"></i> Discord</a></li> <li class="uk-nav-divider uk-hidden@m"></li> + {% if logged_in %} + <li class="uk-active"><a href="/auth/logout">Logout</a></li> + {% else %} + <li class="uk-active"><a href={{ login_url }}>Connect to Discord</a></li> + {% endif %} + {% if current_page.startswith("info") %} <li class="uk-nav-header uk-active"><a href="/info">Information</a></li> {% else %} |