diff options
author | 2022-07-09 21:50:00 +0100 | |
---|---|---|
committer | 2022-07-09 21:50:00 +0100 | |
commit | 21212e5e2e03414a824d6136112bb9414d3964c1 (patch) | |
tree | 6de6778a9093cb4da6c09fcfde3c8d60e161b203 | |
parent | Migrate metricity to use BotBase from botcore (diff) |
Move all event listener logic to extensions
-rw-r--r-- | metricity/exts/__init__.py | 1 | ||||
-rw-r--r-- | metricity/exts/event_listeners/__init__.py | 1 | ||||
-rw-r--r-- | metricity/exts/event_listeners/_utils.py | 42 | ||||
-rw-r--r-- | metricity/exts/event_listeners/guild_listeners.py | 187 | ||||
-rw-r--r-- | metricity/exts/event_listeners/member_listeners.py | 124 | ||||
-rw-r--r-- | metricity/exts/event_listeners/message_listeners.py | 62 |
6 files changed, 417 insertions, 0 deletions
diff --git a/metricity/exts/__init__.py b/metricity/exts/__init__.py new file mode 100644 index 0000000..a8cce86 --- /dev/null +++ b/metricity/exts/__init__.py @@ -0,0 +1 @@ +"""A module containing all extensions to be loaded into the bot on startup.""" diff --git a/metricity/exts/event_listeners/__init__.py b/metricity/exts/event_listeners/__init__.py new file mode 100644 index 0000000..2830bfb --- /dev/null +++ b/metricity/exts/event_listeners/__init__.py @@ -0,0 +1 @@ +"""A module containing all extensions around listening to events and storing them in the database.""" diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py new file mode 100644 index 0000000..4de01fd --- /dev/null +++ b/metricity/exts/event_listeners/_utils.py @@ -0,0 +1,42 @@ +import discord + +from metricity import models + + +async def insert_thread(thread: discord.Thread) -> None: + """Insert the given thread to the database.""" + await models.Thread.create( + id=str(thread.id), + parent_channel_id=str(thread.parent_id), + name=thread.name, + archived=thread.archived, + auto_archive_duration=thread.auto_archive_duration, + locked=thread.locked, + type=thread.type.name, + ) + + +async def sync_message(message: discord.Message, from_thread: bool) -> None: + """Sync the given message with the database.""" + if await models.Message.get(str(message.id)): + return + + if from_thread and not message.channel.parent: + # This is a forum channel, not currently supported by Discord.py. Ignore it. + return + + args = { + "id": str(message.id), + "channel_id": str(message.channel.id), + "author_id": str(message.author.id), + "created_at": message.created_at + } + + if from_thread: + thread = message.channel + args["channel_id"] = str(thread.parent_id) + args["thread_id"] = str(thread.id) + if not await models.Thread.get(str(thread.id)): + await insert_thread(thread) + + await models.Message.create(**args) diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py new file mode 100644 index 0000000..18eb79a --- /dev/null +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -0,0 +1,187 @@ +"""An ext to listen for guild (and guild channel) events and syncs them to the database.""" + +import discord +from botcore.utils import logging, scheduling +from discord.ext import commands + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import db +from metricity.exts.event_listeners import _utils + +log = logging.get_logger(__name__) + + +class GuildListeners(commands.Cog): + """Listen for guild (and guild channel) events and sync them to the database.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + scheduling.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Sync all channels and members in the guild.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(self.bot.guild_id) + await self.sync_channels(guild) + + log.info("Beginning thread archive state synchronisation process") + await self.sync_thread_archive_state(guild) + + log.info("Beginning user synchronisation process") + await models.User.update.values(in_guild=False).gino.status() + + users = [] + for user in guild.members: + users.append({ + "id": str(user.id), + "name": user.name, + "avatar_hash": getattr(user.avatar, "key", None), + "guild_avatar_hash": getattr(user.guild_avatar, "key", None), + "joined_at": user.joined_at, + "created_at": user.created_at, + "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], + "bot": user.bot, + "in_guild": True, + "public_flags": dict(user.public_flags), + "pending": user.pending + }) + + log.info(f"Performing bulk upsert of {len(users)} rows") + + user_chunks = discord.utils.as_chunks(users, 500) + + for chunk in user_chunks: + log.info(f"Upserting chunk of {len(chunk)}") + await models.User.bulk_upsert(chunk) + + log.info("User upsert complete") + + self.bot.sync_process_complete.set() + + @staticmethod + async def sync_thread_archive_state(guild: discord.Guild) -> None: + """Sync the archive state of all threads in the database with the state in guild.""" + active_thread_ids = [str(thread.id) for thread in guild.threads] + async with db.transaction() as tx: + async for db_thread in tx.connection.iterate(models.Thread.query): + await db_thread.update(archived=db_thread.id not in active_thread_ids).apply() + + async def sync_channels(self, guild: discord.Guild) -> None: + """Sync channels and categories with the database.""" + self.bot.channel_sync_in_progress.clear() + + log.info("Beginning category synchronisation process") + + for channel in guild.channels: + if isinstance(channel, discord.CategoryChannel): + if db_cat := await models.Category.get(str(channel.id)): + await db_cat.update(name=channel.name).apply() + else: + await models.Category.create(id=str(channel.id), name=channel.name) + + log.info("Category synchronisation process complete, synchronising channels") + + for channel in guild.channels: + if channel.category: + if channel.category.id in BotConfig.ignore_categories: + continue + + if not isinstance(channel, discord.CategoryChannel): + category_id = str(channel.category.id) if channel.category else None + # Cast to bool so is_staff is False if channel.category is None + is_staff = bool( + channel.category + and channel.category.id in BotConfig.staff_categories + ) + if db_chan := await models.Channel.get(str(channel.id)): + await db_chan.update( + name=channel.name, + category_id=category_id, + is_staff=is_staff, + ).apply() + else: + await models.Channel.create( + id=str(channel.id), + name=channel.name, + category_id=category_id, + is_staff=is_staff, + ) + + log.info("Channel synchronisation process complete, synchronising threads") + + for thread in guild.threads: + if thread.parent and thread.parent.category: + if thread.parent.category.id in BotConfig.ignore_categories: + continue + else: + # This is a forum channel, not currently supported by Discord.py. Ignore it. + continue + + if db_thread := await models.Thread.get(str(thread.id)): + await db_thread.update( + name=thread.name, + archived=thread.archived, + auto_archive_duration=thread.auto_archive_duration, + locked=thread.locked, + type=thread.type.name, + ).apply() + else: + await _utils.insert_thread(thread) + + log.info("Thread synchronisation process complete, finished synchronising guild.") + self.bot.channel_sync_in_progress.set() + + @commands.Cog.listener() + async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None: + """Sync the channels when one is created.""" + if channel.guild.id != BotConfig.guild_id: + return + + await self.sync_channels(channel.guild) + + @commands.Cog.listener() + async def on_guild_channel_update( + self, + _before: discord.abc.GuildChannel, + channel: discord.abc.GuildChannel + ) -> None: + """Sync the channels when one is updated.""" + if channel.guild.id != BotConfig.guild_id: + return + + await self.sync_channels(channel.guild) + + @commands.Cog.listener() + async def on_thread_create(self, thread: discord.Thread) -> None: + """Sync channels when a thread is created.""" + if thread.guild.id != BotConfig.guild_id: + return + + await self.sync_channels(thread.guild) + + @commands.Cog.listener() + async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None: + """Sync the channels when one is updated.""" + if thread.guild.id != BotConfig.guild_id: + return + + await self.sync_channels(thread.guild) + + @commands.Cog.listener() + async def on_guild_available(self, guild: discord.Guild) -> None: + """Synchronize the user table with the Discord users.""" + log.info(f"Received guild available for {guild.id}") + + if guild.id != BotConfig.guild_id: + log.info("Guild was not the configured guild, discarding event") + return + + await self.sync_guild() + + +async def setup(bot: Bot) -> None: + """Load the GuildListeners cog.""" + await bot.add_cog(GuildListeners(bot)) diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py new file mode 100644 index 0000000..f3074ce --- /dev/null +++ b/metricity/exts/event_listeners/member_listeners.py @@ -0,0 +1,124 @@ +"""An ext to listen for member events and syncs them to the database.""" + +import discord +from asyncpg.exceptions import UniqueViolationError +from discord.ext import commands + +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.models import User + + +class MemberListeners(commands.Cog): + """Listen for member events and sync them to the database.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @commands.Cog.listener() + async def on_member_remove(self, member: discord.Member) -> None: + """On a user leaving the server mark in_guild as False.""" + await self.bot.sync_process_complete.wait() + + if member.guild.id != BotConfig.guild_id: + return + + if db_user := await User.get(str(member.id)): + await db_user.update( + in_guild=False + ).apply() + + @commands.Cog.listener() + async def on_member_join(self, member: discord.Member) -> None: + """On a user joining the server add them to the database.""" + await self.bot.sync_process_complete.wait() + + if member.guild.id != BotConfig.guild_id: + return + + if db_user := await User.get(str(member.id)): + await db_user.update( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], + public_flags=dict(member.public_flags), + pending=member.pending, + in_guild=True + ).apply() + else: + try: + await User.create( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], + public_flags=dict(member.public_flags), + pending=member.pending, + in_guild=True + ) + except UniqueViolationError: + pass + + @commands.Cog.listener() + async def on_member_update(self, before: discord.Member, member: discord.Member) -> None: + """When a member updates their profile, update the DB record.""" + await self.bot.sync_process_complete.wait() + + if member.guild.id != BotConfig.guild_id: + return + + # Joined at will be null if we are not ready to process events yet + if not member.joined_at: + return + + roles = set([role.id for role in member.roles]) + + if db_user := await User.get(str(member.id)): + if ( + db_user.name != member.name or + db_user.avatar_hash != getattr(member.avatar, "key", None) or + db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or + BotConfig.staff_role_id in + [role.id for role in member.roles] != db_user.is_staff + or db_user.pending is not member.pending + ): + await db_user.update( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in roles, + public_flags=dict(member.public_flags), + in_guild=True, + pending=member.pending + ).apply() + else: + try: + await User.create( + id=str(member.id), + name=member.name, + avatar_hash=getattr(member.avatar, "key", None), + guild_avatar_hash=getattr(member.guild_avatar, "key", None), + joined_at=member.joined_at, + created_at=member.created_at, + is_staff=BotConfig.staff_role_id in roles, + public_flags=dict(member.public_flags), + in_guild=True, + pending=member.pending + ) + except UniqueViolationError: + pass + + +async def setup(bot: Bot) -> None: + """Load the MemberListeners cog.""" + await bot.add_cog(MemberListeners(bot)) diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py new file mode 100644 index 0000000..b446e26 --- /dev/null +++ b/metricity/exts/event_listeners/message_listeners.py @@ -0,0 +1,62 @@ +"""An ext to listen for message events and syncs them to the database.""" + +import discord +from discord.ext import commands + +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.exts.event_listeners import _utils +from metricity.models import Message, User + + +class MessageListeners(commands.Cog): + """Listen for message events and sync them to the database.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @commands.Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Add a message to the table when one is sent providing the author has accepted.""" + if not message.guild: + return + + if message.author.bot: + return + + if message.guild.id != BotConfig.guild_id: + return + + if message.type in (discord.MessageType.thread_created, discord.MessageType.auto_moderation_action): + return + + await self.bot.sync_process_complete.wait() + await self.bot.channel_sync_in_progress.wait() + + if not await User.get(str(message.author.id)): + return + + cat_id = message.channel.category.id if message.channel.category else None + if cat_id in BotConfig.ignore_categories: + return + + from_thread = isinstance(message.channel, discord.Thread) + await _utils.sync_message(message, from_thread=from_thread) + + @commands.Cog.listener() + async def on_raw_message_delete(self, message: discord.RawMessageDeleteEvent) -> None: + """If a message is deleted and we have a record of it set the is_deleted flag.""" + if message := await Message.get(str(message.message_id)): + await message.update(is_deleted=True).apply() + + @commands.Cog.listener() + async def on_raw_bulk_message_delete(self, messages: discord.RawBulkMessageDeleteEvent) -> None: + """If messages are deleted in bulk and we have a record of them set the is_deleted flag.""" + for message_id in messages.message_ids: + if message := await Message.get(str(message_id)): + await message.update(is_deleted=True).apply() + + +async def setup(bot: Bot) -> None: + """Load the MessageListeners cog.""" + await bot.add_cog(MessageListeners(bot)) |