aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2020-08-25 13:23:38 +0100
committerGravatar Joe Banks <[email protected]>2020-08-25 13:23:38 +0100
commite62918abb9d8722c4c0388f3c9fdf9f61b82b572 (patch)
tree6635eeaf68f215a565de2248f68a1ec727f725a9
parentAdd Alembic (diff)
Add database
-rw-r--r--metricity/database.py28
-rw-r--r--metricity/models.py75
2 files changed, 103 insertions, 0 deletions
diff --git a/metricity/database.py b/metricity/database.py
new file mode 100644
index 0000000..1534e2c
--- /dev/null
+++ b/metricity/database.py
@@ -0,0 +1,28 @@
+"""Methods for connecting and interacting with the database."""
+import logging
+
+import gino
+
+from metricity.config import DatabaseConfig
+
+log = logging.getLogger(__name__)
+
+db = gino.Gino()
+
+
+def build_db_uri() -> str:
+ """Use information from the config file to build a PostgreSQL URI."""
+ if DatabaseConfig.uri:
+ return DatabaseConfig.uri
+
+ return (
+ f"postgresql://{DatabaseConfig.username}:{DatabaseConfig.password}"
+ f"@{DatabaseConfig.host}:{DatabaseConfig.port}/{DatabaseConfig.database}"
+ )
+
+
+async def connect() -> None:
+ """Initiate a connection to the database."""
+ log.info("Initiating connection to the database")
+ await db.set_bind(build_db_uri())
+ log.info("Database connection established")
diff --git a/metricity/models.py b/metricity/models.py
new file mode 100644
index 0000000..73296ca
--- /dev/null
+++ b/metricity/models.py
@@ -0,0 +1,75 @@
+"""Database models used by Metricity for statistic collection."""
+
+from datetime import datetime
+from typing import Any, Dict, List
+
+from sqlalchemy.dialects.postgresql import insert
+
+from metricity.database import db
+
+
+class Category(db.Model):
+ """Database model representing a Discord category channel."""
+
+ __tablename__ = "categories"
+
+ id = db.Column(db.BigInteger, primary_key=True)
+ name = db.Column(db.String, nullable=False)
+
+
+class Channel(db.Model):
+ """Database model representing a Discord channel."""
+
+ __tablename__ = "channels"
+
+ id = db.Column(db.BigInteger, primary_key=True)
+ name = db.Column(db.String, nullable=False)
+ category_id = db.Column(db.BigInteger, nullable=True)
+ is_staff = db.Column(db.Boolean, nullable=False)
+
+
+class User(db.Model):
+ """Database model representing a Discord user."""
+
+ __tablename__ = "users"
+
+ id = db.Column(db.BigInteger, primary_key=True)
+ name = db.Column(db.String, nullable=False)
+ avatar_hash = db.Column(db.String, nullable=True)
+ joined_at = db.Column(db.DateTime, nullable=False)
+ created_at = db.Column(db.DateTime, nullable=False)
+ is_staff = db.Column(db.Boolean, nullable=False)
+ opt_out = db.Column(db.Boolean, default=False)
+ bot = db.Column(db.Boolean, default=False)
+
+ @classmethod
+ def bulk_upsert(cls: type, users: List[Dict[str, Any]]) -> Any:
+ qs = insert(cls.__table__).values(users)
+
+ update_cols = [
+ "name",
+ "avatar_hash",
+ "joined_at",
+ "is_staff",
+ "opt_out",
+ "bot"
+ ]
+
+ return qs.on_conflict_do_update(
+ index_elements=[cls.id],
+ set_={k: getattr(qs.excluded, k) for k in update_cols}
+ ).returning(cls.__table__).gino.all()
+
+
+class Message(db.Model):
+ """Database model representing a message sent in a Discord server."""
+
+ __tablename__ = "messages"
+
+ id = db.Column(db.BigInteger, primary_key=True)
+ channel_id = db.Column(
+ db.BigInteger,
+ db.ForeignKey("channels.id", ondelete="CASCADE"),
+ )
+ author_id = db.Column(db.BigInteger, db.ForeignKey("users.id", ondelete="CASCADE"))
+ created_at = db.Column(db.DateTime, default=datetime.utcnow)