diff options
author | 2023-09-04 20:25:00 +0100 | |
---|---|---|
committer | 2023-09-04 20:25:00 +0100 | |
commit | 79b27945e0c2f1fd7adede76d0c50a92127363d6 (patch) | |
tree | 7e846ef195fb29c784a096d75155226074845f5c | |
parent | Update database & models modules to use SQLAlchemy 2 (diff) |
Update listeners and utils to use new SQLAlchemy 2 models
-rw-r--r-- | metricity/exts/event_listeners/_utils.py | 15 | ||||
-rw-r--r-- | metricity/exts/event_listeners/guild_listeners.py | 143 | ||||
-rw-r--r-- | metricity/exts/event_listeners/member_listeners.py | 120 | ||||
-rw-r--r-- | metricity/exts/event_listeners/message_listeners.py | 30 |
4 files changed, 179 insertions, 129 deletions
diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py index 6b2aacf..69d44ab 100644 --- a/metricity/exts/event_listeners/_utils.py +++ b/metricity/exts/event_listeners/_utils.py @@ -1,11 +1,12 @@ import discord +from sqlalchemy.ext.asyncio import AsyncSession from metricity import models -async def insert_thread(thread: discord.Thread) -> None: - """Insert the given thread to the database.""" - await models.Thread.create( +def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: + """Insert the given thread to the database session.""" + sess.add(models.Thread( id=str(thread.id), parent_channel_id=str(thread.parent_id), name=thread.name, @@ -13,12 +14,12 @@ async def insert_thread(thread: discord.Thread) -> None: auto_archive_duration=thread.auto_archive_duration, locked=thread.locked, type=thread.type.name, - ) + )) -async def sync_message(message: discord.Message, *, from_thread: bool) -> None: +async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: """Sync the given message with the database.""" - if await models.Message.get(str(message.id)): + if await sess.get(models.Message, str(message.id)): return args = { @@ -33,4 +34,4 @@ async def sync_message(message: discord.Message, *, from_thread: bool) -> None: args["channel_id"] = str(thread.parent_id) args["thread_id"] = str(thread.id) - await models.Message.create(**args) + sess.add(models.Message(**args)) diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py index c7c074f..9ad0bda 100644 --- a/metricity/exts/event_listeners/guild_listeners.py +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -3,16 +3,17 @@ import discord from discord.ext import commands from pydis_core.utils import logging, scheduling +from sqlalchemy import update +from sqlalchemy.dialects.postgresql import insert from metricity import models from metricity.bot import Bot from metricity.config import BotConfig -from metricity.database import db +from metricity.database import async_session from metricity.exts.event_listeners import _utils log = logging.get_logger(__name__) - class GuildListeners(commands.Cog): """Listen for guild (and guild channel) events and sync them to the database.""" @@ -31,7 +32,9 @@ class GuildListeners(commands.Cog): await self.sync_thread_archive_state(guild) log.info("Beginning user synchronisation process") - await models.User.update.values(in_guild=False).gino.status() + async with async_session() as sess: + await sess.execute(update(models.User).values(in_guild=False)) + await sess.commit() users = [ { @@ -54,9 +57,29 @@ class GuildListeners(commands.Cog): user_chunks = discord.utils.as_chunks(users, 500) - for chunk in user_chunks: - log.info("Upserting chunk of %d", len(chunk)) - await models.User.bulk_upsert(chunk) + async with async_session() as sess: + for chunk in user_chunks: + log.info("Upserting chunk of %d", len(chunk)) + qs = insert(models.User).values(chunk) + + update_cols = [ + "name", + "avatar_hash", + "guild_avatar_hash", + "joined_at", + "is_staff", + "bot", + "in_guild", + "public_flags", + "pending", + ] + + await sess.execute(qs.on_conflict_do_update( + index_elements=[models.User.id], + set_={k: getattr(qs.excluded, k) for k in update_cols}, + )) + + await sess.commit() log.info("User upsert complete") @@ -66,9 +89,19 @@ class GuildListeners(commands.Cog): async def sync_thread_archive_state(guild: discord.Guild) -> None: """Sync the archive state of all threads in the database with the state in guild.""" active_thread_ids = [str(thread.id) for thread in guild.threads] - async with db.transaction() as tx: - async for db_thread in tx.connection.iterate(models.Thread.query): - await db_thread.update(archived=db_thread.id not in active_thread_ids).apply() + + async with async_session() as sess: + await sess.execute( + update(models.Thread) + .where(models.Thread.id.in_(active_thread_ids)) + .values(archived=False), + ) + await sess.execute( + update(models.Thread) + .where(~models.Thread.id.in_(active_thread_ids)) + .values(archived=True), + ) + await sess.commit() async def sync_channels(self, guild: discord.Guild) -> None: """Sync channels and categories with the database.""" @@ -76,59 +109,61 @@ class GuildListeners(commands.Cog): log.info("Beginning category synchronisation process") - for channel in guild.channels: - if isinstance(channel, discord.CategoryChannel): - if db_cat := await models.Category.get(str(channel.id)): - await db_cat.update(name=channel.name).apply() - else: - await models.Category.create(id=str(channel.id), name=channel.name) + async with async_session() as sess: + for channel in guild.channels: + if isinstance(channel, discord.CategoryChannel): + if existing_cat := await sess.get(models.Category, str(channel.id)): + existing_cat.name = channel.name + else: + sess.add(models.Category(id=str(channel.id), name=channel.name)) + + await sess.commit() log.info("Category synchronisation process complete, synchronising channels") - for channel in guild.channels: - if channel.category and channel.category.id in BotConfig.ignore_categories: - continue - - if not isinstance(channel, discord.CategoryChannel): - category_id = str(channel.category.id) if channel.category else None - # Cast to bool so is_staff is False if channel.category is None - is_staff = channel.id in BotConfig.staff_channels or bool( - channel.category and channel.category.id in BotConfig.staff_categories, - ) - if db_chan := await models.Channel.get(str(channel.id)): - await db_chan.update( - name=channel.name, - category_id=category_id, - is_staff=is_staff, - ).apply() - else: - await models.Channel.create( - id=str(channel.id), - name=channel.name, - category_id=category_id, - is_staff=is_staff, + async with async_session() as sess: + for channel in guild.channels: + if channel.category and channel.category.id in BotConfig.ignore_categories: + continue + + if not isinstance(channel, discord.CategoryChannel): + category_id = str(channel.category.id) if channel.category else None + # Cast to bool so is_staff is False if channel.category is None + is_staff = channel.id in BotConfig.staff_channels or bool( + channel.category and channel.category.id in BotConfig.staff_categories, ) + if db_chan := await sess.get(models.Channel, str(channel.id)): + db_chan.name = channel.name + else: + sess.add(models.Channel( + id=str(channel.id), + name=channel.name, + category_id=category_id, + is_staff=is_staff, + )) + + await sess.commit() log.info("Channel synchronisation process complete, synchronising threads") - for thread in guild.threads: - if thread.parent and thread.parent.category: - if thread.parent.category.id in BotConfig.ignore_categories: + async with async_session() as sess: + for thread in guild.threads: + if thread.parent and thread.parent.category: + if thread.parent.category.id in BotConfig.ignore_categories: + continue + else: + # This is a forum channel, not currently supported by Discord.py. Ignore it. continue - else: - # This is a forum channel, not currently supported by Discord.py. Ignore it. - continue - - if db_thread := await models.Thread.get(str(thread.id)): - await db_thread.update( - name=thread.name, - archived=thread.archived, - auto_archive_duration=thread.auto_archive_duration, - locked=thread.locked, - type=thread.type.name, - ).apply() - else: - await _utils.insert_thread(thread) + + if db_thread := await sess.get(models.Thread, str(thread.id)): + db_thread.name = thread.name + db_thread.archived = thread.archived + db_thread.auto_archive_duration = thread.auto_archive_duration + db_thread.locked = thread.locked + db_thread.type = thread.type.name + else: + _utils.insert_thread(thread, sess) + await sess.commit() log.info("Thread synchronisation process complete, finished synchronising guild.") self.bot.channel_sync_in_progress.set() diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py index ddf5954..dc1e3c1 100644 --- a/metricity/exts/event_listeners/member_listeners.py +++ b/metricity/exts/event_listeners/member_listeners.py @@ -5,9 +5,11 @@ import contextlib import discord from asyncpg.exceptions import UniqueViolationError from discord.ext import commands +from sqlalchemy import update from metricity.bot import Bot from metricity.config import BotConfig +from metricity.database import async_session from metricity.models import User @@ -25,10 +27,11 @@ class MemberListeners(commands.Cog): if member.guild.id != BotConfig.guild_id: return - if db_user := await User.get(str(member.id)): - await db_user.update( - in_guild=False, - ).apply() + async with async_session() as sess: + await sess.execute( + update(User).where(User.id == str(member.id)).values(in_guild=False), + ) + await sess.commit() @commands.Cog.listener() async def on_member_join(self, member: discord.Member) -> None: @@ -38,22 +41,9 @@ class MemberListeners(commands.Cog): if member.guild.id != BotConfig.guild_id: return - if db_user := await User.get(str(member.id)): - await db_user.update( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], - public_flags=dict(member.public_flags), - pending=member.pending, - in_guild=True, - ).apply() - else: - with contextlib.suppress(UniqueViolationError): - await User.create( + async with async_session() as sess: + if await sess.get(User, str(member.id)): + await sess.execute(update(User).where(User.id == str(member.id)).values( id=str(member.id), name=member.name, avatar_hash=getattr(member.avatar, "key", None), @@ -64,7 +54,23 @@ class MemberListeners(commands.Cog): public_flags=dict(member.public_flags), pending=member.pending, in_guild=True, - ) + )) + else: + with contextlib.suppress(UniqueViolationError): + await sess.add(User( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], + public_flags=dict(member.public_flags), + pending=member.pending, + in_guild=True, + )) + + await sess.commit() @commands.Cog.listener() async def on_member_update(self, _before: discord.Member, member: discord.Member) -> None: @@ -80,41 +86,43 @@ class MemberListeners(commands.Cog): roles = {role.id for role in member.roles} - if db_user := await User.get(str(member.id)): - if ( - db_user.name != member.name or - db_user.avatar_hash != getattr(member.avatar, "key", None) or - db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or - BotConfig.staff_role_id in - [role.id for role in member.roles] != db_user.is_staff - or db_user.pending is not member.pending - ): - await db_user.update( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in roles, - public_flags=dict(member.public_flags), - in_guild=True, - pending=member.pending, - ).apply() - else: - with contextlib.suppress(UniqueViolationError): - await User.create( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in roles, - public_flags=dict(member.public_flags), - in_guild=True, - pending=member.pending, - ) + async with async_session() as sess: + if db_user := await sess.get(User, str(member.id)): + if ( + db_user.name != member.name or + db_user.avatar_hash != getattr(member.avatar, "key", None) or + db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or + (BotConfig.staff_role_id in roles) != db_user.is_staff + or db_user.pending is not member.pending + ): + await sess.execute(update(User).where(User.id == str(member.id)).values( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in roles, + public_flags=dict(member.public_flags), + in_guild=True, + pending=member.pending, + )) + else: + with contextlib.suppress(UniqueViolationError): + sess.add(User( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in roles, + public_flags=dict(member.public_flags), + in_guild=True, + pending=member.pending, + )) + + await sess.commit() diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py index b446e26..28329d0 100644 --- a/metricity/exts/event_listeners/message_listeners.py +++ b/metricity/exts/event_listeners/message_listeners.py @@ -2,9 +2,11 @@ import discord from discord.ext import commands +from sqlalchemy import update from metricity.bot import Bot from metricity.config import BotConfig +from metricity.database import async_session from metricity.exts.event_listeners import _utils from metricity.models import Message, User @@ -33,28 +35,32 @@ class MessageListeners(commands.Cog): await self.bot.sync_process_complete.wait() await self.bot.channel_sync_in_progress.wait() - if not await User.get(str(message.author.id)): - return + async with async_session() as sess: + if not await sess.get(User, str(message.author.id)): + return - cat_id = message.channel.category.id if message.channel.category else None - if cat_id in BotConfig.ignore_categories: - return + cat_id = message.channel.category.id if message.channel.category else None + if cat_id in BotConfig.ignore_categories: + return + + from_thread = isinstance(message.channel, discord.Thread) + await _utils.sync_message(message, sess, from_thread=from_thread) - from_thread = isinstance(message.channel, discord.Thread) - await _utils.sync_message(message, from_thread=from_thread) + await sess.commit() @commands.Cog.listener() async def on_raw_message_delete(self, message: discord.RawMessageDeleteEvent) -> None: """If a message is deleted and we have a record of it set the is_deleted flag.""" - if message := await Message.get(str(message.message_id)): - await message.update(is_deleted=True).apply() + async with async_session() as sess: + await sess.execute(update(Message).where(Message.id == str(message.message_id)).values(is_deleted=True)) + await sess.commit() @commands.Cog.listener() async def on_raw_bulk_message_delete(self, messages: discord.RawBulkMessageDeleteEvent) -> None: """If messages are deleted in bulk and we have a record of them set the is_deleted flag.""" - for message_id in messages.message_ids: - if message := await Message.get(str(message_id)): - await message.update(is_deleted=True).apply() + async with async_session() as sess: + await sess.execute(update(Message).where(Message.id.in_(messages.message_ids)).values(is_deleted=True)) + await sess.commit() async def setup(bot: Bot) -> None: |