aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--metricity/exts/event_listeners/_syncer_utils.py152
-rw-r--r--metricity/exts/event_listeners/_utils.py38
-rw-r--r--metricity/exts/event_listeners/guild_listeners.py212
-rw-r--r--metricity/exts/event_listeners/message_listeners.py4
-rw-r--r--metricity/exts/event_listeners/startup_sync.py115
-rw-r--r--pyproject.toml2
6 files changed, 277 insertions, 246 deletions
diff --git a/metricity/exts/event_listeners/_syncer_utils.py b/metricity/exts/event_listeners/_syncer_utils.py
new file mode 100644
index 0000000..258a165
--- /dev/null
+++ b/metricity/exts/event_listeners/_syncer_utils.py
@@ -0,0 +1,152 @@
+import discord
+from pydis_core.utils import logging
+from sqlalchemy import update
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from metricity import models
+from metricity.bot import Bot
+from metricity.config import BotConfig
+from metricity.database import async_session
+
+log = logging.get_logger(__name__)
+
+
+def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None:
+ """Insert the given thread to the database session."""
+ sess.add(models.Thread(
+ id=str(thread.id),
+ parent_channel_id=str(thread.parent_id),
+ name=thread.name,
+ archived=thread.archived,
+ auto_archive_duration=thread.auto_archive_duration,
+ locked=thread.locked,
+ type=thread.type.name,
+ created_at=thread.created_at,
+ ))
+
+
+async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None:
+ """Sync the given message with the database."""
+ if await sess.get(models.Message, str(message.id)):
+ return
+
+ args = {
+ "id": str(message.id),
+ "channel_id": str(message.channel.id),
+ "author_id": str(message.author.id),
+ "created_at": message.created_at,
+ }
+
+ if from_thread:
+ thread = message.channel
+ args["channel_id"] = str(thread.parent_id)
+ args["thread_id"] = str(thread.id)
+
+ sess.add(models.Message(**args))
+
+
+async def sync_channels(bot: Bot, guild: discord.Guild) -> None:
+ """Sync channels and categories with the database."""
+ bot.channel_sync_in_progress.clear()
+
+ log.info("Beginning category synchronisation process")
+
+ async with async_session() as sess:
+ for channel in guild.channels:
+ if isinstance(channel, discord.CategoryChannel):
+ if existing_cat := await sess.get(models.Category, str(channel.id)):
+ existing_cat.name = channel.name
+ else:
+ sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False))
+
+ await sess.commit()
+
+ log.info("Category synchronisation process complete, synchronising deleted categories")
+
+ async with async_session() as sess:
+ await sess.execute(
+ update(models.Category)
+ .where(~models.Category.id.in_(
+ [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)],
+ ))
+ .values(deleted=True),
+ )
+ await sess.commit()
+
+ log.info("Deleted category synchronisation process complete, synchronising channels")
+
+ async with async_session() as sess:
+ for channel in guild.channels:
+ if channel.category and channel.category.id in BotConfig.ignore_categories:
+ continue
+
+ if not isinstance(channel, discord.CategoryChannel):
+ category_id = str(channel.category.id) if channel.category else None
+ # Cast to bool so is_staff is False if channel.category is None
+ is_staff = channel.id in BotConfig.staff_channels or bool(
+ channel.category and channel.category.id in BotConfig.staff_categories,
+ )
+ if db_chan := await sess.get(models.Channel, str(channel.id)):
+ db_chan.name = channel.name
+ else:
+ sess.add(models.Channel(
+ id=str(channel.id),
+ name=channel.name,
+ category_id=category_id,
+ is_staff=is_staff,
+ deleted=False,
+ ))
+
+ await sess.commit()
+
+ log.info("Channel synchronisation process complete, synchronising deleted channels")
+
+ async with async_session() as sess:
+ await sess.execute(
+ update(models.Channel)
+ .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels]))
+ .values(deleted=True),
+ )
+ await sess.commit()
+
+ log.info("Deleted channel synchronisation process complete, synchronising threads")
+
+ async with async_session() as sess:
+ for thread in guild.threads:
+ if thread.parent and thread.parent.category:
+ if thread.parent.category.id in BotConfig.ignore_categories:
+ continue
+ else:
+ # This is a forum channel, not currently supported by Discord.py. Ignore it.
+ continue
+
+ if db_thread := await sess.get(models.Thread, str(thread.id)):
+ db_thread.name = thread.name
+ db_thread.archived = thread.archived
+ db_thread.auto_archive_duration = thread.auto_archive_duration
+ db_thread.locked = thread.locked
+ db_thread.type = thread.type.name
+ else:
+ insert_thread(thread, sess)
+ await sess.commit()
+
+ log.info("Thread synchronisation process complete, finished synchronising guild.")
+ bot.channel_sync_in_progress.set()
+
+
+async def sync_thread_archive_state(guild: discord.Guild) -> None:
+ """Sync the archive state of all threads in the database with the state in guild."""
+ active_thread_ids = [str(thread.id) for thread in guild.threads]
+
+ async with async_session() as sess:
+ await sess.execute(
+ update(models.Thread)
+ .where(models.Thread.id.in_(active_thread_ids))
+ .values(archived=False),
+ )
+ await sess.execute(
+ update(models.Thread)
+ .where(~models.Thread.id.in_(active_thread_ids))
+ .values(archived=True),
+ )
+ await sess.commit()
diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py
deleted file mode 100644
index 4006ea2..0000000
--- a/metricity/exts/event_listeners/_utils.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import discord
-from sqlalchemy.ext.asyncio import AsyncSession
-
-from metricity import models
-
-
-def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None:
- """Insert the given thread to the database session."""
- sess.add(models.Thread(
- id=str(thread.id),
- parent_channel_id=str(thread.parent_id),
- name=thread.name,
- archived=thread.archived,
- auto_archive_duration=thread.auto_archive_duration,
- locked=thread.locked,
- type=thread.type.name,
- created_at=thread.created_at,
- ))
-
-
-async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None:
- """Sync the given message with the database."""
- if await sess.get(models.Message, str(message.id)):
- return
-
- args = {
- "id": str(message.id),
- "channel_id": str(message.channel.id),
- "author_id": str(message.author.id),
- "created_at": message.created_at,
- }
-
- if from_thread:
- thread = message.channel
- args["channel_id"] = str(thread.parent_id)
- args["thread_id"] = str(thread.id)
-
- sess.add(models.Message(**args))
diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py
index 79cd8f4..db976b2 100644
--- a/metricity/exts/event_listeners/guild_listeners.py
+++ b/metricity/exts/event_listeners/guild_listeners.py
@@ -1,18 +1,12 @@
"""An ext to listen for guild (and guild channel) events and syncs them to the database."""
-import math
-
import discord
from discord.ext import commands
-from pydis_core.utils import logging, scheduling
-from sqlalchemy import column, update
-from sqlalchemy.dialects.postgresql import insert
+from pydis_core.utils import logging
-from metricity import models
from metricity.bot import Bot
from metricity.config import BotConfig
-from metricity.database import async_session
-from metricity.exts.event_listeners import _utils
+from metricity.exts.event_listeners import _syncer_utils
log = logging.get_logger(__name__)
@@ -22,187 +16,6 @@ class GuildListeners(commands.Cog):
def __init__(self, bot: Bot) -> None:
self.bot = bot
- scheduling.create_task(self.sync_guild())
-
- async def sync_guild(self) -> None:
- """Sync all channels and members in the guild."""
- await self.bot.wait_until_guild_available()
-
- guild = self.bot.get_guild(self.bot.guild_id)
- await self.sync_channels(guild)
-
- log.info("Beginning thread archive state synchronisation process")
- await self.sync_thread_archive_state(guild)
-
- log.info("Beginning user synchronisation process")
- async with async_session() as sess:
- await sess.execute(update(models.User).values(in_guild=False))
- await sess.commit()
-
- users = [
- {
- "id": str(user.id),
- "name": user.name,
- "avatar_hash": getattr(user.avatar, "key", None),
- "guild_avatar_hash": getattr(user.guild_avatar, "key", None),
- "joined_at": user.joined_at,
- "created_at": user.created_at,
- "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles],
- "bot": user.bot,
- "in_guild": True,
- "public_flags": dict(user.public_flags),
- "pending": user.pending,
- }
- for user in guild.members
- ]
-
- user_chunks = discord.utils.as_chunks(users, 500)
- created = 0
- updated = 0
- total_users = len(users)
-
- log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500))
-
- async with async_session() as sess:
- for chunk in user_chunks:
- qs = insert(models.User).returning(column("xmax")).values(chunk)
-
- update_cols = [
- "name",
- "avatar_hash",
- "guild_avatar_hash",
- "joined_at",
- "is_staff",
- "bot",
- "in_guild",
- "public_flags",
- "pending",
- ]
-
- res = await sess.execute(qs.on_conflict_do_update(
- index_elements=[models.User.id],
- set_={k: getattr(qs.excluded, k) for k in update_cols},
- ))
-
- objs = list(res)
-
- created += [obj[0] == 0 for obj in objs].count(True)
- updated += [obj[0] != 0 for obj in objs].count(True)
-
- log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining",
- created, updated, created + updated, total_users - (created + updated))
-
- await sess.commit()
-
- log.info("User upsert complete")
-
- self.bot.sync_process_complete.set()
-
- @staticmethod
- async def sync_thread_archive_state(guild: discord.Guild) -> None:
- """Sync the archive state of all threads in the database with the state in guild."""
- active_thread_ids = [str(thread.id) for thread in guild.threads]
-
- async with async_session() as sess:
- await sess.execute(
- update(models.Thread)
- .where(models.Thread.id.in_(active_thread_ids))
- .values(archived=False),
- )
- await sess.execute(
- update(models.Thread)
- .where(~models.Thread.id.in_(active_thread_ids))
- .values(archived=True),
- )
- await sess.commit()
-
- async def sync_channels(self, guild: discord.Guild) -> None:
- """Sync channels and categories with the database."""
- self.bot.channel_sync_in_progress.clear()
-
- log.info("Beginning category synchronisation process")
-
- async with async_session() as sess:
- for channel in guild.channels:
- if isinstance(channel, discord.CategoryChannel):
- if existing_cat := await sess.get(models.Category, str(channel.id)):
- existing_cat.name = channel.name
- else:
- sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False))
-
- await sess.commit()
-
- log.info("Category synchronisation process complete, synchronising deleted categories")
-
- async with async_session() as sess:
- await sess.execute(
- update(models.Category)
- .where(~models.Category.id.in_(
- [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)],
- ))
- .values(deleted=True),
- )
- await sess.commit()
-
- log.info("Deleted category synchronisation process complete, synchronising channels")
-
- async with async_session() as sess:
- for channel in guild.channels:
- if channel.category and channel.category.id in BotConfig.ignore_categories:
- continue
-
- if not isinstance(channel, discord.CategoryChannel):
- category_id = str(channel.category.id) if channel.category else None
- # Cast to bool so is_staff is False if channel.category is None
- is_staff = channel.id in BotConfig.staff_channels or bool(
- channel.category and channel.category.id in BotConfig.staff_categories,
- )
- if db_chan := await sess.get(models.Channel, str(channel.id)):
- db_chan.name = channel.name
- else:
- sess.add(models.Channel(
- id=str(channel.id),
- name=channel.name,
- category_id=category_id,
- is_staff=is_staff,
- deleted=False,
- ))
-
- await sess.commit()
-
- log.info("Channel synchronisation process complete, synchronising deleted channels")
-
- async with async_session() as sess:
- await sess.execute(
- update(models.Channel)
- .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels]))
- .values(deleted=True),
- )
- await sess.commit()
-
- log.info("Deleted channel synchronisation process complete, synchronising threads")
-
- async with async_session() as sess:
- for thread in guild.threads:
- if thread.parent and thread.parent.category:
- if thread.parent.category.id in BotConfig.ignore_categories:
- continue
- else:
- # This is a forum channel, not currently supported by Discord.py. Ignore it.
- continue
-
- if db_thread := await sess.get(models.Thread, str(thread.id)):
- db_thread.name = thread.name
- db_thread.archived = thread.archived
- db_thread.auto_archive_duration = thread.auto_archive_duration
- db_thread.locked = thread.locked
- db_thread.type = thread.type.name
- else:
- _utils.insert_thread(thread, sess)
- await sess.commit()
-
- log.info("Thread synchronisation process complete, finished synchronising guild.")
- self.bot.channel_sync_in_progress.set()
@commands.Cog.listener()
async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None:
@@ -210,7 +23,7 @@ class GuildListeners(commands.Cog):
if channel.guild.id != BotConfig.guild_id:
return
- await self.sync_channels(channel.guild)
+ await _syncer_utils.sync_channels(self.bot, channel.guild)
@commands.Cog.listener()
async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel) -> None:
@@ -218,7 +31,7 @@ class GuildListeners(commands.Cog):
if channel.guild.id != BotConfig.guild_id:
return
- await self.sync_channels(channel.guild)
+ await _syncer_utils.sync_channels(self.bot, channel.guild)
@commands.Cog.listener()
async def on_guild_channel_update(
@@ -230,7 +43,7 @@ class GuildListeners(commands.Cog):
if channel.guild.id != BotConfig.guild_id:
return
- await self.sync_channels(channel.guild)
+ await _syncer_utils.sync_channels(self.bot, channel.guild)
@commands.Cog.listener()
async def on_thread_create(self, thread: discord.Thread) -> None:
@@ -238,7 +51,7 @@ class GuildListeners(commands.Cog):
if thread.guild.id != BotConfig.guild_id:
return
- await self.sync_channels(thread.guild)
+ await _syncer_utils.sync_channels(self.bot, thread.guild)
@commands.Cog.listener()
async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None:
@@ -246,18 +59,7 @@ class GuildListeners(commands.Cog):
if thread.guild.id != BotConfig.guild_id:
return
- await self.sync_channels(thread.guild)
-
- @commands.Cog.listener()
- async def on_guild_available(self, guild: discord.Guild) -> None:
- """Synchronize the user table with the Discord users."""
- log.info("Received guild available for %d", guild.id)
-
- if guild.id != BotConfig.guild_id:
- log.info("Guild was not the configured guild, discarding event")
- return
-
- await self.sync_guild()
+ await _syncer_utils.sync_channels(self.bot, thread.guild)
async def setup(bot: Bot) -> None:
diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py
index a71e53f..917b13c 100644
--- a/metricity/exts/event_listeners/message_listeners.py
+++ b/metricity/exts/event_listeners/message_listeners.py
@@ -7,7 +7,7 @@ from sqlalchemy import update
from metricity.bot import Bot
from metricity.config import BotConfig
from metricity.database import async_session
-from metricity.exts.event_listeners import _utils
+from metricity.exts.event_listeners import _syncer_utils
from metricity.models import Message, User
@@ -44,7 +44,7 @@ class MessageListeners(commands.Cog):
return
from_thread = isinstance(message.channel, discord.Thread)
- await _utils.sync_message(message, sess, from_thread=from_thread)
+ await _syncer_utils.sync_message(message, sess, from_thread=from_thread)
await sess.commit()
diff --git a/metricity/exts/event_listeners/startup_sync.py b/metricity/exts/event_listeners/startup_sync.py
new file mode 100644
index 0000000..0f6264f
--- /dev/null
+++ b/metricity/exts/event_listeners/startup_sync.py
@@ -0,0 +1,115 @@
+"""An ext to sync the guild when the bot starts up."""
+
+import math
+
+import discord
+from discord.ext import commands
+from pydis_core.utils import logging, scheduling
+from sqlalchemy import column, update
+from sqlalchemy.dialects.postgresql import insert
+
+from metricity import models
+from metricity.bot import Bot
+from metricity.config import BotConfig
+from metricity.database import async_session
+from metricity.exts.event_listeners import _syncer_utils
+
+log = logging.get_logger(__name__)
+
+
+class StartupSyncer(commands.Cog):
+ """Sync the guild on bot startup."""
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+ scheduling.create_task(self.sync_guild())
+
+ async def sync_guild(self) -> None:
+ """Sync all channels and members in the guild."""
+ await self.bot.wait_until_guild_available()
+
+ guild = self.bot.get_guild(self.bot.guild_id)
+ await _syncer_utils.sync_channels(self.bot, guild)
+
+ log.info("Beginning thread archive state synchronisation process")
+ await _syncer_utils.sync_thread_archive_state(guild)
+
+ log.info("Beginning user synchronisation process")
+ async with async_session() as sess:
+ await sess.execute(update(models.User).values(in_guild=False))
+ await sess.commit()
+
+ users = (
+ {
+ "id": str(user.id),
+ "name": user.name,
+ "avatar_hash": getattr(user.avatar, "key", None),
+ "guild_avatar_hash": getattr(user.guild_avatar, "key", None),
+ "joined_at": user.joined_at,
+ "created_at": user.created_at,
+ "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles],
+ "bot": user.bot,
+ "in_guild": True,
+ "public_flags": dict(user.public_flags),
+ "pending": user.pending,
+ }
+ for user in guild.members
+ )
+
+ user_chunks = discord.utils.as_chunks(users, 500)
+ created = 0
+ updated = 0
+ total_users = len(guild.members)
+
+ log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500))
+
+ async with async_session() as sess:
+ for chunk in user_chunks:
+ qs = insert(models.User).returning(column("xmax")).values(chunk)
+
+ update_cols = [
+ "name",
+ "avatar_hash",
+ "guild_avatar_hash",
+ "joined_at",
+ "is_staff",
+ "bot",
+ "in_guild",
+ "public_flags",
+ "pending",
+ ]
+
+ res = await sess.execute(qs.on_conflict_do_update(
+ index_elements=[models.User.id],
+ set_={k: getattr(qs.excluded, k) for k in update_cols},
+ ))
+
+ objs = list(res)
+
+ created += [obj[0] == 0 for obj in objs].count(True)
+ updated += [obj[0] != 0 for obj in objs].count(True)
+
+ log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining",
+ created, updated, created + updated, total_users - (created + updated))
+
+ await sess.commit()
+
+ log.info("User upsert complete")
+
+ self.bot.sync_process_complete.set()
+
+ @commands.Cog.listener()
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """Synchronize the user table with the Discord users."""
+ log.info("Received guild available for %d", guild.id)
+
+ if guild.id != BotConfig.guild_id:
+ log.info("Guild was not the configured guild, discarding event")
+ return
+
+ await self.sync_guild()
+
+
+async def setup(bot: Bot) -> None:
+ """Load the GuildListeners cog."""
+ await bot.add_cog(StartupSyncer(bot))
diff --git a/pyproject.toml b/pyproject.toml
index e817933..5c5e61c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "metricity"
-version = "2.5.1"
+version = "2.6.0"
description = "Advanced metric collection for the Python Discord server"
authors = ["Joe Banks <[email protected]>"]
license = "MIT"