diff options
author | 2022-07-09 21:49:28 +0100 | |
---|---|---|
committer | 2022-07-09 21:49:28 +0100 | |
commit | d3470a98c964152530933217cf58a6dca2bf7d94 (patch) | |
tree | 4de598f4a87f3777fc2dfb1c3608e3beffd8ccdb | |
parent | Move TZDateTime to avoid circular import (diff) |
Migrate metricity to use BotBase from botcore
-rw-r--r-- | metricity/__init__.py | 18 | ||||
-rw-r--r-- | metricity/__main__.py | 44 | ||||
-rw-r--r-- | metricity/bot.py | 438 | ||||
-rw-r--r-- | tox.ini | 8 |
4 files changed, 85 insertions, 423 deletions
diff --git a/metricity/__init__.py b/metricity/__init__.py index 5fecffc..9216f05 100644 --- a/metricity/__init__.py +++ b/metricity/__init__.py @@ -1,12 +1,20 @@ """Metric collection for the Python Discord server.""" +import asyncio import logging +import os +from typing import TYPE_CHECKING + import coloredlogs +from botcore.utils import apply_monkey_patches from metricity.config import PythonConfig -__version__ = "1.3.0" +if TYPE_CHECKING: + from metricity.bot import Bot + +__version__ = "1.4.0" # Set root log level logging.basicConfig(level=PythonConfig.log_level) @@ -18,3 +26,11 @@ logging.getLogger("discord.client").setLevel(PythonConfig.discord_log_level) # Gino has an obnoxiously loud log for all queries executed, not great when inserting # tens of thousands of users, so we can disable that (it's just a SQLAlchemy logger) logging.getLogger("gino.engine._SAEngine").setLevel(logging.WARNING) + +# On Windows, the selector event loop is required for aiodns. +if os.name == "nt": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +apply_monkey_patches() + +instance: "Bot" = None # Global Bot instance. diff --git a/metricity/__main__.py b/metricity/__main__.py index bc711b3..71fe8a7 100644 --- a/metricity/__main__.py +++ b/metricity/__main__.py @@ -1,9 +1,49 @@ """Entry point for the Metricity application.""" -from metricity.bot import bot +import asyncio + +import aiohttp +import discord +from discord.ext import commands + +import metricity +from metricity.bot import Bot from metricity.config import BotConfig +async def main() -> None: + """Entry async method for starting the bot.""" + intents = discord.Intents( + guilds=True, + members=True, + bans=False, + emojis=False, + integrations=False, + webhooks=False, + invites=False, + voice_states=False, + presences=False, + messages=True, + reactions=False, + typing=False + ) + + async with aiohttp.ClientSession() as session: + metricity.instance = Bot( + guild_id=BotConfig.guild_id, + http_session=session, + command_prefix=commands.when_mentioned, + activity=discord.Game(f"Metricity {metricity.__version__}"), + intents=intents, + max_messages=None, + allowed_mentions=None, + allowed_roles=None, + help_command=None, + ) + async with metricity.instance as _bot: + await _bot.start(BotConfig.token) + + def start() -> None: """Start the Metricity application.""" - bot.run(BotConfig.token) + asyncio.run(main()) diff --git a/metricity/bot.py b/metricity/bot.py index 73125bf..17e6a6d 100644 --- a/metricity/bot.py +++ b/metricity/bot.py @@ -1,430 +1,32 @@ """Creating and configuring a Discord client for Metricity.""" import asyncio -import logging -from typing import Any, Generator, List -from asyncpg.exceptions import UniqueViolationError -from discord import ( - CategoryChannel, - Game, - Guild, - Intents, - Member, - Message as DiscordMessage, - MessageType, - RawBulkMessageDeleteEvent, - RawMessageDeleteEvent, - Thread as ThreadChannel, - VoiceChannel, -) -from discord.abc import Messageable -from discord.ext.commands import Bot +from botcore import BotBase +from botcore.utils import logging, scheduling -from metricity import __version__ -from metricity.config import BotConfig -from metricity.database import connect, db -from metricity.models import Category, Channel, Message, Thread, User +from metricity import exts +from metricity.database import connect -log = logging.getLogger(__name__) +log = logging.get_logger(__name__) -intents = Intents( - guilds=True, - members=True, - bans=False, - emojis=False, - integrations=False, - webhooks=False, - invites=False, - voice_states=False, - presences=False, - messages=True, - reactions=False, - typing=False -) +class Bot(BotBase): + """A subclass of `botcore.BotBase` that implements bot-specific functions.""" -bot = Bot( - command_prefix="", - help_command=None, - intents=intents, - max_messages=None, - activity=Game(f"Metricity {__version__}") -) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) -sync_process_complete = asyncio.Event() -channel_sync_in_progress = asyncio.Event() -db_ready = asyncio.Event() + self.sync_process_complete = asyncio.Event() + self.channel_sync_in_progress = asyncio.Event() + async def setup_hook(self) -> None: + """Connect to db and load cogs.""" + await super().setup_hook() + log.info(f"Metricity is online, logged in as {self.user}") + await connect() + scheduling.create_task(self.load_extensions(exts)) -async def insert_thread(thread: ThreadChannel) -> None: - """Insert the given thread to the database.""" - await Thread.create( - id=str(thread.id), - parent_channel_id=str(thread.parent_id), - name=thread.name, - archived=thread.archived, - auto_archive_duration=thread.auto_archive_duration, - locked=thread.locked, - type=thread.type.name, - ) - - -async def sync_channels(guild: Guild) -> None: - """Sync channels and categories with the database.""" - channel_sync_in_progress.clear() - - log.info("Beginning category synchronisation process") - - for channel in guild.channels: - if isinstance(channel, CategoryChannel): - if db_cat := await Category.get(str(channel.id)): - await db_cat.update(name=channel.name).apply() - else: - await Category.create(id=str(channel.id), name=channel.name) - - log.info("Category synchronisation process complete, synchronising channels") - - for channel in guild.channels: - if channel.category: - if channel.category.id in BotConfig.ignore_categories: - continue - - if ( - not isinstance(channel, CategoryChannel) and - not isinstance(channel, VoiceChannel) - ): - category_id = str(channel.category.id) if channel.category else None - # Cast to bool so is_staff is False if channel.category is None - is_staff = bool( - channel.category - and channel.category.id in BotConfig.staff_categories - ) - if db_chan := await Channel.get(str(channel.id)): - await db_chan.update( - name=channel.name, - category_id=category_id, - is_staff=is_staff, - ).apply() - else: - await Channel.create( - id=str(channel.id), - name=channel.name, - category_id=category_id, - is_staff=is_staff, - ) - - log.info("Channel synchronisation process complete, synchronising threads") - - for thread in guild.threads: - if thread.parent and thread.parent.category: - if thread.parent.category.id in BotConfig.ignore_categories: - continue - else: - # This is a forum channel, not currently supported by Discord.py. Ignore it. - continue - - if db_thread := await Thread.get(str(thread.id)): - await db_thread.update( - name=thread.name, - archived=thread.archived, - auto_archive_duration=thread.auto_archive_duration, - locked=thread.locked, - type=thread.type.name, - ).apply() - else: - await insert_thread(thread) - channel_sync_in_progress.set() - - -async def sync_thread_archive_state(guild: Guild) -> None: - """Sync the archive state of all threads in the database with the state in guild.""" - active_thread_ids = [str(thread.id) for thread in guild.threads] - async with db.transaction() as tx: - async for db_thread in tx.connection.iterate(Thread.query): - await db_thread.update(archived=db_thread.id not in active_thread_ids).apply() - - -def gen_chunks( - chunk_src: List[Any], - chunk_size: int -) -> Generator[List[Any], None, List[Any]]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(chunk_src), chunk_size): - yield chunk_src[i:i + chunk_size] - - -async def on_ready() -> None: - """Initiate tasks when the bot comes online.""" - log.info(f"Metricity is online, logged in as {bot.user}") - await connect() - db_ready.set() - - -async def on_guild_channel_create(channel: Messageable) -> None: - """Sync the channels when one is created.""" - await db_ready.wait() - - if channel.guild.id != BotConfig.guild_id: - return - - await sync_channels(channel.guild) - - -async def on_guild_channel_update(_before: Messageable, channel: Messageable) -> None: - """Sync the channels when one is updated.""" - await db_ready.wait() - - if channel.guild.id != BotConfig.guild_id: - return - - await sync_channels(channel.guild) - - -async def on_thread_join(thread: ThreadChannel) -> None: - """ - Sync channels when thread join is triggered. - - Unlike what the name suggested, this is also triggered when: - - A thread is created. - - An un-cached thread is un-archived. - """ - await db_ready.wait() - - if thread.guild.id != BotConfig.guild_id: - return - - await sync_channels(thread.guild) - - -async def on_thread_update(_before: Messageable, thread: Messageable) -> None: - """Sync the channels when one is updated.""" - await db_ready.wait() - - if thread.guild.id != BotConfig.guild_id: - return - - await sync_channels(thread.guild) - - -async def on_guild_available(guild: Guild) -> None: - """Synchronize the user table with the Discord users.""" - await db_ready.wait() - - log.info(f"Received guild available for {guild.id}") - - if guild.id != BotConfig.guild_id: - return log.info("Guild was not the configured guild, discarding event") - - await sync_channels(guild) - - log.info("Beginning thread archive state synchronisation process") - await sync_thread_archive_state(guild) - - log.info("Beginning user synchronisation process") - - await User.update.values(in_guild=False).gino.status() - - users = [] - - for user in guild.members: - users.append({ - "id": str(user.id), - "name": user.name, - "avatar_hash": getattr(user.avatar, "key", None), - "guild_avatar_hash": getattr(user.guild_avatar, "key", None), - "joined_at": user.joined_at, - "created_at": user.created_at, - "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], - "bot": user.bot, - "in_guild": True, - "public_flags": dict(user.public_flags), - "pending": user.pending - }) - - log.info(f"Performing bulk upsert of {len(users)} rows") - - user_chunks = gen_chunks(users, 500) - - for chunk in user_chunks: - log.info(f"Upserting chunk of {len(chunk)}") - await User.bulk_upsert(chunk) - - log.info("User upsert complete") - - sync_process_complete.set() - - -async def on_member_join(member: Member) -> None: - """On a user joining the server add them to the database.""" - await db_ready.wait() - await sync_process_complete.wait() - - if member.guild.id != BotConfig.guild_id: - return - - if db_user := await User.get(str(member.id)): - await db_user.update( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], - public_flags=dict(member.public_flags), - pending=member.pending, - in_guild=True - ).apply() - else: - try: - await User.create( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in [role.id for role in member.roles], - public_flags=dict(member.public_flags), - pending=member.pending, - in_guild=True - ) - except UniqueViolationError: - pass - - -async def on_member_remove(member: Member) -> None: - """On a user leaving the server mark in_guild as False.""" - await db_ready.wait() - await sync_process_complete.wait() - - if member.guild.id != BotConfig.guild_id: - return - - if db_user := await User.get(str(member.id)): - await db_user.update( - in_guild=False - ).apply() - - -async def on_member_update(before: Member, member: Member) -> None: - """When a member updates their profile, update the DB record.""" - await sync_process_complete.wait() - - if member.guild.id != BotConfig.guild_id: - return - - # Joined at will be null if we are not ready to process events yet - if not member.joined_at: - return - - roles = set([role.id for role in member.roles]) - - if db_user := await User.get(str(member.id)): - if ( - db_user.name != member.name or - db_user.avatar_hash != getattr(member.avatar, "key", None) or - db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or - BotConfig.staff_role_id in - [role.id for role in member.roles] != db_user.is_staff - or db_user.pending is not member.pending - ): - await db_user.update( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in roles, - public_flags=dict(member.public_flags), - in_guild=True, - pending=member.pending - ).apply() - else: - try: - await User.create( - id=str(member.id), - name=member.name, - avatar_hash=getattr(member.avatar, "key", None), - guild_avatar_hash=getattr(member.guild_avatar, "key", None), - joined_at=member.joined_at, - created_at=member.created_at, - is_staff=BotConfig.staff_role_id in roles, - public_flags=dict(member.public_flags), - in_guild=True, - pending=member.pending - ) - except UniqueViolationError: - pass - - -async def on_message(message: DiscordMessage) -> None: - """Add a message to the table when one is sent providing the author has accepted.""" - await db_ready.wait() - - if not message.guild: - return - - if message.author.bot: - return - - if message.guild.id != BotConfig.guild_id: - return - - if message.type == MessageType.thread_created: - return - - await sync_process_complete.wait() - await channel_sync_in_progress.wait() - - if not await User.get(str(message.author.id)): - return - - cat_id = message.channel.category.id if message.channel.category else None - - if cat_id in BotConfig.ignore_categories: - return - - args = { - "id": str(message.id), - "channel_id": str(message.channel.id), - "author_id": str(message.author.id), - "created_at": message.created_at - } - - if isinstance(message.channel, ThreadChannel): - if not message.channel.parent: - # This is a forum channel, not currently supported by Discord.py. Ignore it. - return - thread = message.channel - args["channel_id"] = str(thread.parent_id) - args["thread_id"] = str(thread.id) - - await Message.create(**args) - - -async def on_raw_message_delete(message: RawMessageDeleteEvent) -> None: - """If a message is deleted and we have a record of it set the is_deleted flag.""" - if message := await Message.get(str(message.message_id)): - await message.update(is_deleted=True).apply() - - -async def on_raw_bulk_message_delete(messages: RawBulkMessageDeleteEvent) -> None: - """If messages are deleted in bulk and we have a record of them set the is_deleted flag.""" - for message_id in messages.message_ids: - if message := await Message.get(str(message_id)): - await message.update(is_deleted=True).apply() + async def on_error(self, event: str, *args, **kwargs) -> None: + """Log errors raised in event listeners rather than printing them to stderr.""" + log.exception(f"Unhandled exception in {event}.") @@ -2,9 +2,13 @@ max-line-length=120 application-import-names=metricity import-order-style=pycharm -exclude=alembic +exclude=alembic,.venv extend-ignore= # self params in classes. ANN101, + # args and kwargs + ANN002, ANN003, # line break before/after binary operator - W503, W504 + W503, W504, + # __init__ doc strings + D107 |