aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Gareth Coles <[email protected]>2018-02-14 23:10:31 +0000
committerGravatar Sam Wedgwood <[email protected]>2018-02-14 23:10:31 +0000
commit70f0a9166b15645845370f4db8b1c9d1cfb75e6a (patch)
treea9ff85c2ecdb1db529a935f070b2d54ecbee93ef
parent[API] You need to return the value of `self.error()` (diff)
Database API Improvements #1qcra (#13)
* A large set of changes, including: * A mixin for views that need the DB * Many changes to the database class in order to make things more fluid * Provide the route manager in view setup() methods * Pushing up the progress so far * snekchek * Full (undocumented) database implementation * snekchek * Don't rely on exceptions for table deletion * Add RethinkDB data to gitignore * Documentation for DB class * Make Flake8 ignore P102 What even is that? What does "docstring does contain unindexed parameters" mean? * Document the base_routes module * Cleanup RE latest reviews * snekchek (bah)
-rw-r--r--.gitignore3
-rw-r--r--.snekrc3
-rw-r--r--pysite/base_route.py130
-rw-r--r--pysite/database.py446
-rw-r--r--pysite/route_manager.py2
-rw-r--r--pysite/views/api/bot/tag.py64
6 files changed, 603 insertions, 45 deletions
diff --git a/.gitignore b/.gitignore
index 6669a983..23151389 100644
--- a/.gitignore
+++ b/.gitignore
@@ -99,3 +99,6 @@ ENV/
# PyCharm
.idea/
+
+# RethinkDB data
+rethinkdb_data/
diff --git a/.snekrc b/.snekrc
index 3f26eea5..5aabe4e0 100644
--- a/.snekrc
+++ b/.snekrc
@@ -3,4 +3,5 @@ linters = flake8, safety, dodgy
[flake8]
max-line-length=120
-application_import_names=pysite \ No newline at end of file
+application_import_names=pysite
+ignore=P102 \ No newline at end of file
diff --git a/pysite/base_route.py b/pysite/base_route.py
index e0871e49..730b3e10 100644
--- a/pysite/base_route.py
+++ b/pysite/base_route.py
@@ -4,13 +4,22 @@ import random
import string
from functools import wraps
-from flask import Blueprint, jsonify, render_template, request
+from flask import Blueprint, g, jsonify, render_template, request
from flask.views import MethodView
+from rethinkdb.ast import Table
+
from pysite.constants import ErrorCodes
+from pysite.database import RethinkDB
class BaseView(MethodView):
+ """
+ Base view class with functions and attributes that should be common to all view classes.
+
+ This class should be subclassed, and is not intended to be used directly.
+ """
+
name = None # type: str
def render(self, *template_names, **context):
@@ -21,10 +30,39 @@ class BaseView(MethodView):
class RouteView(BaseView):
+ """
+ Standard route-based page view. For a standard page, this is what you want.
+
+ This class is intended to be subclassed - use it as a base class for your own views, and set the class-level
+ attributes as appropriate. For example:
+
+ >>> class MyView(RouteView):
+ ... name = "my_view" # Flask internal name for this route
+ ... path = "/my_view" # Actual URL path to reach this route
+ ...
+ ... def get(self): # Name your function after the relevant HTTP method
+ ... return self.render("index.html")
+
+ For more complicated routing, see http://exploreflask.com/en/latest/views.html#built-in-converters
+ """
+
path = None # type: str
@classmethod
- def setup(cls: "RouteView", blueprint: Blueprint):
+ def setup(cls: "RouteView", manager: "pysite.route_manager.RouteManager", blueprint: Blueprint):
+ """
+ Set up the view by adding it to the blueprint passed in - this will also deal with multiple inheritance by
+ calling `super().setup()` as appropriate.
+
+ This is for a standard route view. Nothing special here.
+
+ :param manager: Instance of the current RouteManager
+ :param blueprint: Current Flask blueprint to register this route to
+ """
+
+ if hasattr(super(), "setup"):
+ super().setup(manager, blueprint)
+
if not cls.path or not cls.name:
raise RuntimeError("Route views must have both `path` and `name` defined")
@@ -32,6 +70,20 @@ class RouteView(BaseView):
class APIView(RouteView):
+ """
+ API route view, with extra methods to help you add routes to the JSON API with ease.
+
+ This class is intended to be subclassed - use it as a base class for your own views, and set the class-level
+ attributes as appropriate. For example:
+
+ >>> class MyView(APIView):
+ ... name = "my_view" # Flask internal name for this route
+ ... path = "/my_view" # Actual URL path to reach this route
+ ...
+ ... def get(self): # Name your function after the relevant HTTP method
+ ... return self.error(ErrorCodes.unknown_route)
+ """
+
def validate_key(self, api_key: str):
""" Placeholder! """
return api_key == os.environ.get("API_KEY")
@@ -81,11 +133,83 @@ class APIView(RouteView):
return response
+class DBViewMixin:
+ """
+ Mixin for views that make use of RethinkDB. It can automatically create a table with the specified primary
+ key using the attributes set at class-level.
+
+ This class is intended to be mixed in alongside one of the other view classes. For example:
+
+ >>> class MyView(APIView, DBViewMixin):
+ ... name = "my_view" # Flask internal name for this route
+ ... path = "/my_view" # Actual URL path to reach this route
+ ... table_name = "my_table" # Name of the table to create
+ ... table_primary_key = "username" # Primary key to set for this table
+
+ You may omit `table_primary_key` and it will be defaulted to RethinkDB's default column - "id".
+ """
+
+ table_name = "" # type: str
+ table_primary_key = "id" # type: str
+
+ @classmethod
+ def setup(cls: "DBViewMixin", manager: "pysite.route_manager.RouteManager", blueprint: Blueprint):
+ """
+ Set up the view by creating the table specified by the class attributes - this will also deal with multiple
+ inheritance by calling `super().setup()` as appropriate.
+
+ :param manager: Instance of the current RouteManager (used to get a handle for the database object)
+ :param blueprint: Current Flask blueprint
+ """
+
+ if hasattr(super(), "setup"):
+ super().setup(manager, blueprint)
+
+ if not cls.table_name:
+ raise RuntimeError("Routes using DBViewMixin must define `table_name`")
+
+ manager.db.create_table(cls.table_name, primary_key=cls.table_primary_key)
+
+ @property
+ def table(self) -> Table:
+ return self.db.query(self.table_name)
+
+ @property
+ def db(self) -> RethinkDB:
+ return g.db
+
+
class ErrorView(BaseView):
+ """
+ Error view, shown for a specific HTTP status code, as defined in the class attributes.
+
+ This class is intended to be subclassed - use it as a base class for your own views, and set the class-level
+ attributes as appropriate. For example:
+
+ >>> class MyView(ErrorView):
+ ... name = "my_view" # Flask internal name for this route
+ ... path = "/my_view" # Actual URL path to reach this route
+ ... error_code = 404
+ ...
+ ... def get(self): # Name your function after the relevant HTTP method
+ ... return "Replace me with a template, 404 not found", 404
+ """
+
error_code = None # type: int
@classmethod
- def setup(cls: "ErrorView", blueprint: Blueprint):
+ def setup(cls: "ErrorView", manager: "pysite.route_manager.RouteManager", blueprint: Blueprint):
+ """
+ Set up the view by registering it as the error handler for the HTTP status code specified in the class
+ attributes - this will also deal with multiple inheritance by calling `super().setup()` as appropriate.
+
+ :param manager: Instance of the current RouteManager
+ :param blueprint: Current Flask blueprint to register the error handler for
+ """
+
+ if hasattr(super(), "setup"):
+ super().setup(manager, blueprint)
+
if not cls.name or not cls.error_code:
raise RuntimeError("Error views must have both `name` and `error_code` defined")
diff --git a/pysite/database.py b/pysite/database.py
index 75f01378..1c9cb838 100644
--- a/pysite/database.py
+++ b/pysite/database.py
@@ -1,10 +1,13 @@
# coding=utf-8
import os
+from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from flask import abort
import rethinkdb
+from rethinkdb.ast import RqlMethodQuery, Table, UserError
+from rethinkdb.net import DefaultConnection
class RethinkDB:
@@ -20,24 +23,461 @@ class RethinkDB:
with self.get_connection(connect_database=False) as conn:
try:
rethinkdb.db_create(self.database).run(conn)
- print(f"Database created: {self.database}")
+ print(f"Database created: '{self.database}'")
except rethinkdb.RqlRuntimeError:
- print(f"Database found: {self.database}")
+ print(f"Database found: '{self.database}'")
+
+ def get_connection(self, connect_database: bool=True) -> DefaultConnection:
+ """
+ Grab a connection to the RethinkDB server, optionally without selecting a database
+
+ :param connect_database: Whether to immediately connect to the database or not
+ """
- def get_connection(self, connect_database: bool = True):
if connect_database:
return rethinkdb.connect(host=self.host, port=self.port, db=self.database)
else:
return rethinkdb.connect(host=self.host, port=self.port)
def before_request(self):
+ """
+ Flask pre-request callback to set up a connection for the duration of the request
+ """
+
try:
self.conn = self.get_connection()
except rethinkdb.RqlDriverError:
abort(503, "Database connection could not be established.")
def teardown_request(self, _):
+ """
+ Flask post-request callback to close a previously set-up connection
+
+ :param _: Exception object, not used here
+ """
+
try:
self.conn.close()
except AttributeError:
pass
+
+ # region: Convenience wrappers
+
+ def create_table(self, table_name: str, primary_key: str="id", durability: str="hard", shards: int=1,
+ replicas: Union[int, Dict[str, int]]=1, primary_replica_tag: Optional[str]=None) -> bool:
+ """
+ Attempt to create a new table on the current database
+
+ :param table_name: The name of the table to create
+ :param primary_key: The name of the primary key - defaults to "id"
+ :param durability: "hard" (the default) to write the change immediately, "soft" otherwise
+ :param shards: The number of shards to span the table over - defaults to 1
+ :param replicas: See the RethinkDB documentation relating to replicas
+ :param primary_replica_tag: See the RethinkDB documentation relating to replicas
+
+ :return: True if the table was created, False if it already exists
+ """
+
+ with self.get_connection() as conn:
+ all_tables = rethinkdb.db(self.database).table_list().run(conn)
+
+ if table_name in all_tables:
+ print(f"Table found: '{table_name}' ({len(all_tables)} tables in total)")
+ return False
+
+ # Use a kwargs dict because the driver doesn't check the value
+ # of `primary_replica_tag` properly; None is not handled
+ kwargs = {
+ "primary_key": primary_key,
+ "durability": durability,
+ "shards": shards,
+ "replicas": replicas
+ }
+
+ if primary_replica_tag is not None:
+ kwargs["primary_replica_tag"] = primary_replica_tag
+
+ rethinkdb.db(self.database).table_create(table_name, **kwargs).run(conn)
+
+ print(f"Table created: '{table_name}'")
+ return True
+
+ def drop_table(self, table_name: str):
+ """
+ Attempt to drop a table from the database, along with its data
+
+ :param table_name: The name of the table to drop
+ :return: True if the table was dropped, False if the table doesn't exist
+ """
+
+ with self.get_connection() as conn:
+ all_tables = rethinkdb.db(self.database).table_list().run(conn)
+
+ if table_name not in all_tables:
+ return False
+
+ rethinkdb.db(self.database).table_drop(table_name).run(conn)
+ return True
+
+ def query(self, table_name: str) -> Table:
+ """
+ Get a RethinkDB table object that you can run queries against
+
+ >>> db = RethinkDB()
+ >>> query = db.query("my_table")
+ >>> db.run(query.insert({"key": "value"}), coerce=dict)
+ {
+ "deleted": 0,
+ "errors": 0,
+ "inserted": 1,
+ "replaced": 0,
+ "skipped": 0,
+ "unchanged": 0
+ }
+
+ :param table_name: Name of the table to query against
+ :return: The RethinkDB table object for the table
+ """
+
+ return rethinkdb.table(table_name)
+
+ def run(self, query: RqlMethodQuery, *, new_connection: bool=False,
+ connect_database: bool=True, coerce: type=None) -> Union[rethinkdb.Cursor, List, Dict, object]:
+ """
+ Run a query using a table object obtained from a call to `query()`
+
+ >>> db = RethinkDB()
+ >>> query = db.query("my_table")
+ >>> db.run(query.insert({"key": "value"}), coerce=dict)
+ {
+ "deleted": 0,
+ "errors": 0,
+ "inserted": 1,
+ "replaced": 0,
+ "skipped": 0,
+ "unchanged": 0
+ }
+
+ Note that result coercion is very basic, and doesn't really do any magic. If you want to be able to work
+ directly with the result of your query, then don't specify the `coerce` argument - the object that you'd
+ usually get from the RethinkDB API will be returned instead.
+
+ :param query: The full query to run
+ :param new_connection: Whether to create a new connection or use the current request-bound one
+ :param connect_database: If creating a new connection, whether to connect to the database immediately
+ :param coerce: Optionally, an object type to attempt to coerce the result to
+
+ :return: THe result of the operation
+ """
+
+ if not new_connection:
+ result = query.run(self.conn)
+ else:
+ result = query.run(self.get_connection(connect_database))
+
+ if coerce:
+ return coerce(result) if result else coerce()
+ return result
+
+ # endregion
+
+ # region: RethinkDB wrapper functions
+
+ def insert(self, table_name: str, *objects: Dict[str, Any],
+ durability: str="hard",
+ return_changes: Union[bool, str]=False,
+ conflict: Union[ # Any of...
+ str, Callable[ # ...str, or a callable that...
+ [Dict[str, Any], Dict[str, Any]], # ...takes two dicts with string keys and any values...
+ Dict[str, Any] # ...and returns a dict with string keys and any values
+ ]
+ ]="error") -> Union[List, Dict]: # flake8: noqa
+ """
+ Insert an object or a set of objects into a table
+
+ :param table_name: The name of the table to insert into
+ :param objects: The objects to be inserted into the table
+ :param durability: "hard" (the default) to write the change immediately, "soft" otherwise
+ :param return_changes: Whether to return a list of changed values or not - defaults to False
+ :param conflict: What to do in the event of a conflict - "error", "replace" and "update" are included, but
+ you can also provide your own function in order to handle conflicts yourself. If you do this, the function
+ should take two arguments (the old document and the new one), and return a single document to replace both.
+
+ :return: A list of changes if `return_changes` is True; a dict detailing the operations run otherwise
+ """
+
+ query = self.query(table_name).insert(
+ *objects, durability=durability, return_changes=return_changes, conflict=conflict
+ )
+
+ if return_changes:
+ return self.run(query, coerce=list)
+ else:
+ return self.run(query, coerce=dict)
+
+ def get(self, table_name: str, key: str) -> Union[Dict[str, Any], None]:
+ """
+ Get a single document from a table by primary key
+
+ :param table_name: The name of the table to get the document from
+ :param key: The value of the primary key belonging to the document you want
+
+ :return: The document, or None if it wasn't found
+ """
+
+ result = self.run(
+ self.query(table_name).get(key)
+ )
+
+ return dict(result) if result else None
+
+ def get_all(self, table_name: str, *keys: str, index: str="id") -> List[Any]:
+ """
+ Get a list of documents matching a set of keys, on a specific index
+
+ :param table_name: The name of the table to get documents from
+ :param keys: The key values to match against
+ :param index: The name of the key or index to match on
+
+ :return: A list of matching documents; may be empty if no matches were made
+ """
+
+ return self.run(
+ self.query(table_name).get_all(*keys, index=index),
+ coerce=list
+ )
+
+ def wait(self, table_name: str, wait_for: str="all_replicas_ready", timeout: int=0) -> bool:
+ """
+ Wait until an operation has happened on a specific table; will block the current function
+
+ :param table_name: The name of the table to wait against
+ :param wait_for: The operation to wait for; may be "ready_for_outdated_reads",
+ "ready_for_reads", "ready_for_writes" or "all_replicas_ready", which is the default
+ :param timeout: How long to wait before returning; defaults to 0 (forever)
+
+ :return: True; but may return False if the timeout was reached
+ """
+
+ result = self.run(
+ self.query(table_name).wait(wait_for=wait_for, timeout=timeout),
+ coerce=dict
+ )
+
+ return result.get("ready", 0) > 0
+
+ def sync(self, table_name: str) -> bool:
+ """
+ Following a set of edits with durability set to "soft", this must be called to save those edits
+
+ :param table_name: The name of the table to sync
+
+ :return: True if the sync was successful; False otherwise
+ """
+ result = self.run(
+ self.query(table_name).sync(),
+ coerce=dict
+ )
+
+ return result.get("synced", 0) > 0
+
+ def changes(self, table_name: str, squash: Union[bool, int]=False, changefeed_queue_size: int=100_000,
+ include_initial: Optional[bool]=None, include_states: bool=False,
+ include_types: bool=False) -> Iterator[Dict[str, Any]]:
+ """
+ A complicated function allowing you to follow a changefeed for a specific table
+
+ This function will not allow you to specify a set of conditions for your changefeed, so you'll
+ have to write your own query and run it with `run()` if you need that. If not, you'll just get every
+ change for the specified table.
+
+ >>> db = RethinkDB()
+ >>> for document in db.changes("my_table", squash=True):
+ ... print(document.get("new_val", {}))
+
+ Documents take the form of a dict with `old_val` and `new_val` fields by default. These are set to a copy of
+ the document before and after the change being represented was made, respectively. The format of these dicts
+ can change depending on the arguments you pass to the function, however.
+
+ If a changefeed must be aborted (for example, if the table was deleted), a ReqlRuntimeError will be
+ raised.
+
+ Note: This function always creates a new connection. This is to prevent you from losing your changefeed
+ when the connection used for a request context is closed.
+
+ :param table_name: The name of the table to watch for changes on
+
+ :param squash: How to deal with batches of changes to a single document - False (the default) to send changes
+ as they happen, True to squash changes for single objects together and send them as a single change,
+ or an int to specify how many seconds to wait for an object to change before batching it
+
+ :param changefeed_queue_size: The number of changes the server will buffer between client reads before it
+ starts to drop changes and issues errors - defaults to 100,000
+
+ :param include_initial: If True, the changefeed will start with the initial values of all the documents in
+ the table; the results will have `new_val` fields ONLY to start with if this is the case. Note that
+ the old values may be intermixed with new changes if you're still iterating through the old values, but
+ only as long as the old value for that field has already been sent. If the order of a document you've
+ already seen moves it to a part of the group you haven't yet seen, an "unitial" notification is sent, which
+ is simply a dict with an `old_val` field set, and not a `new_val` field set. This option defaults to
+ False.
+
+ :param include_states: Whether to send special state documents to the changefeed as its state changes. This
+ comprises of special documents with only a `state` field, set to a string - the state of the feed. There
+ are currently two states - "initializing" and "ready". This option defaults to False.
+
+ :param include_types: If True, each document generated will include a `type` field which states what type
+ of change the document represents. This may be "add", "remove", "change", "initial", "uninitial" or
+ "state". This option defaults to False.
+
+ :return: A special iterator that will iterate over documents in the changefeed as they're sent. If there is
+ no document waiting, this will block the function until there is.
+ """
+ return self.run(
+ self.query(table_name).changes(
+ squash=squash, changefeed_queue_size=changefeed_queue_size, include_initial=include_initial,
+ include_states=include_states, include_offsets=False, include_types=include_types
+ ),
+ new_connection=True
+ )
+
+ def pluck(self, table_name: str, *selectors: Union[str, Dict[str, Union[List[...], Dict[str, ...]]]]):
+ """
+ Get a list of values for a specific set of keys for every document in the table; this can include
+ nested values
+
+ >>> db = RethinkDB()
+ >>> db.pluck("users", "username", "password") # Select a flat document
+ [
+ {"username": "lemon", "password": "hunter2"}
+ ]
+ >>> db.pluck("users", {"posts": ["title"]}) # Select from nested documents
+ [
+ {
+ "posts": [
+ {"title": "New website!"}
+ ]
+ }
+ ]
+
+ :param table_name: The table to get values from
+ :param selectors: The set of keys to get values for
+ :return: A list containing the requested documents, with only the keys requested
+ """
+
+ return self.run(
+ self.query(table_name).pluck(*selectors)
+ )
+
+ def without(self, table_name: str, *selectors: Union[str, Dict[str, Union[List[...], Dict[str, ...]]]]):
+ """
+ The functional opposite of `pluck()`, returning full documents without the specified selectors
+
+ >>> db = RethinkDB()
+ >>> db.without("users", "posts")
+ [
+ {"username": "lemon", "password": "hunter2}
+ ]
+
+ :param table_name: The table to get values from
+ :param selectors: The set of keys to exclude
+ :return: A list containing the requested documents, without the keys requested
+ """
+
+ return self.run(
+ self.query(table_name).without(*selectors)
+ )
+
+ def between(self, table_name: str, *, lower: Any=rethinkdb.minval, upper: Any=rethinkdb.maxval,
+ index: Optional[str]=None, left_bound: str="closed", right_bound: str ="open") -> List[Dict[str, Any]]:
+ """
+ Get all documents between two keys
+
+ >>> db = RethinkDB()
+ >>> db.between("users", upper=10, index="conquests")
+ [
+ {"username": "gdude", "conquests": 2},
+ {"username": "joseph", "conquests": 5}
+ ]
+ >>> db.between("users", lower=10, index="conquests")
+ [
+ {"username": "lemon", "conquests": 15}
+ ]
+ >>> db.between("users", lower=2, upper=10, index="conquests" left_bound="open")
+ [
+ {"username": "gdude", "conquests": 2},
+ {"username": "joseph", "conquests": 5}
+ ]
+
+ :param table_name: The table to get documents from
+ :param lower: The lower-bounded value, leave blank to ignore
+ :param upper: The upper-bounded value, leave blank to ignore
+ :param index: The key or index to check on each document
+ :param left_bound: "open" to include documents that exactly match the lower bound, "closed" otherwise
+ :param right_bound: "open" to include documents that exactly match the upper bound, "closed" otherwise
+
+ :return: A list of matched documents; may be empty
+ """
+ return self.run(
+ self.query(table_name).between(lower, upper, index=index, left_bound=left_bound, right_bound=right_bound),
+ coerce=list
+ )
+
+ def map(self, table_name: str, func: Callable):
+ """
+ Map a function over every document in a table, with the possibility of modifying it
+
+ r.table('users').map(
+ lambda doc: doc.merge({'user_id': doc['id']}).without('id')).run(conn)
+
+ As an example, you could do the following to rename the "id" field to "user_id" for all documents
+ in the "users" table.
+
+ >>> db = RethinkDB()
+ >>> db.map(
+ ... "users",
+ ... lambda doc: doc.merge({"user_id": doc["id"]}).without("id")
+ ... )
+
+ :param table_name: The name of the table to map the function over
+ :param func: A callable that takes a single argument
+
+ :return: Unknown, needs more testing
+ """
+
+ return self.run(
+ self.query(table_name).map(func),
+ coerce=list
+ )
+
+ def filter(self, table_name: str, predicate: Callable[[Dict[str, Any]], bool],
+ default: Union[bool, UserError]=False) -> List[Dict[str, Any]]:
+ """
+ Return all documents in a table for which `predicate` returns true.
+
+ The `predicate` argument should be a function that takes a single argument - a single document to check - and
+ it should return True or False depending on whether the document should be included.
+
+ >>> def many_conquests(doc):
+ ... '''Return documents with at least 10 conquests'''
+ ... return doc["conquests"] >= 10
+ ...
+ >>> db = RethinkDB()
+ >>> db.filter("users", many_conquests)
+ [
+ {"username": "lemon", "conquests": 15}
+ ]
+
+ :param table_name: The name of the table to get documents for
+ :param predicate: The callable to use to filter the documents
+ :param default: What to do if a document is missing fields; True to include them, `rethink.error()` to raise
+ aa ReqlRuntimeError, or False to skip over the document (the default)
+ :return: A list of documents that match the predicate; may be empty
+ """
+
+ return self.run(
+ self.query(table_name).filter(predicate, default=default),
+ coerce=list
+ )
+
+ # endregion
diff --git a/pysite/route_manager.py b/pysite/route_manager.py
index 6f973767..ddd969d7 100644
--- a/pysite/route_manager.py
+++ b/pysite/route_manager.py
@@ -71,5 +71,5 @@ class RouteManager:
cls is not APIView and
BaseView in cls.__mro__
):
- cls.setup(blueprint)
+ cls.setup(self, blueprint)
print(f">> View loaded: {cls.name: <15} ({module.__name__}.{cls_name})")
diff --git a/pysite/views/api/bot/tag.py b/pysite/views/api/bot/tag.py
index e679b500..84fd8977 100644
--- a/pysite/views/api/bot/tag.py
+++ b/pysite/views/api/bot/tag.py
@@ -1,66 +1,56 @@
# coding=utf-8
-from flask import g, jsonify, request
+from flask import jsonify, request
-import rethinkdb
-
-from pysite.base_route import APIView
+from pysite.base_route import APIView, DBViewMixin
from pysite.constants import ErrorCodes
-class TagView(APIView):
- path = '/tag'
- name = 'tag'
- table = 'tag'
-
- def __init__(self):
- # make sure the table exists
- with g.db.get_connection() as conn:
- try:
- rethinkdb.db(g.db.database).table_create(self.table, {'primary_key': 'tag_name'}).run(conn)
- except rethinkdb.RqlRuntimeError:
- print(f'Table {self.table} exists')
+class TagView(APIView, DBViewMixin):
+ path = "/tag"
+ name = "tag"
+ table_name = "tag"
+ table_primary_key = "tag_name"
def get(self):
"""
- Indata must be provided as params,
+ Data must be provided as params,
API key must be provided as header
"""
- rdb = rethinkdb.table(self.table)
- api_key = request.headers.get('X-API-Key')
- tag_name = request.args.get('tag_name')
+ api_key = request.headers.get("X-API-Key")
+ tag_name = request.args.get("tag_name")
if self.validate_key(api_key):
if tag_name:
- data = rdb.get(tag_name).run(g.db.conn)
- data = dict(data) if data else {}
+ data = self.db.get(self.table_name, tag_name)
else:
- data = rdb.pluck('tag_name').run(g.db.conn)
- data = list(data) if data else []
+ data = self.db.pluck(self.table_name, "tag_name")
else:
return self.error(ErrorCodes.invalid_api_key)
- return jsonify(data)
+ return jsonify(data or {})
def post(self):
- """ Indata must be provided as JSON. """
- rdb = rethinkdb.table(self.table)
+ """ Data must be provided as JSON. """
indata = request.get_json()
- tag_name = indata.get('tag_name')
- tag_content = indata.get('tag_content')
- tag_category = indata.get('tag_category')
- api_key = indata.get('api_key')
+ tag_name = indata.get("tag_name")
+ tag_content = indata.get("tag_content")
+ tag_category = indata.get("tag_category")
+ api_key = request.headers.get("X-API-Key")
if self.validate_key(api_key):
if tag_name and tag_content:
- rdb.insert({
- 'tag_name': tag_name,
- 'tag_content': tag_content,
- 'tag_category': tag_category
- }).run(g.db.conn)
+ self.db.insert(
+ self.table_name,
+ {
+ "tag_name": tag_name,
+ "tag_content": tag_content,
+ "tag_category": tag_category
+ }
+ )
else:
return self.error(ErrorCodes.missing_parameters)
else:
return self.error(ErrorCodes.invalid_api_key)
- return jsonify({'success': True})
+ return jsonify({"success": True})