aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--metricity/exts/__init__.py1
-rw-r--r--metricity/exts/event_listeners/__init__.py1
-rw-r--r--metricity/exts/event_listeners/_utils.py42
-rw-r--r--metricity/exts/event_listeners/guild_listeners.py187
-rw-r--r--metricity/exts/event_listeners/member_listeners.py124
-rw-r--r--metricity/exts/event_listeners/message_listeners.py62
6 files changed, 417 insertions, 0 deletions
diff --git a/metricity/exts/__init__.py b/metricity/exts/__init__.py
new file mode 100644
index 0000000..a8cce86
--- /dev/null
+++ b/metricity/exts/__init__.py
@@ -0,0 +1 @@
+"""A module containing all extensions to be loaded into the bot on startup."""
diff --git a/metricity/exts/event_listeners/__init__.py b/metricity/exts/event_listeners/__init__.py
new file mode 100644
index 0000000..2830bfb
--- /dev/null
+++ b/metricity/exts/event_listeners/__init__.py
@@ -0,0 +1 @@
+"""A module containing all extensions around listening to events and storing them in the database."""
diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py
new file mode 100644
index 0000000..4de01fd
--- /dev/null
+++ b/metricity/exts/event_listeners/_utils.py
@@ -0,0 +1,42 @@
+import discord
+
+from metricity import models
+
+
+async def insert_thread(thread: discord.Thread) -> None:
+ """Insert the given thread to the database."""
+ await models.Thread.create(
+ id=str(thread.id),
+ parent_channel_id=str(thread.parent_id),
+ name=thread.name,
+ archived=thread.archived,
+ auto_archive_duration=thread.auto_archive_duration,
+ locked=thread.locked,
+ type=thread.type.name,
+ )
+
+
+async def sync_message(message: discord.Message, from_thread: bool) -> None:
+ """Sync the given message with the database."""
+ if await models.Message.get(str(message.id)):
+ return
+
+ if from_thread and not message.channel.parent:
+ # This is a forum channel, not currently supported by Discord.py. Ignore it.
+ return
+
+ args = {
+ "id": str(message.id),
+ "channel_id": str(message.channel.id),
+ "author_id": str(message.author.id),
+ "created_at": message.created_at
+ }
+
+ if from_thread:
+ thread = message.channel
+ args["channel_id"] = str(thread.parent_id)
+ args["thread_id"] = str(thread.id)
+ if not await models.Thread.get(str(thread.id)):
+ await insert_thread(thread)
+
+ await models.Message.create(**args)
diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py
new file mode 100644
index 0000000..18eb79a
--- /dev/null
+++ b/metricity/exts/event_listeners/guild_listeners.py
@@ -0,0 +1,187 @@
+"""An ext to listen for guild (and guild channel) events and syncs them to the database."""
+
+import discord
+from botcore.utils import logging, scheduling
+from discord.ext import commands
+
+from metricity import models
+from metricity.bot import Bot
+from metricity.config import BotConfig
+from metricity.database import db
+from metricity.exts.event_listeners import _utils
+
+log = logging.get_logger(__name__)
+
+
+class GuildListeners(commands.Cog):
+ """Listen for guild (and guild channel) events and sync them to the database."""
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+ scheduling.create_task(self.sync_guild())
+
+ async def sync_guild(self) -> None:
+ """Sync all channels and members in the guild."""
+ await self.bot.wait_until_guild_available()
+
+ guild = self.bot.get_guild(self.bot.guild_id)
+ await self.sync_channels(guild)
+
+ log.info("Beginning thread archive state synchronisation process")
+ await self.sync_thread_archive_state(guild)
+
+ log.info("Beginning user synchronisation process")
+ await models.User.update.values(in_guild=False).gino.status()
+
+ users = []
+ for user in guild.members:
+ users.append({
+ "id": str(user.id),
+ "name": user.name,
+ "avatar_hash": getattr(user.avatar, "key", None),
+ "guild_avatar_hash": getattr(user.guild_avatar, "key", None),
+ "joined_at": user.joined_at,
+ "created_at": user.created_at,
+ "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles],
+ "bot": user.bot,
+ "in_guild": True,
+ "public_flags": dict(user.public_flags),
+ "pending": user.pending
+ })
+
+ log.info(f"Performing bulk upsert of {len(users)} rows")
+
+ user_chunks = discord.utils.as_chunks(users, 500)
+
+ for chunk in user_chunks:
+ log.info(f"Upserting chunk of {len(chunk)}")
+ await models.User.bulk_upsert(chunk)
+
+ log.info("User upsert complete")
+
+ self.bot.sync_process_complete.set()
+
+ @staticmethod
+ async def sync_thread_archive_state(guild: discord.Guild) -> None:
+ """Sync the archive state of all threads in the database with the state in guild."""
+ active_thread_ids = [str(thread.id) for thread in guild.threads]
+ async with db.transaction() as tx:
+ async for db_thread in tx.connection.iterate(models.Thread.query):
+ await db_thread.update(archived=db_thread.id not in active_thread_ids).apply()
+
+ async def sync_channels(self, guild: discord.Guild) -> None:
+ """Sync channels and categories with the database."""
+ self.bot.channel_sync_in_progress.clear()
+
+ log.info("Beginning category synchronisation process")
+
+ for channel in guild.channels:
+ if isinstance(channel, discord.CategoryChannel):
+ if db_cat := await models.Category.get(str(channel.id)):
+ await db_cat.update(name=channel.name).apply()
+ else:
+ await models.Category.create(id=str(channel.id), name=channel.name)
+
+ log.info("Category synchronisation process complete, synchronising channels")
+
+ for channel in guild.channels:
+ if channel.category:
+ if channel.category.id in BotConfig.ignore_categories:
+ continue
+
+ if not isinstance(channel, discord.CategoryChannel):
+ category_id = str(channel.category.id) if channel.category else None
+ # Cast to bool so is_staff is False if channel.category is None
+ is_staff = bool(
+ channel.category
+ and channel.category.id in BotConfig.staff_categories
+ )
+ if db_chan := await models.Channel.get(str(channel.id)):
+ await db_chan.update(
+ name=channel.name,
+ category_id=category_id,
+ is_staff=is_staff,
+ ).apply()
+ else:
+ await models.Channel.create(
+ id=str(channel.id),
+ name=channel.name,
+ category_id=category_id,
+ is_staff=is_staff,
+ )
+
+ log.info("Channel synchronisation process complete, synchronising threads")
+
+ for thread in guild.threads:
+ if thread.parent and thread.parent.category:
+ if thread.parent.category.id in BotConfig.ignore_categories:
+ continue
+ else:
+ # This is a forum channel, not currently supported by Discord.py. Ignore it.
+ continue
+
+ if db_thread := await models.Thread.get(str(thread.id)):
+ await db_thread.update(
+ name=thread.name,
+ archived=thread.archived,
+ auto_archive_duration=thread.auto_archive_duration,
+ locked=thread.locked,
+ type=thread.type.name,
+ ).apply()
+ else:
+ await _utils.insert_thread(thread)
+
+ log.info("Thread synchronisation process complete, finished synchronising guild.")
+ self.bot.channel_sync_in_progress.set()
+
+ @commands.Cog.listener()
+ async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None:
+ """Sync the channels when one is created."""
+ if channel.guild.id != BotConfig.guild_id:
+ return
+
+ await self.sync_channels(channel.guild)
+
+ @commands.Cog.listener()
+ async def on_guild_channel_update(
+ self,
+ _before: discord.abc.GuildChannel,
+ channel: discord.abc.GuildChannel
+ ) -> None:
+ """Sync the channels when one is updated."""
+ if channel.guild.id != BotConfig.guild_id:
+ return
+
+ await self.sync_channels(channel.guild)
+
+ @commands.Cog.listener()
+ async def on_thread_create(self, thread: discord.Thread) -> None:
+ """Sync channels when a thread is created."""
+ if thread.guild.id != BotConfig.guild_id:
+ return
+
+ await self.sync_channels(thread.guild)
+
+ @commands.Cog.listener()
+ async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None:
+ """Sync the channels when one is updated."""
+ if thread.guild.id != BotConfig.guild_id:
+ return
+
+ await self.sync_channels(thread.guild)
+
+ @commands.Cog.listener()
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """Synchronize the user table with the Discord users."""
+ log.info(f"Received guild available for {guild.id}")
+
+ if guild.id != BotConfig.guild_id:
+ log.info("Guild was not the configured guild, discarding event")
+ return
+
+ await self.sync_guild()
+
+
+async def setup(bot: Bot) -> None:
+ """Load the GuildListeners cog."""
+ await bot.add_cog(GuildListeners(bot))
diff --git a/metricity/exts/event_listeners/member_listeners.py b/metricity/exts/event_listeners/member_listeners.py
new file mode 100644
index 0000000..f3074ce
--- /dev/null
+++ b/metricity/exts/event_listeners/member_listeners.py
@@ -0,0 +1,124 @@
+"""An ext to listen for member events and syncs them to the database."""
+
+import discord
+from asyncpg.exceptions import UniqueViolationError
+from discord.ext import commands
+
+from metricity.bot import Bot
+from metricity.config import BotConfig
+from metricity.models import User
+
+
+class MemberListeners(commands.Cog):
+ """Listen for member events and sync them to the database."""
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+
+ @commands.Cog.listener()
+ async def on_member_remove(self, member: discord.Member) -> None:
+ """On a user leaving the server mark in_guild as False."""
+ await self.bot.sync_process_complete.wait()
+
+ if member.guild.id != BotConfig.guild_id:
+ return
+
+ if db_user := await User.get(str(member.id)):
+ await db_user.update(
+ in_guild=False
+ ).apply()
+
+ @commands.Cog.listener()
+ async def on_member_join(self, member: discord.Member) -> None:
+ """On a user joining the server add them to the database."""
+ await self.bot.sync_process_complete.wait()
+
+ if member.guild.id != BotConfig.guild_id:
+ return
+
+ if db_user := await User.get(str(member.id)):
+ await db_user.update(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
+ public_flags=dict(member.public_flags),
+ pending=member.pending,
+ in_guild=True
+ ).apply()
+ else:
+ try:
+ await User.create(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
+ public_flags=dict(member.public_flags),
+ pending=member.pending,
+ in_guild=True
+ )
+ except UniqueViolationError:
+ pass
+
+ @commands.Cog.listener()
+ async def on_member_update(self, before: discord.Member, member: discord.Member) -> None:
+ """When a member updates their profile, update the DB record."""
+ await self.bot.sync_process_complete.wait()
+
+ if member.guild.id != BotConfig.guild_id:
+ return
+
+ # Joined at will be null if we are not ready to process events yet
+ if not member.joined_at:
+ return
+
+ roles = set([role.id for role in member.roles])
+
+ if db_user := await User.get(str(member.id)):
+ if (
+ db_user.name != member.name or
+ db_user.avatar_hash != getattr(member.avatar, "key", None) or
+ db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or
+ BotConfig.staff_role_id in
+ [role.id for role in member.roles] != db_user.is_staff
+ or db_user.pending is not member.pending
+ ):
+ await db_user.update(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in roles,
+ public_flags=dict(member.public_flags),
+ in_guild=True,
+ pending=member.pending
+ ).apply()
+ else:
+ try:
+ await User.create(
+ id=str(member.id),
+ name=member.name,
+ avatar_hash=getattr(member.avatar, "key", None),
+ guild_avatar_hash=getattr(member.guild_avatar, "key", None),
+ joined_at=member.joined_at,
+ created_at=member.created_at,
+ is_staff=BotConfig.staff_role_id in roles,
+ public_flags=dict(member.public_flags),
+ in_guild=True,
+ pending=member.pending
+ )
+ except UniqueViolationError:
+ pass
+
+
+async def setup(bot: Bot) -> None:
+ """Load the MemberListeners cog."""
+ await bot.add_cog(MemberListeners(bot))
diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py
new file mode 100644
index 0000000..b446e26
--- /dev/null
+++ b/metricity/exts/event_listeners/message_listeners.py
@@ -0,0 +1,62 @@
+"""An ext to listen for message events and syncs them to the database."""
+
+import discord
+from discord.ext import commands
+
+from metricity.bot import Bot
+from metricity.config import BotConfig
+from metricity.exts.event_listeners import _utils
+from metricity.models import Message, User
+
+
+class MessageListeners(commands.Cog):
+ """Listen for message events and sync them to the database."""
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+
+ @commands.Cog.listener()
+ async def on_message(self, message: discord.Message) -> None:
+ """Add a message to the table when one is sent providing the author has accepted."""
+ if not message.guild:
+ return
+
+ if message.author.bot:
+ return
+
+ if message.guild.id != BotConfig.guild_id:
+ return
+
+ if message.type in (discord.MessageType.thread_created, discord.MessageType.auto_moderation_action):
+ return
+
+ await self.bot.sync_process_complete.wait()
+ await self.bot.channel_sync_in_progress.wait()
+
+ if not await User.get(str(message.author.id)):
+ return
+
+ cat_id = message.channel.category.id if message.channel.category else None
+ if cat_id in BotConfig.ignore_categories:
+ return
+
+ from_thread = isinstance(message.channel, discord.Thread)
+ await _utils.sync_message(message, from_thread=from_thread)
+
+ @commands.Cog.listener()
+ async def on_raw_message_delete(self, message: discord.RawMessageDeleteEvent) -> None:
+ """If a message is deleted and we have a record of it set the is_deleted flag."""
+ if message := await Message.get(str(message.message_id)):
+ await message.update(is_deleted=True).apply()
+
+ @commands.Cog.listener()
+ async def on_raw_bulk_message_delete(self, messages: discord.RawBulkMessageDeleteEvent) -> None:
+ """If messages are deleted in bulk and we have a record of them set the is_deleted flag."""
+ for message_id in messages.message_ids:
+ if message := await Message.get(str(message_id)):
+ await message.update(is_deleted=True).apply()
+
+
+async def setup(bot: Bot) -> None:
+ """Load the MessageListeners cog."""
+ await bot.add_cog(MessageListeners(bot))