diff options
| -rw-r--r-- | metricity/exts/event_listeners/_syncer_utils.py | 152 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/_utils.py | 38 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/guild_listeners.py | 212 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/message_listeners.py | 4 | ||||
| -rw-r--r-- | metricity/exts/event_listeners/startup_sync.py | 115 | ||||
| -rw-r--r-- | pyproject.toml | 2 | 
6 files changed, 277 insertions, 246 deletions
| diff --git a/metricity/exts/event_listeners/_syncer_utils.py b/metricity/exts/event_listeners/_syncer_utils.py new file mode 100644 index 0000000..258a165 --- /dev/null +++ b/metricity/exts/event_listeners/_syncer_utils.py @@ -0,0 +1,152 @@ +import discord +from pydis_core.utils import logging +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import async_session + +log = logging.get_logger(__name__) + + +def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: +    """Insert the given thread to the database session.""" +    sess.add(models.Thread( +        id=str(thread.id), +        parent_channel_id=str(thread.parent_id), +        name=thread.name, +        archived=thread.archived, +        auto_archive_duration=thread.auto_archive_duration, +        locked=thread.locked, +        type=thread.type.name, +        created_at=thread.created_at, +    )) + + +async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: +    """Sync the given message with the database.""" +    if await sess.get(models.Message, str(message.id)): +        return + +    args = { +        "id": str(message.id), +        "channel_id": str(message.channel.id), +        "author_id": str(message.author.id), +        "created_at": message.created_at, +    } + +    if from_thread: +        thread = message.channel +        args["channel_id"] = str(thread.parent_id) +        args["thread_id"] = str(thread.id) + +    sess.add(models.Message(**args)) + + +async def sync_channels(bot: Bot, guild: discord.Guild) -> None: +    """Sync channels and categories with the database.""" +    bot.channel_sync_in_progress.clear() + +    log.info("Beginning category synchronisation process") + +    async with async_session() as sess: +        for channel in guild.channels: +            if isinstance(channel, discord.CategoryChannel): +                if existing_cat := await sess.get(models.Category, str(channel.id)): +                    existing_cat.name = channel.name +                else: +                    sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) + +        await sess.commit() + +    log.info("Category synchronisation process complete, synchronising deleted categories") + +    async with async_session() as sess: +        await sess.execute( +            update(models.Category) +            .where(~models.Category.id.in_( +                [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], +            )) +            .values(deleted=True), +        ) +        await sess.commit() + +    log.info("Deleted category synchronisation process complete, synchronising channels") + +    async with async_session() as sess: +        for channel in guild.channels: +            if channel.category and channel.category.id in BotConfig.ignore_categories: +                continue + +            if not isinstance(channel, discord.CategoryChannel): +                category_id = str(channel.category.id) if channel.category else None +                # Cast to bool so is_staff is False if channel.category is None +                is_staff = channel.id in BotConfig.staff_channels or bool( +                    channel.category and channel.category.id in BotConfig.staff_categories, +                ) +                if db_chan := await sess.get(models.Channel, str(channel.id)): +                    db_chan.name = channel.name +                else: +                    sess.add(models.Channel( +                        id=str(channel.id), +                        name=channel.name, +                        category_id=category_id, +                        is_staff=is_staff, +                        deleted=False, +                    )) + +        await sess.commit() + +    log.info("Channel synchronisation process complete, synchronising deleted channels") + +    async with async_session() as sess: +        await sess.execute( +            update(models.Channel) +            .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) +            .values(deleted=True), +        ) +        await sess.commit() + +    log.info("Deleted channel synchronisation process complete, synchronising threads") + +    async with async_session() as sess: +        for thread in guild.threads: +            if thread.parent and thread.parent.category: +                if thread.parent.category.id in BotConfig.ignore_categories: +                    continue +            else: +                # This is a forum channel, not currently supported by Discord.py. Ignore it. +                continue + +            if db_thread := await sess.get(models.Thread, str(thread.id)): +                db_thread.name = thread.name +                db_thread.archived = thread.archived +                db_thread.auto_archive_duration = thread.auto_archive_duration +                db_thread.locked = thread.locked +                db_thread.type = thread.type.name +            else: +                insert_thread(thread, sess) +        await sess.commit() + +    log.info("Thread synchronisation process complete, finished synchronising guild.") +    bot.channel_sync_in_progress.set() + + +async def sync_thread_archive_state(guild: discord.Guild) -> None: +    """Sync the archive state of all threads in the database with the state in guild.""" +    active_thread_ids = [str(thread.id) for thread in guild.threads] + +    async with async_session() as sess: +        await sess.execute( +            update(models.Thread) +            .where(models.Thread.id.in_(active_thread_ids)) +            .values(archived=False), +        ) +        await sess.execute( +            update(models.Thread) +            .where(~models.Thread.id.in_(active_thread_ids)) +            .values(archived=True), +        ) +        await sess.commit() diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py deleted file mode 100644 index 4006ea2..0000000 --- a/metricity/exts/event_listeners/_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import discord -from sqlalchemy.ext.asyncio import AsyncSession - -from metricity import models - - -def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: -    """Insert the given thread to the database session.""" -    sess.add(models.Thread( -        id=str(thread.id), -        parent_channel_id=str(thread.parent_id), -        name=thread.name, -        archived=thread.archived, -        auto_archive_duration=thread.auto_archive_duration, -        locked=thread.locked, -        type=thread.type.name, -        created_at=thread.created_at, -    )) - - -async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: -    """Sync the given message with the database.""" -    if await sess.get(models.Message, str(message.id)): -        return - -    args = { -        "id": str(message.id), -        "channel_id": str(message.channel.id), -        "author_id": str(message.author.id), -        "created_at": message.created_at, -    } - -    if from_thread: -        thread = message.channel -        args["channel_id"] = str(thread.parent_id) -        args["thread_id"] = str(thread.id) - -    sess.add(models.Message(**args)) diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py index 79cd8f4..db976b2 100644 --- a/metricity/exts/event_listeners/guild_listeners.py +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -1,18 +1,12 @@  """An ext to listen for guild (and guild channel) events and syncs them to the database.""" -import math -  import discord  from discord.ext import commands -from pydis_core.utils import logging, scheduling -from sqlalchemy import column, update -from sqlalchemy.dialects.postgresql import insert +from pydis_core.utils import logging -from metricity import models  from metricity.bot import Bot  from metricity.config import BotConfig -from metricity.database import async_session -from metricity.exts.event_listeners import _utils +from metricity.exts.event_listeners import _syncer_utils  log = logging.get_logger(__name__) @@ -22,187 +16,6 @@ class GuildListeners(commands.Cog):      def __init__(self, bot: Bot) -> None:          self.bot = bot -        scheduling.create_task(self.sync_guild()) - -    async def sync_guild(self) -> None: -        """Sync all channels and members in the guild.""" -        await self.bot.wait_until_guild_available() - -        guild = self.bot.get_guild(self.bot.guild_id) -        await self.sync_channels(guild) - -        log.info("Beginning thread archive state synchronisation process") -        await self.sync_thread_archive_state(guild) - -        log.info("Beginning user synchronisation process") -        async with async_session() as sess: -            await sess.execute(update(models.User).values(in_guild=False)) -            await sess.commit() - -        users = [ -            { -                "id": str(user.id), -                "name": user.name, -                "avatar_hash": getattr(user.avatar, "key", None), -                "guild_avatar_hash": getattr(user.guild_avatar, "key", None), -                "joined_at": user.joined_at, -                "created_at": user.created_at, -                "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], -                "bot": user.bot, -                "in_guild": True, -                "public_flags": dict(user.public_flags), -                "pending": user.pending, -            } -            for user in guild.members -        ] - -        user_chunks = discord.utils.as_chunks(users, 500) -        created = 0 -        updated = 0 -        total_users = len(users) - -        log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500)) - -        async with async_session() as sess: -            for chunk in user_chunks: -                qs = insert(models.User).returning(column("xmax")).values(chunk) - -                update_cols = [ -                    "name", -                    "avatar_hash", -                    "guild_avatar_hash", -                    "joined_at", -                    "is_staff", -                    "bot", -                    "in_guild", -                    "public_flags", -                    "pending", -                ] - -                res = await sess.execute(qs.on_conflict_do_update( -                    index_elements=[models.User.id], -                    set_={k: getattr(qs.excluded, k) for k in update_cols}, -                )) - -                objs = list(res) - -                created += [obj[0] == 0 for obj in objs].count(True) -                updated += [obj[0] != 0 for obj in objs].count(True) - -                log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining", -                         created, updated, created + updated, total_users - (created + updated)) - -            await sess.commit() - -        log.info("User upsert complete") - -        self.bot.sync_process_complete.set() - -    @staticmethod -    async def sync_thread_archive_state(guild: discord.Guild) -> None: -        """Sync the archive state of all threads in the database with the state in guild.""" -        active_thread_ids = [str(thread.id) for thread in guild.threads] - -        async with async_session() as sess: -            await sess.execute( -                update(models.Thread) -                .where(models.Thread.id.in_(active_thread_ids)) -                .values(archived=False), -            ) -            await sess.execute( -                update(models.Thread) -                .where(~models.Thread.id.in_(active_thread_ids)) -                .values(archived=True), -            ) -            await sess.commit() - -    async def sync_channels(self, guild: discord.Guild) -> None: -        """Sync channels and categories with the database.""" -        self.bot.channel_sync_in_progress.clear() - -        log.info("Beginning category synchronisation process") - -        async with async_session() as sess: -            for channel in guild.channels: -                if isinstance(channel, discord.CategoryChannel): -                    if existing_cat := await sess.get(models.Category, str(channel.id)): -                        existing_cat.name = channel.name -                    else: -                        sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) - -            await sess.commit() - -        log.info("Category synchronisation process complete, synchronising deleted categories") - -        async with async_session() as sess: -            await sess.execute( -                update(models.Category) -                .where(~models.Category.id.in_( -                    [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], -                )) -                .values(deleted=True), -            ) -            await sess.commit() - -        log.info("Deleted category synchronisation process complete, synchronising channels") - -        async with async_session() as sess: -            for channel in guild.channels: -                if channel.category and channel.category.id in BotConfig.ignore_categories: -                    continue - -                if not isinstance(channel, discord.CategoryChannel): -                    category_id = str(channel.category.id) if channel.category else None -                    # Cast to bool so is_staff is False if channel.category is None -                    is_staff = channel.id in BotConfig.staff_channels or bool( -                        channel.category and channel.category.id in BotConfig.staff_categories, -                    ) -                    if db_chan := await sess.get(models.Channel, str(channel.id)): -                        db_chan.name = channel.name -                    else: -                        sess.add(models.Channel( -                            id=str(channel.id), -                            name=channel.name, -                            category_id=category_id, -                            is_staff=is_staff, -                            deleted=False, -                        )) - -            await sess.commit() - -        log.info("Channel synchronisation process complete, synchronising deleted channels") - -        async with async_session() as sess: -            await sess.execute( -                update(models.Channel) -                .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) -                .values(deleted=True), -            ) -            await sess.commit() - -        log.info("Deleted channel synchronisation process complete, synchronising threads") - -        async with async_session() as sess: -            for thread in guild.threads: -                if thread.parent and thread.parent.category: -                    if thread.parent.category.id in BotConfig.ignore_categories: -                        continue -                else: -                    # This is a forum channel, not currently supported by Discord.py. Ignore it. -                    continue - -                if db_thread := await sess.get(models.Thread, str(thread.id)): -                    db_thread.name = thread.name -                    db_thread.archived = thread.archived -                    db_thread.auto_archive_duration = thread.auto_archive_duration -                    db_thread.locked = thread.locked -                    db_thread.type = thread.type.name -                else: -                    _utils.insert_thread(thread, sess) -            await sess.commit() - -        log.info("Thread synchronisation process complete, finished synchronising guild.") -        self.bot.channel_sync_in_progress.set()      @commands.Cog.listener()      async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None: @@ -210,7 +23,7 @@ class GuildListeners(commands.Cog):          if channel.guild.id != BotConfig.guild_id:              return -        await self.sync_channels(channel.guild) +        await _syncer_utils.sync_channels(self.bot, channel.guild)      @commands.Cog.listener()      async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel) -> None: @@ -218,7 +31,7 @@ class GuildListeners(commands.Cog):          if channel.guild.id != BotConfig.guild_id:              return -        await self.sync_channels(channel.guild) +        await _syncer_utils.sync_channels(self.bot, channel.guild)      @commands.Cog.listener()      async def on_guild_channel_update( @@ -230,7 +43,7 @@ class GuildListeners(commands.Cog):          if channel.guild.id != BotConfig.guild_id:              return -        await self.sync_channels(channel.guild) +        await _syncer_utils.sync_channels(self.bot, channel.guild)      @commands.Cog.listener()      async def on_thread_create(self, thread: discord.Thread) -> None: @@ -238,7 +51,7 @@ class GuildListeners(commands.Cog):          if thread.guild.id != BotConfig.guild_id:              return -        await self.sync_channels(thread.guild) +        await _syncer_utils.sync_channels(self.bot, thread.guild)      @commands.Cog.listener()      async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None: @@ -246,18 +59,7 @@ class GuildListeners(commands.Cog):          if thread.guild.id != BotConfig.guild_id:              return -        await self.sync_channels(thread.guild) - -    @commands.Cog.listener() -    async def on_guild_available(self, guild: discord.Guild) -> None: -        """Synchronize the user table with the Discord users.""" -        log.info("Received guild available for %d", guild.id) - -        if guild.id != BotConfig.guild_id: -            log.info("Guild was not the configured guild, discarding event") -            return - -        await self.sync_guild() +        await _syncer_utils.sync_channels(self.bot, thread.guild)  async def setup(bot: Bot) -> None: diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py index a71e53f..917b13c 100644 --- a/metricity/exts/event_listeners/message_listeners.py +++ b/metricity/exts/event_listeners/message_listeners.py @@ -7,7 +7,7 @@ from sqlalchemy import update  from metricity.bot import Bot  from metricity.config import BotConfig  from metricity.database import async_session -from metricity.exts.event_listeners import _utils +from metricity.exts.event_listeners import _syncer_utils  from metricity.models import Message, User @@ -44,7 +44,7 @@ class MessageListeners(commands.Cog):                  return              from_thread = isinstance(message.channel, discord.Thread) -            await _utils.sync_message(message, sess, from_thread=from_thread) +            await _syncer_utils.sync_message(message, sess, from_thread=from_thread)              await sess.commit() diff --git a/metricity/exts/event_listeners/startup_sync.py b/metricity/exts/event_listeners/startup_sync.py new file mode 100644 index 0000000..0f6264f --- /dev/null +++ b/metricity/exts/event_listeners/startup_sync.py @@ -0,0 +1,115 @@ +"""An ext to sync the guild when the bot starts up.""" + +import math + +import discord +from discord.ext import commands +from pydis_core.utils import logging, scheduling +from sqlalchemy import column, update +from sqlalchemy.dialects.postgresql import insert + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import async_session +from metricity.exts.event_listeners import _syncer_utils + +log = logging.get_logger(__name__) + + +class StartupSyncer(commands.Cog): +    """Sync the guild on bot startup.""" + +    def __init__(self, bot: Bot) -> None: +        self.bot = bot +        scheduling.create_task(self.sync_guild()) + +    async def sync_guild(self) -> None: +        """Sync all channels and members in the guild.""" +        await self.bot.wait_until_guild_available() + +        guild = self.bot.get_guild(self.bot.guild_id) +        await _syncer_utils.sync_channels(self.bot, guild) + +        log.info("Beginning thread archive state synchronisation process") +        await _syncer_utils.sync_thread_archive_state(guild) + +        log.info("Beginning user synchronisation process") +        async with async_session() as sess: +            await sess.execute(update(models.User).values(in_guild=False)) +            await sess.commit() + +        users = ( +            { +                "id": str(user.id), +                "name": user.name, +                "avatar_hash": getattr(user.avatar, "key", None), +                "guild_avatar_hash": getattr(user.guild_avatar, "key", None), +                "joined_at": user.joined_at, +                "created_at": user.created_at, +                "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], +                "bot": user.bot, +                "in_guild": True, +                "public_flags": dict(user.public_flags), +                "pending": user.pending, +            } +            for user in guild.members +        ) + +        user_chunks = discord.utils.as_chunks(users, 500) +        created = 0 +        updated = 0 +        total_users = len(guild.members) + +        log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500)) + +        async with async_session() as sess: +            for chunk in user_chunks: +                qs = insert(models.User).returning(column("xmax")).values(chunk) + +                update_cols = [ +                    "name", +                    "avatar_hash", +                    "guild_avatar_hash", +                    "joined_at", +                    "is_staff", +                    "bot", +                    "in_guild", +                    "public_flags", +                    "pending", +                ] + +                res = await sess.execute(qs.on_conflict_do_update( +                    index_elements=[models.User.id], +                    set_={k: getattr(qs.excluded, k) for k in update_cols}, +                )) + +                objs = list(res) + +                created += [obj[0] == 0 for obj in objs].count(True) +                updated += [obj[0] != 0 for obj in objs].count(True) + +                log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining", +                         created, updated, created + updated, total_users - (created + updated)) + +            await sess.commit() + +        log.info("User upsert complete") + +        self.bot.sync_process_complete.set() + +    @commands.Cog.listener() +    async def on_guild_available(self, guild: discord.Guild) -> None: +        """Synchronize the user table with the Discord users.""" +        log.info("Received guild available for %d", guild.id) + +        if guild.id != BotConfig.guild_id: +            log.info("Guild was not the configured guild, discarding event") +            return + +        await self.sync_guild() + + +async def setup(bot: Bot) -> None: +    """Load the GuildListeners cog.""" +    await bot.add_cog(StartupSyncer(bot)) diff --git a/pyproject.toml b/pyproject.toml index e817933..5c5e61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@  [tool.poetry]  name = "metricity" -version = "2.5.1" +version = "2.6.0"  description = "Advanced metric collection for the Python Discord server"  authors = ["Joe Banks <[email protected]>"]  license = "MIT" | 
