diff options
| author | 2021-10-31 15:49:37 +0000 | |
|---|---|---|
| committer | 2021-12-21 19:02:14 +0000 | |
| commit | 597e723db04b50abb1b2fa542eb8bb15df9e686e (patch) | |
| tree | 6dd582ba251fec5c6d111c0a21714dbc998c2b2a | |
| parent | Add discord.py thread models & migration (diff) | |
Upsert threads when syncing channels
| -rw-r--r-- | metricity/bot.py | 55 | 
1 files changed, 50 insertions, 5 deletions
diff --git a/metricity/bot.py b/metricity/bot.py index 3d4f3cd..84f163d 100644 --- a/metricity/bot.py +++ b/metricity/bot.py @@ -6,17 +6,25 @@ from typing import Any, Generator, List  from asyncpg.exceptions import UniqueViolationError  from discord import ( -    CategoryChannel, Game, Guild, Intents, -    Member, Message as DiscordMessage, RawBulkMessageDeleteEvent, RawMessageDeleteEvent, -    VoiceChannel +    CategoryChannel, +    Game, +    Guild, +    Intents, +    Member, +    Message as DiscordMessage, +    MessageType, +    RawBulkMessageDeleteEvent, +    RawMessageDeleteEvent, +    Thread as ThreadChannel, +    VoiceChannel,  )  from discord.abc import Messageable  from discord.ext.commands import Bot  from metricity import __version__  from metricity.config import BotConfig -from metricity.database import connect -from metricity.models import Category, Channel, Message, User +from metricity.database import connect, db +from metricity.models import Category, Channel, Message, Thread, User  log = logging.getLogger(__name__) @@ -49,6 +57,19 @@ channel_sync_in_progress = asyncio.Event()  db_ready = asyncio.Event() +async def insert_thread(thread: ThreadChannel) -> None: +    """Insert the given thread to the database.""" +    await Thread.create( +        id=str(thread.id), +        parent_channel_id=str(thread.parent_id), +        name=thread.name, +        archived=thread.archived, +        auto_archive_duration=thread.auto_archive_duration, +        locked=thread.locked, +        type=thread.type.name, +    ) + +  async def sync_channels(guild: Guild) -> None:      """Sync channels and categories with the database."""      channel_sync_in_progress.clear() @@ -93,6 +114,30 @@ async def sync_channels(guild: Guild) -> None:                      ),                  ) +    log.info("Channel synchronisation process complete, synchronising threads") + +    active_thread_ids = [] +    for thread in guild.threads: +        active_thread_ids.append(thread.id) +        if thread.parent and thread.parent.category: +            if thread.parent.category.id in BotConfig.ignore_categories: +                continue + +        if db_thread := await Thread.get(str(thread.id)): +            await db_thread.update( +                name=thread.name, +                archived=thread.archived, +                auto_archive_duration=thread.auto_archive_duration, +                locked=thread.locked, +                type=thread.type.name, +            ).apply() +        else: +            await insert_thread(thread) + +    async with db.transaction(): +        async for db_thread in Thread.query.gino.iterate(): +            await db_thread.update(archived=db_thread.id in active_thread_ids).apply() +      channel_sync_in_progress.set()  |