diff options
| -rw-r--r-- | metricity/__init__.py | 18 | ||||
| -rw-r--r-- | metricity/__main__.py | 44 | ||||
| -rw-r--r-- | metricity/bot.py | 438 | ||||
| -rw-r--r-- | tox.ini | 8 | 
4 files changed, 85 insertions, 423 deletions
| diff --git a/metricity/__init__.py b/metricity/__init__.py index 5fecffc..9216f05 100644 --- a/metricity/__init__.py +++ b/metricity/__init__.py @@ -1,12 +1,20 @@  """Metric collection for the Python Discord server.""" +import asyncio  import logging +import os +from typing import TYPE_CHECKING +  import coloredlogs +from botcore.utils import apply_monkey_patches  from metricity.config import PythonConfig -__version__ = "1.3.0" +if TYPE_CHECKING: +    from metricity.bot import Bot + +__version__ = "1.4.0"  # Set root log level  logging.basicConfig(level=PythonConfig.log_level) @@ -18,3 +26,11 @@ logging.getLogger("discord.client").setLevel(PythonConfig.discord_log_level)  # Gino has an obnoxiously loud log for all queries executed, not great when inserting  # tens of thousands of users, so we can disable that (it's just a SQLAlchemy logger)  logging.getLogger("gino.engine._SAEngine").setLevel(logging.WARNING) + +# On Windows, the selector event loop is required for aiodns. +if os.name == "nt": +    asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +apply_monkey_patches() + +instance: "Bot" = None  # Global Bot instance. diff --git a/metricity/__main__.py b/metricity/__main__.py index bc711b3..71fe8a7 100644 --- a/metricity/__main__.py +++ b/metricity/__main__.py @@ -1,9 +1,49 @@  """Entry point for the Metricity application.""" -from metricity.bot import bot +import asyncio + +import aiohttp +import discord +from discord.ext import commands + +import metricity +from metricity.bot import Bot  from metricity.config import BotConfig +async def main() -> None: +    """Entry async method for starting the bot.""" +    intents = discord.Intents( +        guilds=True, +        members=True, +        bans=False, +        emojis=False, +        integrations=False, +        webhooks=False, +        invites=False, +        voice_states=False, +        presences=False, +        messages=True, +        reactions=False, +        typing=False +    ) + +    async with aiohttp.ClientSession() as session: +        metricity.instance = Bot( +            guild_id=BotConfig.guild_id, +            http_session=session, +            command_prefix=commands.when_mentioned, +            activity=discord.Game(f"Metricity {metricity.__version__}"), +            intents=intents, +            max_messages=None, +            allowed_mentions=None, +            allowed_roles=None, +            help_command=None, +        ) +        async with metricity.instance as _bot: +            await _bot.start(BotConfig.token) + +  def start() -> None:      """Start the Metricity application.""" -    bot.run(BotConfig.token) +    asyncio.run(main()) diff --git a/metricity/bot.py b/metricity/bot.py index 73125bf..17e6a6d 100644 --- a/metricity/bot.py +++ b/metricity/bot.py @@ -1,430 +1,32 @@  """Creating and configuring a Discord client for Metricity."""  import asyncio -import logging -from typing import Any, Generator, List -from asyncpg.exceptions import UniqueViolationError -from discord import ( -    CategoryChannel, -    Game, -    Guild, -    Intents, -    Member, -    Message as DiscordMessage, -    MessageType, -    RawBulkMessageDeleteEvent, -    RawMessageDeleteEvent, -    Thread as ThreadChannel, -    VoiceChannel, -) -from discord.abc import Messageable -from discord.ext.commands import Bot +from botcore import BotBase +from botcore.utils import logging, scheduling -from metricity import __version__ -from metricity.config import BotConfig -from metricity.database import connect, db -from metricity.models import Category, Channel, Message, Thread, User +from metricity import exts +from metricity.database import connect -log = logging.getLogger(__name__) +log = logging.get_logger(__name__) -intents = Intents( -    guilds=True, -    members=True, -    bans=False, -    emojis=False, -    integrations=False, -    webhooks=False, -    invites=False, -    voice_states=False, -    presences=False, -    messages=True, -    reactions=False, -    typing=False -) +class Bot(BotBase): +    """A subclass of `botcore.BotBase` that implements bot-specific functions.""" -bot = Bot( -    command_prefix="", -    help_command=None, -    intents=intents, -    max_messages=None, -    activity=Game(f"Metricity {__version__}") -) +    def __init__(self, *args, **kwargs) -> None: +        super().__init__(*args, **kwargs) -sync_process_complete = asyncio.Event() -channel_sync_in_progress = asyncio.Event() -db_ready = asyncio.Event() +        self.sync_process_complete = asyncio.Event() +        self.channel_sync_in_progress = asyncio.Event() +    async def setup_hook(self) -> None: +        """Connect to db and load cogs.""" +        await super().setup_hook() +        log.info(f"Metricity is online, logged in as {self.user}") +        await connect() +        scheduling.create_task(self.load_extensions(exts)) -async def insert_thread(thread: ThreadChannel) -> None: -    """Insert the given thread to the database.""" -    await Thread.create( -        id=str(thread.id), -        parent_channel_id=str(thread.parent_id), -        name=thread.name, -        archived=thread.archived, -        auto_archive_duration=thread.auto_archive_duration, -        locked=thread.locked, -        type=thread.type.name, -    ) - - -async def sync_channels(guild: Guild) -> None: -    """Sync channels and categories with the database.""" -    channel_sync_in_progress.clear() - -    log.info("Beginning category synchronisation process") - -    for channel in guild.channels: -        if isinstance(channel, CategoryChannel): -            if db_cat := await Category.get(str(channel.id)): -                await db_cat.update(name=channel.name).apply() -            else: -                await Category.create(id=str(channel.id), name=channel.name) - -    log.info("Category synchronisation process complete, synchronising channels") - -    for channel in guild.channels: -        if channel.category: -            if channel.category.id in BotConfig.ignore_categories: -                continue - -        if ( -            not isinstance(channel, CategoryChannel) and -            not isinstance(channel, VoiceChannel) -        ): -            category_id = str(channel.category.id) if channel.category else None -            # Cast to bool so is_staff is False if channel.category is None -            is_staff = bool( -                channel.category -                and channel.category.id in BotConfig.staff_categories -            ) -            if db_chan := await Channel.get(str(channel.id)): -                await db_chan.update( -                    name=channel.name, -                    category_id=category_id, -                    is_staff=is_staff, -                ).apply() -            else: -                await Channel.create( -                    id=str(channel.id), -                    name=channel.name, -                    category_id=category_id, -                    is_staff=is_staff, -                ) - -    log.info("Channel synchronisation process complete, synchronising threads") - -    for thread in guild.threads: -        if thread.parent and thread.parent.category: -            if thread.parent.category.id in BotConfig.ignore_categories: -                continue -        else: -            # This is a forum channel, not currently supported by Discord.py. Ignore it. -            continue - -        if db_thread := await Thread.get(str(thread.id)): -            await db_thread.update( -                name=thread.name, -                archived=thread.archived, -                auto_archive_duration=thread.auto_archive_duration, -                locked=thread.locked, -                type=thread.type.name, -            ).apply() -        else: -            await insert_thread(thread) -    channel_sync_in_progress.set() - - -async def sync_thread_archive_state(guild: Guild) -> None: -    """Sync the archive state of all threads in the database with the state in guild.""" -    active_thread_ids = [str(thread.id) for thread in guild.threads] -    async with db.transaction() as tx: -        async for db_thread in tx.connection.iterate(Thread.query): -            await db_thread.update(archived=db_thread.id not in active_thread_ids).apply() - - -def gen_chunks( -    chunk_src: List[Any], -    chunk_size: int -) -> Generator[List[Any], None, List[Any]]: -    """Yield successive n-sized chunks from lst.""" -    for i in range(0, len(chunk_src), chunk_size): -        yield chunk_src[i:i + chunk_size] - - -async def on_ready() -> None: -    """Initiate tasks when the bot comes online.""" -    log.info(f"Metricity is online, logged in as {bot.user}") -    await connect() -    db_ready.set() - - -async def on_guild_channel_create(channel: Messageable) -> None: -    """Sync the channels when one is created.""" -    await db_ready.wait() - -    if channel.guild.id != BotConfig.guild_id: -        return - -    await sync_channels(channel.guild) - - -async def on_guild_channel_update(_before: Messageable, channel: Messageable) -> None: -    """Sync the channels when one is updated.""" -    await db_ready.wait() - -    if channel.guild.id != BotConfig.guild_id: -        return - -    await sync_channels(channel.guild) - - -async def on_thread_join(thread: ThreadChannel) -> None: -    """ -    Sync channels when thread join is triggered. - -    Unlike what the name suggested, this is also triggered when: -       - A thread is created. -       - An un-cached thread is un-archived. -    """ -    await db_ready.wait() - -    if thread.guild.id != BotConfig.guild_id: -        return - -    await sync_channels(thread.guild) - - -async def on_thread_update(_before: Messageable, thread: Messageable) -> None: -    """Sync the channels when one is updated.""" -    await db_ready.wait() - -    if thread.guild.id != BotConfig.guild_id: -        return - -    await sync_channels(thread.guild) - - -async def on_guild_available(guild: Guild) -> None: -    """Synchronize the user table with the Discord users.""" -    await db_ready.wait() - -    log.info(f"Received guild available for {guild.id}") - -    if guild.id != BotConfig.guild_id: -        return log.info("Guild was not the configured guild, discarding event") - -    await sync_channels(guild) - -    log.info("Beginning thread archive state synchronisation process") -    await sync_thread_archive_state(guild) - -    log.info("Beginning user synchronisation process") - -    await User.update.values(in_guild=False).gino.status() - -    users = [] - -    for user in guild.members: -        users.append({ -            "id": str(user.id), -            "name": user.name, -            "avatar_hash": getattr(user.avatar, "key", None), -            "guild_avatar_hash": getattr(user.guild_avatar, "key", None), -            "joined_at": user.joined_at, -            "created_at": user.created_at, -            "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], -            "bot": user.bot, -            "in_guild": True, -            "public_flags": dict(user.public_flags), -            "pending": user.pending -        }) - -    log.info(f"Performing bulk upsert of {len(users)} rows") - -    user_chunks = gen_chunks(users, 500) - -    for chunk in user_chunks: -        log.info(f"Upserting chunk of {len(chunk)}") -        await User.bulk_upsert(chunk) - -    log.info("User upsert complete") - -    sync_process_complete.set() - - -async def on_member_join(member: Member) -> None: -    """On a user joining the server add them to the database.""" -    await db_ready.wait() -    await sync_process_complete.wait() - -    if member.guild.id != BotConfig.guild_id: -        return - -    if db_user := await User.get(str(member.id)): -        await db_user.update( -            id=str(member.id), -            name=member.name, -            avatar_hash=getattr(member.avatar, "key", None), -            guild_avatar_hash=getattr(member.guild_avatar, "key", None), -            joined_at=member.joined_at, -            created_at=member.created_at, -            is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], -            public_flags=dict(member.public_flags), -            pending=member.pending, -            in_guild=True -        ).apply() -    else: -        try: -            await User.create( -                id=str(member.id), -                name=member.name, -                avatar_hash=getattr(member.avatar, "key", None), -                guild_avatar_hash=getattr(member.guild_avatar, "key", None), -                joined_at=member.joined_at, -                created_at=member.created_at, -                is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], -                public_flags=dict(member.public_flags), -                pending=member.pending, -                in_guild=True -            ) -        except UniqueViolationError: -            pass - - -async def on_member_remove(member: Member) -> None: -    """On a user leaving the server mark in_guild as False.""" -    await db_ready.wait() -    await sync_process_complete.wait() - -    if member.guild.id != BotConfig.guild_id: -        return - -    if db_user := await User.get(str(member.id)): -        await db_user.update( -            in_guild=False -        ).apply() - - -async def on_member_update(before: Member, member: Member) -> None: -    """When a member updates their profile, update the DB record.""" -    await sync_process_complete.wait() - -    if member.guild.id != BotConfig.guild_id: -        return - -    # Joined at will be null if we are not ready to process events yet -    if not member.joined_at: -        return - -    roles = set([role.id for role in member.roles]) - -    if db_user := await User.get(str(member.id)): -        if ( -            db_user.name != member.name or -            db_user.avatar_hash != getattr(member.avatar, "key", None) or -            db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or -            BotConfig.staff_role_id in -            [role.id for role in member.roles] != db_user.is_staff -            or db_user.pending is not member.pending -        ): -            await db_user.update( -                id=str(member.id), -                name=member.name, -                avatar_hash=getattr(member.avatar, "key", None), -                guild_avatar_hash=getattr(member.guild_avatar, "key", None), -                joined_at=member.joined_at, -                created_at=member.created_at, -                is_staff=BotConfig.staff_role_id in roles, -                public_flags=dict(member.public_flags), -                in_guild=True, -                pending=member.pending -            ).apply() -    else: -        try: -            await User.create( -                id=str(member.id), -                name=member.name, -                avatar_hash=getattr(member.avatar, "key", None), -                guild_avatar_hash=getattr(member.guild_avatar, "key", None), -                joined_at=member.joined_at, -                created_at=member.created_at, -                is_staff=BotConfig.staff_role_id in roles, -                public_flags=dict(member.public_flags), -                in_guild=True, -                pending=member.pending -            ) -        except UniqueViolationError: -            pass - - -async def on_message(message: DiscordMessage) -> None: -    """Add a message to the table when one is sent providing the author has accepted.""" -    await db_ready.wait() - -    if not message.guild: -        return - -    if message.author.bot: -        return - -    if message.guild.id != BotConfig.guild_id: -        return - -    if message.type == MessageType.thread_created: -        return - -    await sync_process_complete.wait() -    await channel_sync_in_progress.wait() - -    if not await User.get(str(message.author.id)): -        return - -    cat_id = message.channel.category.id if message.channel.category else None - -    if cat_id in BotConfig.ignore_categories: -        return - -    args = { -        "id": str(message.id), -        "channel_id": str(message.channel.id), -        "author_id": str(message.author.id), -        "created_at": message.created_at -    } - -    if isinstance(message.channel, ThreadChannel): -        if not message.channel.parent: -            # This is a forum channel, not currently supported by Discord.py. Ignore it. -            return -        thread = message.channel -        args["channel_id"] = str(thread.parent_id) -        args["thread_id"] = str(thread.id) - -    await Message.create(**args) - - -async def on_raw_message_delete(message: RawMessageDeleteEvent) -> None: -    """If a message is deleted and we have a record of it set the is_deleted flag.""" -    if message := await Message.get(str(message.message_id)): -        await message.update(is_deleted=True).apply() - - -async def on_raw_bulk_message_delete(messages: RawBulkMessageDeleteEvent) -> None: -    """If messages are deleted in bulk and we have a record of them set the is_deleted flag.""" -    for message_id in messages.message_ids: -        if message := await Message.get(str(message_id)): -            await message.update(is_deleted=True).apply() +    async def on_error(self, event: str, *args, **kwargs) -> None: +        """Log errors raised in event listeners rather than printing them to stderr.""" +        log.exception(f"Unhandled exception in {event}.") @@ -2,9 +2,13 @@  max-line-length=120  application-import-names=metricity  import-order-style=pycharm -exclude=alembic +exclude=alembic,.venv  extend-ignore=      # self params in classes.      ANN101, +    # args and kwargs +    ANN002, ANN003,      # line break before/after binary operator -    W503, W504 +    W503, W504, +    # __init__ doc strings +    D107 | 
