diff options
author | 2020-08-25 13:23:38 +0100 | |
---|---|---|
committer | 2020-08-25 13:23:38 +0100 | |
commit | e62918abb9d8722c4c0388f3c9fdf9f61b82b572 (patch) | |
tree | 6635eeaf68f215a565de2248f68a1ec727f725a9 | |
parent | Add Alembic (diff) |
Add database
-rw-r--r-- | metricity/database.py | 28 | ||||
-rw-r--r-- | metricity/models.py | 75 |
2 files changed, 103 insertions, 0 deletions
diff --git a/metricity/database.py b/metricity/database.py new file mode 100644 index 0000000..1534e2c --- /dev/null +++ b/metricity/database.py @@ -0,0 +1,28 @@ +"""Methods for connecting and interacting with the database.""" +import logging + +import gino + +from metricity.config import DatabaseConfig + +log = logging.getLogger(__name__) + +db = gino.Gino() + + +def build_db_uri() -> str: + """Use information from the config file to build a PostgreSQL URI.""" + if DatabaseConfig.uri: + return DatabaseConfig.uri + + return ( + f"postgresql://{DatabaseConfig.username}:{DatabaseConfig.password}" + f"@{DatabaseConfig.host}:{DatabaseConfig.port}/{DatabaseConfig.database}" + ) + + +async def connect() -> None: + """Initiate a connection to the database.""" + log.info("Initiating connection to the database") + await db.set_bind(build_db_uri()) + log.info("Database connection established") diff --git a/metricity/models.py b/metricity/models.py new file mode 100644 index 0000000..73296ca --- /dev/null +++ b/metricity/models.py @@ -0,0 +1,75 @@ +"""Database models used by Metricity for statistic collection.""" + +from datetime import datetime +from typing import Any, Dict, List + +from sqlalchemy.dialects.postgresql import insert + +from metricity.database import db + + +class Category(db.Model): + """Database model representing a Discord category channel.""" + + __tablename__ = "categories" + + id = db.Column(db.BigInteger, primary_key=True) + name = db.Column(db.String, nullable=False) + + +class Channel(db.Model): + """Database model representing a Discord channel.""" + + __tablename__ = "channels" + + id = db.Column(db.BigInteger, primary_key=True) + name = db.Column(db.String, nullable=False) + category_id = db.Column(db.BigInteger, nullable=True) + is_staff = db.Column(db.Boolean, nullable=False) + + +class User(db.Model): + """Database model representing a Discord user.""" + + __tablename__ = "users" + + id = db.Column(db.BigInteger, primary_key=True) + name = db.Column(db.String, nullable=False) + avatar_hash = db.Column(db.String, nullable=True) + joined_at = db.Column(db.DateTime, nullable=False) + created_at = db.Column(db.DateTime, nullable=False) + is_staff = db.Column(db.Boolean, nullable=False) + opt_out = db.Column(db.Boolean, default=False) + bot = db.Column(db.Boolean, default=False) + + @classmethod + def bulk_upsert(cls: type, users: List[Dict[str, Any]]) -> Any: + qs = insert(cls.__table__).values(users) + + update_cols = [ + "name", + "avatar_hash", + "joined_at", + "is_staff", + "opt_out", + "bot" + ] + + return qs.on_conflict_do_update( + index_elements=[cls.id], + set_={k: getattr(qs.excluded, k) for k in update_cols} + ).returning(cls.__table__).gino.all() + + +class Message(db.Model): + """Database model representing a message sent in a Discord server.""" + + __tablename__ = "messages" + + id = db.Column(db.BigInteger, primary_key=True) + channel_id = db.Column( + db.BigInteger, + db.ForeignKey("channels.id", ondelete="CASCADE"), + ) + author_id = db.Column(db.BigInteger, db.ForeignKey("users.id", ondelete="CASCADE")) + created_at = db.Column(db.DateTime, default=datetime.utcnow) |