diff options
| author | 2023-09-04 20:24:41 +0100 | |
|---|---|---|
| committer | 2023-09-04 20:24:41 +0100 | |
| commit | 002dff82e9990125b266744034ee374a2989bcfe (patch) | |
| tree | fe7d28aaf710815e75f16d87ac41922090fd5253 | |
| parent | Add SQLAlchemy 2 & asyncpg and remove GINO & psycopg2 (diff) | |
Update database & models modules to use SQLAlchemy 2
| -rw-r--r-- | metricity/database.py | 18 | ||||
| -rw-r--r-- | metricity/models.py | 131 | 
2 files changed, 52 insertions, 97 deletions
diff --git a/metricity/database.py b/metricity/database.py index 19ec1ff..4fdf465 100644 --- a/metricity/database.py +++ b/metricity/database.py @@ -3,34 +3,26 @@  import logging  from datetime import UTC, datetime -import gino  from sqlalchemy.engine import Dialect +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine  from sqlalchemy.types import DateTime, TypeDecorator  from metricity.config import DatabaseConfig  log = logging.getLogger(__name__) -db = gino.Gino() - -  def build_db_uri() -> str: -    """Use information from the config file to build a PostgreSQL URI.""" +    """Build the database uri from the config."""      if DatabaseConfig.uri:          return DatabaseConfig.uri      return ( -        f"postgresql://{DatabaseConfig.username}:{DatabaseConfig.password}" +        f"postgresql+asyncpg://{DatabaseConfig.username}:{DatabaseConfig.password}"          f"@{DatabaseConfig.host}:{DatabaseConfig.port}/{DatabaseConfig.database}"      ) - -async def connect() -> None: -    """Initiate a connection to the database.""" -    log.info("Initiating connection to the database") -    await db.set_bind(build_db_uri()) -    log.info("Database connection established") - +engine: AsyncEngine = create_async_engine(build_db_uri(), echo=DatabaseConfig.log_queries) +async_session = async_sessionmaker(engine, expire_on_commit=False)  class TZDateTime(TypeDecorator):      """ diff --git a/metricity/models.py b/metricity/models.py index c2ae0c7..c1ce2e0 100644 --- a/metricity/models.py +++ b/metricity/models.py @@ -1,116 +1,79 @@  """Database models used by Metricity for statistic collection."""  from datetime import UTC, datetime -from typing import Any +from typing import Any, Optional +from sqlalchemy import JSON, ForeignKey  from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from metricity.database import TZDateTime, db +from metricity.database import TZDateTime -class Category(db.Model): +class Base(DeclarativeBase): +    pass + +class Category(Base):      """Database model representing a Discord category channel."""      __tablename__ = "categories" -    id = db.Column(db.String, primary_key=True) -    name = db.Column(db.String, nullable=False) +    id: Mapped[str] = mapped_column(primary_key=True) +    name: Mapped[str] -class Channel(db.Model): +class Channel(Base):      """Database model representing a Discord channel."""      __tablename__ = "channels" -    id = db.Column(db.String, primary_key=True) -    name = db.Column(db.String, nullable=False) -    category_id = db.Column( -        db.String, -        db.ForeignKey("categories.id", ondelete="CASCADE"), -        nullable=True, -    ) -    is_staff = db.Column(db.Boolean, nullable=False) +    id: Mapped[str] = mapped_column(primary_key=True) +    name: Mapped[str] +    category_id: Mapped[str | None] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE")) +    is_staff: Mapped[bool] -class Thread(db.Model): +class Thread(Base):      """Database model representing a Thread channel."""      __tablename__ = "threads" -    id = db.Column(db.String, primary_key=True) -    parent_channel_id = db.Column( -        db.String, -        db.ForeignKey("channels.id", ondelete="CASCADE"), -        nullable=False, -    ) -    created_at = db.Column(TZDateTime(), default=datetime.now(UTC)) -    name = db.Column(db.String, nullable=False) -    archived = db.Column(db.Boolean, default=False, nullable=False) -    auto_archive_duration = db.Column(db.Integer, nullable=False) -    locked = db.Column(db.Boolean, default=False, nullable=False) -    type = db.Column(db.String, nullable=False, index=True) - - -class User(db.Model): +    id: Mapped[str] = mapped_column(primary_key=True) +    parent_channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE")) +    created_at = mapped_column(TZDateTime(), default=datetime.now(UTC)) +    name: Mapped[str] +    archived: Mapped[bool] +    auto_archive_duration: Mapped[int] +    locked: Mapped[bool] +    type: Mapped[str] = mapped_column(index=True) + + +class User(Base):      """Database model representing a Discord user."""      __tablename__ = "users" -    id = db.Column(db.String, primary_key=True) -    name = db.Column(db.String, nullable=False) -    avatar_hash = db.Column(db.String, nullable=True) -    guild_avatar_hash = db.Column(db.String, nullable=True) -    joined_at = db.Column(TZDateTime(), nullable=False) -    created_at = db.Column(TZDateTime(), nullable=False) -    is_staff = db.Column(db.Boolean, nullable=False) -    bot = db.Column(db.Boolean, default=False) -    in_guild = db.Column(db.Boolean, default=True) -    public_flags = db.Column(db.JSON, default={}) -    pending = db.Column(db.Boolean, default=False) - -    @classmethod -    def bulk_upsert(cls: type, users: list[dict[str, Any]]) -> Any:  # noqa: ANN401 -        """Perform a bulk insert/update of the database to sync the user table.""" -        qs = insert(cls.__table__).values(users) - -        update_cols = [ -            "name", -            "avatar_hash", -            "guild_avatar_hash", -            "joined_at", -            "is_staff", -            "bot", -            "in_guild", -            "public_flags", -            "pending", -        ] - -        return qs.on_conflict_do_update( -            index_elements=[cls.id], -            set_={k: getattr(qs.excluded, k) for k in update_cols}, -        ).returning(cls.__table__).gino.all() - - -class Message(db.Model): +    id: Mapped[str] = mapped_column(primary_key=True) +    name: Mapped[str] = mapped_column(nullable=False) +    avatar_hash: Mapped[str] = mapped_column(nullable=True) +    guild_avatar_hash: Mapped[str] = mapped_column(nullable=True) +    joined_at = mapped_column(TZDateTime(), nullable=False) +    created_at = mapped_column(TZDateTime(), nullable=False) +    is_staff: Mapped[bool] = mapped_column(nullable=False) +    bot: Mapped[bool] = mapped_column(default=False) +    in_guild: Mapped[bool] = mapped_column(default=False) +    public_flags = mapped_column(JSON, default={}) +    pending: Mapped[bool] = mapped_column(default=False) + + +class Message(Base):      """Database model representing a message sent in a Discord server."""      __tablename__ = "messages" -    id = db.Column(db.String, primary_key=True) -    channel_id = db.Column( -        db.String, -        db.ForeignKey("channels.id", ondelete="CASCADE"), -        index=True, -    ) -    thread_id = db.Column( -        db.String, -        db.ForeignKey("threads.id", ondelete="CASCADE"), -        index=True, -    ) -    author_id = db.Column( -        db.String, -        db.ForeignKey("users.id", ondelete="CASCADE"), -        index=True, -    ) -    created_at = db.Column(TZDateTime(), default=datetime.now(UTC)) -    is_deleted = db.Column(db.Boolean, default=False) +    id: Mapped[str] = mapped_column(primary_key=True) +    channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE"), index=True) +    thread_id: Mapped[str | None] = mapped_column(ForeignKey("threads.id", ondelete="CASCADE"), index=True) +    author_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) +    created_at = mapped_column(TZDateTime()) +    is_deleted: Mapped[bool] = mapped_column(default=False)  |