diff options
| author | 2022-07-09 21:50:00 +0100 | |
|---|---|---|
| committer | 2022-07-09 21:50:00 +0100 | |
| commit | 21212e5e2e03414a824d6136112bb9414d3964c1 (patch) | |
| tree | 6de6778a9093cb4da6c09fcfde3c8d60e161b203 | |
| parent | Migrate metricity to use BotBase from botcore (diff) | |
Move all event listener logic to extensions
| -rw-r--r-- | metricity/exts/__init__.py | 1 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/__init__.py | 1 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/_utils.py | 42 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/guild_listeners.py | 187 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/member_listeners.py | 124 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/message_listeners.py | 62 | 
6 files changed, 417 insertions, 0 deletions
| diff --git a/metricity/exts/__init__.py b/metricity/exts/__init__.py new file mode 100644 index 0000000..a8cce86 --- /dev/null +++ b/metricity/exts/__init__.py @@ -0,0 +1 @@ +"""A module containing all extensions to be loaded into the bot on startup.""" diff --git a/metricity/exts/event_listeners/__init__.py b/metricity/exts/event_listeners/__init__.py new file mode 100644 index 0000000..2830bfb --- /dev/null +++ b/metricity/exts/event_listeners/__init__.py @@ -0,0 +1 @@ +"""A module containing all extensions around listening to events and storing them in the database.""" diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py new file mode 100644 index 0000000..4de01fd --- /dev/null +++ b/metricity/exts/event_listeners/_utils.py @@ -0,0 +1,42 @@ +import discord + +from metricity import models + + +async def insert_thread(thread: discord.Thread) -> None: +    """Insert the given thread to the database.""" +    await models.Thread.create( +        id=str(thread.id), +        parent_channel_id=str(thread.parent_id), +        name=thread.name, +        archived=thread.archived, +        auto_archive_duration=thread.auto_archive_duration, +        locked=thread.locked, +        type=thread.type.name, +    ) + + +async def sync_message(message: discord.Message, from_thread: bool) -> None: +    """Sync the given message with the database.""" +    if await models.Message.get(str(message.id)): +        return + +    if from_thread and not message.channel.parent: +        # This is a forum channel, not currently supported by Discord.py. Ignore it. +        return + +    args = { +        "id": str(message.id), +        "channel_id": str(message.channel.id), +        "author_id": str(message.author.id), +        "created_at": message.created_at +    } + +    if from_thread: +        thread = message.channel +        args["channel_id"] = str(thread.parent_id) +        args["thread_id"] = str(thread.id) +        if not await models.Thread.get(str(thread.id)): +            await insert_thread(thread) + +    await models.Message.create(**args) diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py new file mode 100644 index 0000000..18eb79a --- /dev/null +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -0,0 +1,187 @@ +"""An ext to listen for guild (and guild channel) events and syncs them to the database.""" + +import discord +from botcore.utils import logging, scheduling +from discord.ext import commands + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import db +from metricity.exts.event_listeners import _utils + +log = logging.get_logger(__name__) + + +class GuildListeners(commands.Cog): +    """Listen for guild (and guild channel) events and sync them to the database.""" + +    def __init__(self, bot: Bot) -> None: +        self.bot = bot +        scheduling.create_task(self.sync_guild()) + +    async def sync_guild(self) -> None: +        """Sync all channels and members in the guild.""" +        await self.bot.wait_until_guild_available() + +        guild = self.bot.get_guild(self.bot.guild_id) +        await self.sync_channels(guild) + +        log.info("Beginning thread archive state synchronisation process") +        await self.sync_thread_archive_state(guild) + +        log.info("Beginning user synchronisation process") +        await models.User.update.values(in_guild=False).gino.status() + +        users = [] +        for user in guild.members: +            users.append({ +                "id": str(user.id), +                "name": user.name, +                "avatar_hash": getattr(user.avatar, "key", None), +                "guild_avatar_hash": getattr(user.guild_avatar, "key", None), +                "joined_at": user.joined_at, +                "created_at": user.created_at, +                "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], +                "bot": user.bot, +                "in_guild": True, +                "public_flags": dict(user.public_flags), +                "pending": user.pending +            }) + +        log.info(f"Performing bulk upsert of {len(users)} rows") + +        user_chunks = discord.utils.as_chunks(users, 500) + +        for chunk in user_chunks: +            log.info(f"Upserting chunk of {len(chunk)}") +            await models.User.bulk_upsert(chunk) + +        log.info("User upsert complete") + +        self.bot.sync_process_complete.set() + +    @staticmethod +    async def sync_thread_archive_state(guild: discord.Guild) -> None: +        """Sync the archive state of all threads in the database with the state in guild.""" +        active_thread_ids = [str(thread.id) for thread in guild.threads] +        async with db.transaction() as tx: +            async for db_thread in tx.connection.iterate(models.Thread.query): +                await db_thread.update(archived=db_thread.id not in active_thread_ids).apply() + +    async def sync_channels(self, guild: discord.Guild) -> None: +        """Sync channels and categories with the database.""" +        self.bot.channel_sync_in_progress.clear() + +        log.info("Beginning category synchronisation process") + +        for channel in guild.channels: +            if isinstance(channel, discord.CategoryChannel): +                if db_cat := await models.Category.get(str(channel.id)): +                    await db_cat.update(name=channel.name).apply() +                else: +                    await models.Category.create(id=str(channel.id), name=channel.name) + +        log.info("Category synchronisation process complete, synchronising channels") + +        for channel in guild.channels: +            if channel.category: +                if channel.category.id in BotConfig.ignore_categories: +                    continue + +            if not isinstance(channel, discord.CategoryChannel): +                category_id = str(channel.category.id) if channel.category else None +                # Cast to bool so is_staff is False if channel.category is None +                is_staff = bool( +                    channel.category +                    and channel.category.id in BotConfig.staff_categories +                ) +                if db_chan := await models.Channel.get(str(channel.id)): +                    await db_chan.update( +                        name=channel.name, +                        category_id=category_id, +                        is_staff=is_staff, +                    ).apply() +                else: +                    await models.Channel.create( +                        id=str(channel.id), +                        name=channel.name, +                        category_id=category_id, +                        is_staff=is_staff, +                    ) + +        log.info("Channel synchronisation process complete, synchronising threads") + +        for thread in guild.threads: +            if thread.parent and thread.parent.category: +                if thread.parent.category.id in BotConfig.ignore_categories: +                    continue +            else: +                # This is a forum channel, not currently supported by Discord.py. Ignore it. +                continue + +            if db_thread := await models.Thread.get(str(thread.id)): +                await db_thread.update( +                    name=thread.name, +                    archived=thread.archived, +                    auto_archive_duration=thread.auto_archive_duration, +                    locked=thread.locked, +                    type=thread.type.name, +                ).apply() +            else: +                await _utils.insert_thread(thread) + +        log.info("Thread synchronisation process complete, finished synchronising guild.") +        self.bot.channel_sync_in_progress.set() + +    @commands.Cog.listener() +    async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None: +        """Sync the channels when one is created.""" +        if channel.guild.id != BotConfig.guild_id: +            return + +        await self.sync_channels(channel.guild) + +    @commands.Cog.listener() +    async def on_guild_channel_update( +        self, +        _before: discord.abc.GuildChannel, +        channel: discord.abc.GuildChannel +    ) -> None: +        """Sync the channels when one is updated.""" +        if channel.guild.id != BotConfig.guild_id: +            return + +        await self.sync_channels(channel.guild) + +    @commands.Cog.listener() +    async def on_thread_create(self, thread: discord.Thread) -> None: +        """Sync channels when a thread is created.""" +        if thread.guild.id != BotConfig.guild_id: +            return + +        await self.sync_channels(thread.guild) + +    @commands.Cog.listener() +    async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None: +        """Sync the channels when one is updated.""" +        if thread.guild.id != BotConfig.guild_id: +            return + +        await self.sync_channels(thread.guild) + +    @commands.Cog.listener() +    async def on_guild_available(self, guild: discord.Guild) -> None: +        """Synchronize the user table with the Discord users.""" +        log.info(f"Received guild available for {guild.id}") + +        if guild.id != BotConfig.guild_id: +            log.info("Guild was not the configured guild, discarding event") +            return + +        await self.sync_guild() + + +async def setup(bot: Bot) -> None: +    """Load the GuildListeners cog.""" +    await bot.add_cog(GuildListeners(bot)) diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py new file mode 100644 index 0000000..f3074ce --- /dev/null +++ b/metricity/exts/event_listeners/member_listeners.py @@ -0,0 +1,124 @@ +"""An ext to listen for member events and syncs them to the database.""" + +import discord +from asyncpg.exceptions import UniqueViolationError +from discord.ext import commands + +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.models import User + + +class MemberListeners(commands.Cog): +    """Listen for member events and sync them to the database.""" + +    def __init__(self, bot: Bot) -> None: +        self.bot = bot + +    @commands.Cog.listener() +    async def on_member_remove(self, member: discord.Member) -> None: +        """On a user leaving the server mark in_guild as False.""" +        await self.bot.sync_process_complete.wait() + +        if member.guild.id != BotConfig.guild_id: +            return + +        if db_user := await User.get(str(member.id)): +            await db_user.update( +                in_guild=False +            ).apply() + +    @commands.Cog.listener() +    async def on_member_join(self, member: discord.Member) -> None: +        """On a user joining the server add them to the database.""" +        await self.bot.sync_process_complete.wait() + +        if member.guild.id != BotConfig.guild_id: +            return + +        if db_user := await User.get(str(member.id)): +            await db_user.update( +                id=str(member.id), +                name=member.name, +                avatar_hash=getattr(member.avatar, "key", None), +                guild_avatar_hash=getattr(member.guild_avatar, "key", None), +                joined_at=member.joined_at, +                created_at=member.created_at, +                is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], +                public_flags=dict(member.public_flags), +                pending=member.pending, +                in_guild=True +            ).apply() +        else: +            try: +                await User.create( +                    id=str(member.id), +                    name=member.name, +                    avatar_hash=getattr(member.avatar, "key", None), +                    guild_avatar_hash=getattr(member.guild_avatar, "key", None), +                    joined_at=member.joined_at, +                    created_at=member.created_at, +                    is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], +                    public_flags=dict(member.public_flags), +                    pending=member.pending, +                    in_guild=True +                ) +            except UniqueViolationError: +                pass + +    @commands.Cog.listener() +    async def on_member_update(self, before: discord.Member, member: discord.Member) -> None: +        """When a member updates their profile, update the DB record.""" +        await self.bot.sync_process_complete.wait() + +        if member.guild.id != BotConfig.guild_id: +            return + +        # Joined at will be null if we are not ready to process events yet +        if not member.joined_at: +            return + +        roles = set([role.id for role in member.roles]) + +        if db_user := await User.get(str(member.id)): +            if ( +                db_user.name != member.name or +                db_user.avatar_hash != getattr(member.avatar, "key", None) or +                db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or +                BotConfig.staff_role_id in +                [role.id for role in member.roles] != db_user.is_staff +                or db_user.pending is not member.pending +            ): +                await db_user.update( +                    id=str(member.id), +                    name=member.name, +                    avatar_hash=getattr(member.avatar, "key", None), +                    guild_avatar_hash=getattr(member.guild_avatar, "key", None), +                    joined_at=member.joined_at, +                    created_at=member.created_at, +                    is_staff=BotConfig.staff_role_id in roles, +                    public_flags=dict(member.public_flags), +                    in_guild=True, +                    pending=member.pending +                ).apply() +        else: +            try: +                await User.create( +                    id=str(member.id), +                    name=member.name, +                    avatar_hash=getattr(member.avatar, "key", None), +                    guild_avatar_hash=getattr(member.guild_avatar, "key", None), +                    joined_at=member.joined_at, +                    created_at=member.created_at, +                    is_staff=BotConfig.staff_role_id in roles, +                    public_flags=dict(member.public_flags), +                    in_guild=True, +                    pending=member.pending +                ) +            except UniqueViolationError: +                pass + + +async def setup(bot: Bot) -> None: +    """Load the MemberListeners cog.""" +    await bot.add_cog(MemberListeners(bot)) diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py new file mode 100644 index 0000000..b446e26 --- /dev/null +++ b/metricity/exts/event_listeners/message_listeners.py @@ -0,0 +1,62 @@ +"""An ext to listen for message events and syncs them to the database.""" + +import discord +from discord.ext import commands + +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.exts.event_listeners import _utils +from metricity.models import Message, User + + +class MessageListeners(commands.Cog): +    """Listen for message events and sync them to the database.""" + +    def __init__(self, bot: Bot) -> None: +        self.bot = bot + +    @commands.Cog.listener() +    async def on_message(self, message: discord.Message) -> None: +        """Add a message to the table when one is sent providing the author has accepted.""" +        if not message.guild: +            return + +        if message.author.bot: +            return + +        if message.guild.id != BotConfig.guild_id: +            return + +        if message.type in (discord.MessageType.thread_created, discord.MessageType.auto_moderation_action): +            return + +        await self.bot.sync_process_complete.wait() +        await self.bot.channel_sync_in_progress.wait() + +        if not await User.get(str(message.author.id)): +            return + +        cat_id = message.channel.category.id if message.channel.category else None +        if cat_id in BotConfig.ignore_categories: +            return + +        from_thread = isinstance(message.channel, discord.Thread) +        await _utils.sync_message(message, from_thread=from_thread) + +    @commands.Cog.listener() +    async def on_raw_message_delete(self, message: discord.RawMessageDeleteEvent) -> None: +        """If a message is deleted and we have a record of it set the is_deleted flag.""" +        if message := await Message.get(str(message.message_id)): +            await message.update(is_deleted=True).apply() + +    @commands.Cog.listener() +    async def on_raw_bulk_message_delete(self, messages: discord.RawBulkMessageDeleteEvent) -> None: +        """If messages are deleted in bulk and we have a record of them set the is_deleted flag.""" +        for message_id in messages.message_ids: +            if message := await Message.get(str(message_id)): +                await message.update(is_deleted=True).apply() + + +async def setup(bot: Bot) -> None: +    """Load the MessageListeners cog.""" +    await bot.add_cog(MessageListeners(bot)) | 
