diff options
-rw-r--r-- | pysite/base_route.py | 4 | ||||
-rw-r--r-- | pysite/database.py | 3 | ||||
-rw-r--r-- | pysite/decorators.py | 4 | ||||
-rw-r--r-- | pysite/route_manager.py | 6 | ||||
-rw-r--r-- | pysite/views/api/bot/tag.py | 6 |
5 files changed, 12 insertions, 11 deletions
diff --git a/pysite/base_route.py b/pysite/base_route.py index e1b9c6b2..fe8632de 100644 --- a/pysite/base_route.py +++ b/pysite/base_route.py @@ -2,6 +2,7 @@ import os import random import string +from _weakref import ref from flask import Blueprint, g, jsonify, render_template from flask.views import MethodView @@ -153,6 +154,7 @@ class DBViewMixin: if not cls.table_name: raise RuntimeError("Routes using DBViewMixin must define `table_name`") + cls._db = ref(manager.db) manager.db.create_table(cls.table_name, primary_key=cls.table_primary_key) @property @@ -161,7 +163,7 @@ class DBViewMixin: @property def db(self) -> RethinkDB: - return g.db + return self._db() class ErrorView(BaseView): diff --git a/pysite/database.py b/pysite/database.py index 5a319ebc..3df4bbed 100644 --- a/pysite/database.py +++ b/pysite/database.py @@ -366,7 +366,8 @@ class RethinkDB: """ return self.run( - self.query(table_name).pluck(*selectors) + self.query(table_name).pluck(*selectors), + coerce=list ) def without(self, table_name: str, *selectors: Union[str, Dict[str, Union[List, Dict]]]): diff --git a/pysite/decorators.py b/pysite/decorators.py index 6951e875..8d0cf7f4 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -1,5 +1,6 @@ # coding=utf-8 import os +from functools import wraps from flask import request @@ -13,9 +14,10 @@ def valid_api_key(f): Should only be applied to functions on APIView routes. """ + @wraps(f) def has_valid_api_key(self, *args, **kwargs): if not request.headers.get("X-API-Key") == os.environ.get("API_KEY"): return self.error(ErrorCodes.invalid_api_key) - return f(*args, **kwargs) + return f(self, *args, **kwargs) return has_valid_api_key diff --git a/pysite/route_manager.py b/pysite/route_manager.py index ddd969d7..aeaec7c9 100644 --- a/pysite/route_manager.py +++ b/pysite/route_manager.py @@ -21,14 +21,10 @@ class RouteManager: ) self.db = RethinkDB() self.app.secret_key = os.environ.get("WEBPAGE_SECRET_KEY", "super_secret") - self.app.config["SERVER_NAME"] = os.environ.get("SERVER_NAME", "pythondiscord.com:8080") + self.app.config["SERVER_NAME"] = os.environ.get("", "pythondiscord.com:8080") self.app.before_request(self.db.before_request) self.app.teardown_request(self.db.teardown_request) - # Store the database in the Flask global context - with self.app.app_context(): - g.db = self.db # type: RethinkDB - # Load the main blueprint self.main_blueprint = Blueprint("main", __name__) print(f"Loading Blueprint: {self.main_blueprint.name}") diff --git a/pysite/views/api/bot/tag.py b/pysite/views/api/bot/tag.py index ef17e8fa..b2ce145b 100644 --- a/pysite/views/api/bot/tag.py +++ b/pysite/views/api/bot/tag.py @@ -23,11 +23,11 @@ class TagView(APIView, DBViewMixin): tag_name = request.args.get("tag_name") if tag_name: - data = self.db.get(self.table_name, tag_name) + data = self.db.get(self.table_name, tag_name) or {} else: - data = self.db.pluck(self.table_name, "tag_name") + data = self.db.pluck(self.table_name, "tag_name") or [] - return jsonify(data or {}) + return jsonify(data) @valid_api_key def post(self): |