aboutsummaryrefslogtreecommitdiffstats
path: root/pysite/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'pysite/database.py')
-rw-r--r--pysite/database.py63
1 files changed, 43 insertions, 20 deletions
diff --git a/pysite/database.py b/pysite/database.py
index 4c2153fe..82e1e84e 100644
--- a/pysite/database.py
+++ b/pysite/database.py
@@ -8,17 +8,27 @@ from flask import abort
from rethinkdb.ast import RqlMethodQuery, Table, UserError
from rethinkdb.net import DefaultConnection
+ALL_TABLES = {
+ # table: primary_key
+
+ "oauth_data": "id",
+ "tags": "tag_name",
+ "users": "user_id",
+ "wiki": "slug",
+}
+
class RethinkDB:
- def __init__(self, loop_type: str = "gevent"):
+ def __init__(self, loop_type: Optional[str] = "gevent"):
self.host = os.environ.get("RETHINKDB_HOST", "127.0.0.1")
self.port = os.environ.get("RETHINKDB_PORT", "28016")
self.database = os.environ.get("RETHINKDB_DATABASE", "pythondiscord")
- self.log = logging.getLogger()
+ self.log = logging.getLogger(__name__)
self.conn = None
- rethinkdb.set_loop_type(loop_type)
+ if loop_type:
+ rethinkdb.set_loop_type(loop_type)
with self.get_connection(connect_database=False) as conn:
try:
@@ -27,7 +37,16 @@ class RethinkDB:
except rethinkdb.RqlRuntimeError:
self.log.debug(f"Database found: '{self.database}'")
- def get_connection(self, connect_database: bool=True) -> DefaultConnection:
+ def create_tables(self) -> int:
+ created = 0
+
+ for table, primary_key in ALL_TABLES.items():
+ if self.create_table(table, primary_key):
+ created += 1
+
+ return created
+
+ def get_connection(self, connect_database: bool = True) -> DefaultConnection:
"""
Grab a connection to the RethinkDB server, optionally without selecting a database
@@ -63,8 +82,8 @@ class RethinkDB:
# region: Convenience wrappers
- def create_table(self, table_name: str, primary_key: str="id", durability: str="hard", shards: int=1,
- replicas: Union[int, Dict[str, int]]=1, primary_replica_tag: Optional[str]=None) -> bool:
+ def create_table(self, table_name: str, primary_key: str = "id", durability: str = "hard", shards: int = 1,
+ replicas: Union[int, Dict[str, int]] = 1, primary_replica_tag: Optional[str] = None) -> bool:
"""
Attempt to create a new table on the current database
@@ -106,7 +125,7 @@ class RethinkDB:
def delete(self,
table_name: str,
primary_key: Union[str, None] = None,
- durability: str="hard",
+ durability: str = "hard",
return_changes: Union[bool, str] = False) -> dict:
"""
Delete one or all documents from a table. This can only delete
@@ -170,10 +189,13 @@ class RethinkDB:
:return: The RethinkDB table object for the table
"""
+ if table_name not in ALL_TABLES:
+ self.log.warning(f"Table not declared in database.py: {table_name}")
+
return rethinkdb.table(table_name)
- def run(self, query: RqlMethodQuery, *, new_connection: bool=False,
- connect_database: bool=True, coerce: type=None) -> Union[rethinkdb.Cursor, List, Dict, object]:
+ def run(self, query: RqlMethodQuery, *, new_connection: bool = False,
+ connect_database: bool = True, coerce: type = None) -> Union[rethinkdb.Cursor, List, Dict, object]:
"""
Run a query using a table object obtained from a call to `query()`
@@ -215,14 +237,14 @@ class RethinkDB:
# region: RethinkDB wrapper functions
def insert(self, table_name: str, *objects: Dict[str, Any],
- durability: str="hard",
- return_changes: Union[bool, str]=False,
+ durability: str = "hard",
+ return_changes: Union[bool, str] = False,
conflict: Union[ # Any of...
str, Callable[ # ...str, or a callable that...
[Dict[str, Any], Dict[str, Any]], # ...takes two dicts with string keys and any values...
Dict[str, Any] # ...and returns a dict with string keys and any values
]
- ]="error") -> Union[List, Dict]: # flake8: noqa
+ ] = "error") -> Union[List, Dict]: # flake8: noqa
"""
Insert an object or a set of objects into a table
@@ -262,7 +284,7 @@ class RethinkDB:
return dict(result) if result else None # pragma: no cover
- def get_all(self, table_name: str, *keys: str, index: str="id") -> List[Any]:
+ def get_all(self, table_name: str, *keys: str, index: str = "id") -> List[Any]:
"""
Get a list of documents matching a set of keys, on a specific index
@@ -278,7 +300,7 @@ class RethinkDB:
coerce=list
)
- def wait(self, table_name: str, wait_for: str="all_replicas_ready", timeout: int=0) -> bool:
+ def wait(self, table_name: str, wait_for: str = "all_replicas_ready", timeout: int = 0) -> bool:
"""
Wait until an operation has happened on a specific table; will block the current function
@@ -312,9 +334,9 @@ class RethinkDB:
return result.get("synced", 0) > 0 # pragma: no cover
- def changes(self, table_name: str, squash: Union[bool, int]=False, changefeed_queue_size: int=100_000,
- include_initial: Optional[bool]=None, include_states: bool=False,
- include_types: bool=False) -> Iterator[Dict[str, Any]]:
+ def changes(self, table_name: str, squash: Union[bool, int] = False, changefeed_queue_size: int = 100_000,
+ include_initial: Optional[bool] = None, include_states: bool = False,
+ include_types: bool = False) -> Iterator[Dict[str, Any]]:
"""
A complicated function allowing you to follow a changefeed for a specific table
@@ -420,8 +442,9 @@ class RethinkDB:
self.query(table_name).without(*selectors)
)
- def between(self, table_name: str, *, lower: Any=rethinkdb.minval, upper: Any=rethinkdb.maxval,
- index: Optional[str]=None, left_bound: str="closed", right_bound: str ="open") -> List[Dict[str, Any]]:
+ def between(self, table_name: str, *, lower: Any = rethinkdb.minval, upper: Any = rethinkdb.maxval,
+ index: Optional[str] = None, left_bound: str = "closed", right_bound: str = "open") -> List[
+ Dict[str, Any]]:
"""
Get all documents between two keys
@@ -483,7 +506,7 @@ class RethinkDB:
)
def filter(self, table_name: str, predicate: Callable[[Dict[str, Any]], bool],
- default: Union[bool, UserError]=False) -> List[Dict[str, Any]]:
+ default: Union[bool, UserError] = False) -> List[Dict[str, Any]]:
"""
Return all documents in a table for which `predicate` returns true.