aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2023-09-04 20:24:41 +0100
committerGravatar Joe Banks <[email protected]>2023-09-04 20:24:41 +0100
commit002dff82e9990125b266744034ee374a2989bcfe (patch)
treefe7d28aaf710815e75f16d87ac41922090fd5253
parentAdd SQLAlchemy 2 & asyncpg and remove GINO & psycopg2 (diff)
Update database & models modules to use SQLAlchemy 2
-rw-r--r--metricity/database.py18
-rw-r--r--metricity/models.py131
2 files changed, 52 insertions, 97 deletions
diff --git a/metricity/database.py b/metricity/database.py
index 19ec1ff..4fdf465 100644
--- a/metricity/database.py
+++ b/metricity/database.py
@@ -3,34 +3,26 @@
import logging
from datetime import UTC, datetime
-import gino
from sqlalchemy.engine import Dialect
+from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.types import DateTime, TypeDecorator
from metricity.config import DatabaseConfig
log = logging.getLogger(__name__)
-db = gino.Gino()
-
-
def build_db_uri() -> str:
- """Use information from the config file to build a PostgreSQL URI."""
+ """Build the database uri from the config."""
if DatabaseConfig.uri:
return DatabaseConfig.uri
return (
- f"postgresql://{DatabaseConfig.username}:{DatabaseConfig.password}"
+ f"postgresql+asyncpg://{DatabaseConfig.username}:{DatabaseConfig.password}"
f"@{DatabaseConfig.host}:{DatabaseConfig.port}/{DatabaseConfig.database}"
)
-
-async def connect() -> None:
- """Initiate a connection to the database."""
- log.info("Initiating connection to the database")
- await db.set_bind(build_db_uri())
- log.info("Database connection established")
-
+engine: AsyncEngine = create_async_engine(build_db_uri(), echo=DatabaseConfig.log_queries)
+async_session = async_sessionmaker(engine, expire_on_commit=False)
class TZDateTime(TypeDecorator):
"""
diff --git a/metricity/models.py b/metricity/models.py
index c2ae0c7..c1ce2e0 100644
--- a/metricity/models.py
+++ b/metricity/models.py
@@ -1,116 +1,79 @@
"""Database models used by Metricity for statistic collection."""
from datetime import UTC, datetime
-from typing import Any
+from typing import Any, Optional
+from sqlalchemy import JSON, ForeignKey
from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
-from metricity.database import TZDateTime, db
+from metricity.database import TZDateTime
-class Category(db.Model):
+class Base(DeclarativeBase):
+ pass
+
+class Category(Base):
"""Database model representing a Discord category channel."""
__tablename__ = "categories"
- id = db.Column(db.String, primary_key=True)
- name = db.Column(db.String, nullable=False)
+ id: Mapped[str] = mapped_column(primary_key=True)
+ name: Mapped[str]
-class Channel(db.Model):
+class Channel(Base):
"""Database model representing a Discord channel."""
__tablename__ = "channels"
- id = db.Column(db.String, primary_key=True)
- name = db.Column(db.String, nullable=False)
- category_id = db.Column(
- db.String,
- db.ForeignKey("categories.id", ondelete="CASCADE"),
- nullable=True,
- )
- is_staff = db.Column(db.Boolean, nullable=False)
+ id: Mapped[str] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ category_id: Mapped[str | None] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE"))
+ is_staff: Mapped[bool]
-class Thread(db.Model):
+class Thread(Base):
"""Database model representing a Thread channel."""
__tablename__ = "threads"
- id = db.Column(db.String, primary_key=True)
- parent_channel_id = db.Column(
- db.String,
- db.ForeignKey("channels.id", ondelete="CASCADE"),
- nullable=False,
- )
- created_at = db.Column(TZDateTime(), default=datetime.now(UTC))
- name = db.Column(db.String, nullable=False)
- archived = db.Column(db.Boolean, default=False, nullable=False)
- auto_archive_duration = db.Column(db.Integer, nullable=False)
- locked = db.Column(db.Boolean, default=False, nullable=False)
- type = db.Column(db.String, nullable=False, index=True)
-
-
-class User(db.Model):
+ id: Mapped[str] = mapped_column(primary_key=True)
+ parent_channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE"))
+ created_at = mapped_column(TZDateTime(), default=datetime.now(UTC))
+ name: Mapped[str]
+ archived: Mapped[bool]
+ auto_archive_duration: Mapped[int]
+ locked: Mapped[bool]
+ type: Mapped[str] = mapped_column(index=True)
+
+
+class User(Base):
"""Database model representing a Discord user."""
__tablename__ = "users"
- id = db.Column(db.String, primary_key=True)
- name = db.Column(db.String, nullable=False)
- avatar_hash = db.Column(db.String, nullable=True)
- guild_avatar_hash = db.Column(db.String, nullable=True)
- joined_at = db.Column(TZDateTime(), nullable=False)
- created_at = db.Column(TZDateTime(), nullable=False)
- is_staff = db.Column(db.Boolean, nullable=False)
- bot = db.Column(db.Boolean, default=False)
- in_guild = db.Column(db.Boolean, default=True)
- public_flags = db.Column(db.JSON, default={})
- pending = db.Column(db.Boolean, default=False)
-
- @classmethod
- def bulk_upsert(cls: type, users: list[dict[str, Any]]) -> Any: # noqa: ANN401
- """Perform a bulk insert/update of the database to sync the user table."""
- qs = insert(cls.__table__).values(users)
-
- update_cols = [
- "name",
- "avatar_hash",
- "guild_avatar_hash",
- "joined_at",
- "is_staff",
- "bot",
- "in_guild",
- "public_flags",
- "pending",
- ]
-
- return qs.on_conflict_do_update(
- index_elements=[cls.id],
- set_={k: getattr(qs.excluded, k) for k in update_cols},
- ).returning(cls.__table__).gino.all()
-
-
-class Message(db.Model):
+ id: Mapped[str] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column(nullable=False)
+ avatar_hash: Mapped[str] = mapped_column(nullable=True)
+ guild_avatar_hash: Mapped[str] = mapped_column(nullable=True)
+ joined_at = mapped_column(TZDateTime(), nullable=False)
+ created_at = mapped_column(TZDateTime(), nullable=False)
+ is_staff: Mapped[bool] = mapped_column(nullable=False)
+ bot: Mapped[bool] = mapped_column(default=False)
+ in_guild: Mapped[bool] = mapped_column(default=False)
+ public_flags = mapped_column(JSON, default={})
+ pending: Mapped[bool] = mapped_column(default=False)
+
+
+class Message(Base):
"""Database model representing a message sent in a Discord server."""
__tablename__ = "messages"
- id = db.Column(db.String, primary_key=True)
- channel_id = db.Column(
- db.String,
- db.ForeignKey("channels.id", ondelete="CASCADE"),
- index=True,
- )
- thread_id = db.Column(
- db.String,
- db.ForeignKey("threads.id", ondelete="CASCADE"),
- index=True,
- )
- author_id = db.Column(
- db.String,
- db.ForeignKey("users.id", ondelete="CASCADE"),
- index=True,
- )
- created_at = db.Column(TZDateTime(), default=datetime.now(UTC))
- is_deleted = db.Column(db.Boolean, default=False)
+ id: Mapped[str] = mapped_column(primary_key=True)
+ channel_id: Mapped[str] = mapped_column(ForeignKey("channels.id", ondelete="CASCADE"), index=True)
+ thread_id: Mapped[str | None] = mapped_column(ForeignKey("threads.id", ondelete="CASCADE"), index=True)
+ author_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
+ created_at = mapped_column(TZDateTime())
+ is_deleted: Mapped[bool] = mapped_column(default=False)