aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2023-09-04 20:25:00 +0100
committerGravatar Joe Banks <[email protected]>2023-09-04 20:25:00 +0100
commit79b27945e0c2f1fd7adede76d0c50a92127363d6 (patch)
tree7e846ef195fb29c784a096d75155226074845f5c
parentUpdate database & models modules to use SQLAlchemy 2 (diff)
Update listeners and utils to use new SQLAlchemy 2 models
-rw-r--r--metricity/exts/event_listeners/_utils.py15
-rw-r--r--metricity/exts/event_listeners/guild_listeners.py143
-rw-r--r--metricity/exts/event_listeners/member_listeners.py120
-rw-r--r--metricity/exts/event_listeners/message_listeners.py30
4 files changed, 179 insertions, 129 deletions
diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py
index 6b2aacf..69d44ab 100644
--- a/metricity/exts/event_listeners/_utils.py
+++ b/metricity/exts/event_listeners/_utils.py
@@ -1,11 +1,12 @@
import discord
+from sqlalchemy.ext.asyncio import AsyncSession
from metricity import models
-async def insert_thread(thread: discord.Thread) -> None:
- """Insert the given thread to the database."""
- await models.Thread.create(
+def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None:
+ """Insert the given thread to the database session."""
+ sess.add(models.Thread(
id=str(thread.id),
parent_channel_id=str(thread.parent_id),
name=thread.name,
@@ -13,12 +14,12 @@ async def insert_thread(thread: discord.Thread) -> None:
auto_archive_duration=thread.auto_archive_duration,
locked=thread.locked,
type=thread.type.name,
- )
+ ))
-async def sync_message(message: discord.Message, *, from_thread: bool) -> None:
+async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None:
"""Sync the given message with the database."""
- if await models.Message.get(str(message.id)):
+ if await sess.get(models.Message, str(message.id)):
return
args = {
@@ -33,4 +34,4 @@ async def sync_message(message: discord.Message, *, from_thread: bool) -> None:
args["channel_id"] = str(thread.parent_id)
args["thread_id"] = str(thread.id)
- await models.Message.create(**args)
+ sess.add(models.Message(**args))
diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py
index c7c074f..9ad0bda 100644
--- a/metricity/exts/event_listeners/guild_listeners.py
+++ b/metricity/exts/event_listeners/guild_listeners.py
@@ -3,16 +3,17 @@
import discord
from discord.ext import commands
from pydis_core.utils import logging, scheduling
+from sqlalchemy import update
+from sqlalchemy.dialects.postgresql import insert
from metricity import models
from metricity.bot import Bot
from metricity.config import BotConfig
-from metricity.database import db
+from metricity.database import async_session
from metricity.exts.event_listeners import _utils
log = logging.get_logger(__name__)
-
class GuildListeners(commands.Cog):
"""Listen for guild (and guild channel) events and sync them to the database."""
@@ -31,7 +32,9 @@ class GuildListeners(commands.Cog):
await self.sync_thread_archive_state(guild)
log.info("Beginning user synchronisation process")
- await models.User.update.values(in_guild=False).gino.status()
+ async with async_session() as sess:
+ await sess.execute(update(models.User).values(in_guild=False))
+ await sess.commit()
users = [
{
@@ -54,9 +57,29 @@ class GuildListeners(commands.Cog):
user_chunks = discord.utils.as_chunks(users, 500)
- for chunk in user_chunks:
- log.info("Upserting chunk of %d", len(chunk))
- await models.User.bulk_upsert(chunk)
+ async with async_session() as sess:
+ for chunk in user_chunks:
+ log.info("Upserting chunk of %d", len(chunk))
+ qs = insert(models.User).values(chunk)
+
+ update_cols = [
+ "name",
+ "avatar_hash",
+ "guild_avatar_hash",
+ "joined_at",
+ "is_staff",
+ "bot",
+ "in_guild",
+ "public_flags",
+ "pending",
+ ]
+
+ await sess.execute(qs.on_conflict_do_update(
+ index_elements=[models.User.id],
+ set_={k: getattr(qs.excluded, k) for k in update_cols},
+ ))
+
+ await sess.commit()
log.info("User upsert complete")
@@ -66,9 +89,19 @@ class GuildListeners(commands.Cog):
async def sync_thread_archive_state(guild: discord.Guild) -> None:
"""Sync the archive state of all threads in the database with the state in guild."""
active_thread_ids = [str(thread.id) for thread in guild.threads]
- async with db.transaction() as tx:
- async for db_thread in tx.connection.iterate(models.Thread.query):
- await db_thread.update(archived=db_thread.id not in active_thread_ids).apply()
+
+ async with async_session() as sess:
+ await sess.execute(
+ update(models.Thread)
+ .where(models.Thread.id.in_(active_thread_ids))
+ .values(archived=False),
+ )
+ await sess.execute(
+ update(models.Thread)
+ .where(~models.Thread.id.in_(active_thread_ids))
+ .values(archived=True),
+ )
+ await sess.commit()
async def sync_channels(self, guild: discord.Guild) -> None:
"""Sync channels and categories with the database."""
@@ -76,59 +109,61 @@ class GuildListeners(commands.Cog):
log.info("Beginning category synchronisation process")
- for channel in guild.channels:
- if isinstance(channel, discord.CategoryChannel):
- if db_cat := await models.Category.get(str(channel.id)):
- await db_cat.update(name=channel.name).apply()
- else:
- await models.Category.create(id=str(channel.id), name=channel.name)
+ async with async_session() as sess:
+ for channel in guild.channels:
+ if isinstance(channel, discord.CategoryChannel):
+ if existing_cat := await sess.get(models.Category, str(channel.id)):
+ existing_cat.name = channel.name
+ else:
+ sess.add(models.Category(id=str(channel.id), name=channel.name))
+
+ await sess.commit()
log.info("Category synchronisation process complete, synchronising channels")
- for channel in guild.channels:
- if channel.category and channel.category.id in BotConfig.ignore_categories:
- continue
-
- if not isinstance(channel, discord.CategoryChannel):
- category_id = str(channel.category.id) if channel.category else None
- # Cast to bool so is_staff is False if channel.category is None
- is_staff = channel.id in BotConfig.staff_channels or bool(
- channel.category and channel.category.id in BotConfig.staff_categories,
- )
- if db_chan := await models.Channel.get(str(channel.id)):
- await db_chan.update(
- name=channel.name,
- category_id=category_id,
- is_staff=is_staff,
- ).apply()
- else:
- await models.Channel.create(
- id=str(channel.id),
- name=channel.name,
- category_id=category_id,
- is_staff=is_staff,
+ async with async_session() as sess:
+ for channel in guild.channels:
+ if channel.category and channel.category.id in BotConfig.ignore_categories:
+ continue
+
+ if not isinstance(channel, discord.CategoryChannel):
+ category_id = str(channel.category.id) if channel.category else None
+ # Cast to bool so is_staff is False if channel.category is None
+ is_staff = channel.id in BotConfig.staff_channels or bool(
+ channel.category and channel.category.id in BotConfig.staff_categories,
)
+ if db_chan := await sess.get(models.Channel, str(channel.id)):
+ db_chan.name = channel.name
+ else:
+ sess.add(models.Channel(
+ id=str(channel.id),
+ name=channel.name,
+ category_id=category_id,
+ is_staff=is_staff,
+ ))
+
+ await sess.commit()
log.info("Channel synchronisation process complete, synchronising threads")
- for thread in guild.threads:
- if thread.parent and thread.parent.category:
- if thread.parent.category.id in BotConfig.ignore_categories:
+ async with async_session() as sess:
+ for thread in guild.threads:
+ if thread.parent and thread.parent.category:
+ if thread.parent.category.id in BotConfig.ignore_categories:
+ continue
+ else:
+ # This is a forum channel, not currently supported by Discord.py. Ignore it.
continue
- else:
- # This is a forum channel, not currently supported by Discord.py. Ignore it.
- continue
-
- if db_thread := await models.Thread.get(str(thread.id)):
- await db_thread.update(
- name=thread.name,
- archived=thread.archived,
- auto_archive_duration=thread.auto_archive_duration,
- locked=thread.locked,
- type=thread.type.name,
- ).apply()
- else:
- await _utils.insert_thread(thread)
+
+ if db_thread := await sess.get(models.Thread, str(thread.id)):
+ db_thread.name = thread.name
+ db_thread.archived = thread.archived
+ db_thread.auto_archive_duration = thread.auto_archive_duration
+ db_thread.locked = thread.locked
+ db_thread.type = thread.type.name
+ else:
+ _utils.insert_thread(thread, sess)
+ await sess.commit()
log.info("Thread synchronisation process complete, finished synchronising guild.")
self.bot.channel_sync_in_progress.set()
diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py
index ddf5954..dc1e3c1 100644
--- a/metricity/exts/event_listeners/member_listeners.py
+++ b/metricity/exts/event_listeners/member_listeners.py
@@ -5,9 +5,11 @@ import contextlib
import discord
from asyncpg.exceptions import UniqueViolationError
from discord.ext import commands
+from sqlalchemy import update
from metricity.bot import Bot
from metricity.config import BotConfig
+from metricity.database import async_session
from metricity.models import User
@@ -25,10 +27,11 @@ class MemberListeners(commands.Cog):
if member.guild.id != BotConfig.guild_id:
return
- if db_user := await User.get(str(member.id)):
- await db_user.update(
- in_guild=False,
- ).apply()
+ async with async_session() as sess:
+ await sess.execute(
+ update(User).where(User.id == str(member.id)).values(in_guild=False),
+ )
+ await sess.commit()
@commands.Cog.listener()
async def on_member_join(self, member: discord.Member) -> None:
@@ -38,22 +41,9 @@ class MemberListeners(commands.Cog):
if member.guild.id != BotConfig.guild_id:
return
- if db_user := await User.get(str(member.id)):
- await db_user.update(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
- public_flags=dict(member.public_flags),
- pending=member.pending,
- in_guild=True,
- ).apply()
- else:
- with contextlib.suppress(UniqueViolationError):
- await User.create(
+ async with async_session() as sess:
+ if await sess.get(User, str(member.id)):
+ await sess.execute(update(User).where(User.id == str(member.id)).values(
id=str(member.id),
name=member.name,
avatar_hash=getattr(member.avatar, "key", None),
@@ -64,7 +54,23 @@ class MemberListeners(commands.Cog):
public_flags=dict(member.public_flags),
pending=member.pending,
in_guild=True,
- )
+ ))
+ else:
+ with contextlib.suppress(UniqueViolationError):
+ await sess.add(User(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
+ public_flags=dict(member.public_flags),
+ pending=member.pending,
+ in_guild=True,
+ ))
+
+ await sess.commit()
@commands.Cog.listener()
async def on_member_update(self, _before: discord.Member, member: discord.Member) -> None:
@@ -80,41 +86,43 @@ class MemberListeners(commands.Cog):
roles = {role.id for role in member.roles}
- if db_user := await User.get(str(member.id)):
- if (
- db_user.name != member.name or
- db_user.avatar_hash != getattr(member.avatar, "key", None) or
- db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or
- BotConfig.staff_role_id in
- [role.id for role in member.roles] != db_user.is_staff
- or db_user.pending is not member.pending
- ):
- await db_user.update(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in roles,
- public_flags=dict(member.public_flags),
- in_guild=True,
- pending=member.pending,
- ).apply()
- else:
- with contextlib.suppress(UniqueViolationError):
- await User.create(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in roles,
- public_flags=dict(member.public_flags),
- in_guild=True,
- pending=member.pending,
- )
+ async with async_session() as sess:
+ if db_user := await sess.get(User, str(member.id)):
+ if (
+ db_user.name != member.name or
+ db_user.avatar_hash != getattr(member.avatar, "key", None) or
+ db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or
+ (BotConfig.staff_role_id in roles) != db_user.is_staff
+ or db_user.pending is not member.pending
+ ):
+ await sess.execute(update(User).where(User.id == str(member.id)).values(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in roles,
+ public_flags=dict(member.public_flags),
+ in_guild=True,
+ pending=member.pending,
+ ))
+ else:
+ with contextlib.suppress(UniqueViolationError):
+ sess.add(User(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in roles,
+ public_flags=dict(member.public_flags),
+ in_guild=True,
+ pending=member.pending,
+ ))
+
+ await sess.commit()
diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py
index b446e26..28329d0 100644
--- a/metricity/exts/event_listeners/message_listeners.py
+++ b/metricity/exts/event_listeners/message_listeners.py
@@ -2,9 +2,11 @@
import discord
from discord.ext import commands
+from sqlalchemy import update
from metricity.bot import Bot
from metricity.config import BotConfig
+from metricity.database import async_session
from metricity.exts.event_listeners import _utils
from metricity.models import Message, User
@@ -33,28 +35,32 @@ class MessageListeners(commands.Cog):
await self.bot.sync_process_complete.wait()
await self.bot.channel_sync_in_progress.wait()
- if not await User.get(str(message.author.id)):
- return
+ async with async_session() as sess:
+ if not await sess.get(User, str(message.author.id)):
+ return
- cat_id = message.channel.category.id if message.channel.category else None
- if cat_id in BotConfig.ignore_categories:
- return
+ cat_id = message.channel.category.id if message.channel.category else None
+ if cat_id in BotConfig.ignore_categories:
+ return
+
+ from_thread = isinstance(message.channel, discord.Thread)
+ await _utils.sync_message(message, sess, from_thread=from_thread)
- from_thread = isinstance(message.channel, discord.Thread)
- await _utils.sync_message(message, from_thread=from_thread)
+ await sess.commit()
@commands.Cog.listener()
async def on_raw_message_delete(self, message: discord.RawMessageDeleteEvent) -> None:
"""If a message is deleted and we have a record of it set the is_deleted flag."""
- if message := await Message.get(str(message.message_id)):
- await message.update(is_deleted=True).apply()
+ async with async_session() as sess:
+ await sess.execute(update(Message).where(Message.id == str(message.message_id)).values(is_deleted=True))
+ await sess.commit()
@commands.Cog.listener()
async def on_raw_bulk_message_delete(self, messages: discord.RawBulkMessageDeleteEvent) -> None:
"""If messages are deleted in bulk and we have a record of them set the is_deleted flag."""
- for message_id in messages.message_ids:
- if message := await Message.get(str(message_id)):
- await message.update(is_deleted=True).apply()
+ async with async_session() as sess:
+ await sess.execute(update(Message).where(Message.id.in_(messages.message_ids)).values(is_deleted=True))
+ await sess.commit()
async def setup(bot: Bot) -> None: