diff options
author | 2023-09-04 20:24:41 +0100 | |
---|---|---|
committer | 2023-09-04 20:24:41 +0100 | |
commit | 002dff82e9990125b266744034ee374a2989bcfe (patch) | |
tree | fe7d28aaf710815e75f16d87ac41922090fd5253 | |
parent | Add SQLAlchemy 2 & asyncpg and remove GINO & psycopg2 (diff) |
Update database & models modules to use SQLAlchemy 2
-rw-r--r-- | metricity/database.py | 18 | ||||
-rw-r--r-- | metricity/models.py | 131 |
2 files changed, 52 insertions, 97 deletions
diff --git a/metricity/database.py b/metricity/database.py index 19ec1ff..4fdf465 100644 --- a/metricity/database.py +++ b/metricity/database.py @@ -3,34 +3,26 @@ import logging from datetime import UTC, datetime -import gino from sqlalchemy.engine import Dialect +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.types import DateTime, TypeDecorator from metricity.config import DatabaseConfig log = logging.getLogger(__name__) -db = gino.Gino() - - def build_db_uri() -> str: - """Use information from the config file to build a PostgreSQL URI.""" + """Build the database uri from the config.""" if DatabaseConfig.uri: return DatabaseConfig.uri return ( - f"postgresql://{DatabaseConfig.username}:{DatabaseConfig.password}" + f"postgresql+asyncpg://{DatabaseConfig.username}:{DatabaseConfig.password}" f"@{DatabaseConfig.host}:{DatabaseConfig.port}/{DatabaseConfig.database}" ) - -async def connect() -> None: - """Initiate a connection to the database.""" - log.info("Initiating connection to the database") - await db.set_bind(build_db_uri()) - log.info("Database connection established") - +engine: AsyncEngine = create_async_engine(build_db_uri(), echo=DatabaseConfig.log_queries) +async_session = async_sessionmaker(engine, expire_on_commit=False) class TZDateTime(TypeDecorator): """ diff --git a/metricity/models.py b/metricity/models.py index c2ae0c7..c1ce2e0 100644 --- a/metricity/models.py +++ b/metricity/models.py @@ -1,116 +1,79 @@ """Database models used by Metricity for statistic collection.""" from datetime import UTC, datetime -from typing import Any +from typing import Any, Optional +from sqlalchemy import JSON, ForeignKey from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from metricity.database import TZDateTime, db +from metricity.database import TZDateTime -class Category(db.Model): +class Base(DeclarativeBase): + pass + +class Category(Base): """Database model representing a Discord category channel.""" __tablename__ = "categories" - id = db.Column(db.String, primary_key=True) - name = db.Column(db.String, nullable=False) + id: Mapped[str] = mapped_column(primary_key=True) + name: Mapped[str] -class Channel(db.Model): +class Channel(Base): """Database model representing a Discord channel.""" __tablename__ = "channels" - id = db.Column(db.String, primary_key=True) - name = db.Column(db.String, nullable=False) - category_id = db.Column( - db.String, - db.ForeignKey("categories.id", ondelete="CASCADE"), - nullable=True, - ) - is_staff = db.Column(db.Boolean, nullable=False) + id: Mapped[str] = mapped_column(primary_key=True) + name: Mapped[str] + category_id: Mapped[str | None] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE")) + is_staff: Mapped[bool] -class Thread(db.Model): +class Thread(Base): """Database model representing a Thread channel.""" __tablename__ = "threads" - id = db.Column(db.String, primary_key=True) - parent_channel_id = db.Column( - db.String, - db.ForeignKey("channels.id", ondelete="CASCADE"), - nullable=False, - ) - created_at = db.Column(TZDateTime(), default=datetime.now(UTC)) - name = db.Column(db.String, nullable=False) - archived = db.Column(db.Boolean, default=False, nullable=False) - auto_archive_duration = db.Column(db.Integer, nullable=False) - locked = db.Column(db.Boolean, default=False, nullable=False) - type = db.Column(db.String, nullable=False, index=True) - - -class User(db.Model): + id: Mapped[str] = mapped_column(primary_key=True) + parent_channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE")) + created_at = mapped_column(TZDateTime(), default=datetime.now(UTC)) + name: Mapped[str] + archived: Mapped[bool] + auto_archive_duration: Mapped[int] + locked: Mapped[bool] + type: Mapped[str] = mapped_column(index=True) + + +class User(Base): """Database model representing a Discord user.""" __tablename__ = "users" - id = db.Column(db.String, primary_key=True) - name = db.Column(db.String, nullable=False) - avatar_hash = db.Column(db.String, nullable=True) - guild_avatar_hash = db.Column(db.String, nullable=True) - joined_at = db.Column(TZDateTime(), nullable=False) - created_at = db.Column(TZDateTime(), nullable=False) - is_staff = db.Column(db.Boolean, nullable=False) - bot = db.Column(db.Boolean, default=False) - in_guild = db.Column(db.Boolean, default=True) - public_flags = db.Column(db.JSON, default={}) - pending = db.Column(db.Boolean, default=False) - - @classmethod - def bulk_upsert(cls: type, users: list[dict[str, Any]]) -> Any: # noqa: ANN401 - """Perform a bulk insert/update of the database to sync the user table.""" - qs = insert(cls.__table__).values(users) - - update_cols = [ - "name", - "avatar_hash", - "guild_avatar_hash", - "joined_at", - "is_staff", - "bot", - "in_guild", - "public_flags", - "pending", - ] - - return qs.on_conflict_do_update( - index_elements=[cls.id], - set_={k: getattr(qs.excluded, k) for k in update_cols}, - ).returning(cls.__table__).gino.all() - - -class Message(db.Model): + id: Mapped[str] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False) + avatar_hash: Mapped[str] = mapped_column(nullable=True) + guild_avatar_hash: Mapped[str] = mapped_column(nullable=True) + joined_at = mapped_column(TZDateTime(), nullable=False) + created_at = mapped_column(TZDateTime(), nullable=False) + is_staff: Mapped[bool] = mapped_column(nullable=False) + bot: Mapped[bool] = mapped_column(default=False) + in_guild: Mapped[bool] = mapped_column(default=False) + public_flags = mapped_column(JSON, default={}) + pending: Mapped[bool] = mapped_column(default=False) + + +class Message(Base): """Database model representing a message sent in a Discord server.""" __tablename__ = "messages" - id = db.Column(db.String, primary_key=True) - channel_id = db.Column( - db.String, - db.ForeignKey("channels.id", ondelete="CASCADE"), - index=True, - ) - thread_id = db.Column( - db.String, - db.ForeignKey("threads.id", ondelete="CASCADE"), - index=True, - ) - author_id = db.Column( - db.String, - db.ForeignKey("users.id", ondelete="CASCADE"), - index=True, - ) - created_at = db.Column(TZDateTime(), default=datetime.now(UTC)) - is_deleted = db.Column(db.Boolean, default=False) + id: Mapped[str] = mapped_column(primary_key=True) + channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE"), index=True) + thread_id: Mapped[str | None] = mapped_column(ForeignKey("threads.id", ondelete="CASCADE"), index=True) + author_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + created_at = mapped_column(TZDateTime()) + is_deleted: Mapped[bool] = mapped_column(default=False) |