aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Chris Lovering <[email protected]>2022-07-09 21:49:28 +0100
committerGravatar Chris Lovering <[email protected]>2022-07-09 21:49:28 +0100
commitd3470a98c964152530933217cf58a6dca2bf7d94 (patch)
tree4de598f4a87f3777fc2dfb1c3608e3beffd8ccdb
parentMove TZDateTime to avoid circular import (diff)
Migrate metricity to use BotBase from botcore
-rw-r--r--metricity/__init__.py18
-rw-r--r--metricity/__main__.py44
-rw-r--r--metricity/bot.py438
-rw-r--r--tox.ini8
4 files changed, 85 insertions, 423 deletions
diff --git a/metricity/__init__.py b/metricity/__init__.py
index 5fecffc..9216f05 100644
--- a/metricity/__init__.py
+++ b/metricity/__init__.py
@@ -1,12 +1,20 @@
"""Metric collection for the Python Discord server."""
+import asyncio
import logging
+import os
+from typing import TYPE_CHECKING
+
import coloredlogs
+from botcore.utils import apply_monkey_patches
from metricity.config import PythonConfig
-__version__ = "1.3.0"
+if TYPE_CHECKING:
+ from metricity.bot import Bot
+
+__version__ = "1.4.0"
# Set root log level
logging.basicConfig(level=PythonConfig.log_level)
@@ -18,3 +26,11 @@ logging.getLogger("discord.client").setLevel(PythonConfig.discord_log_level)
# Gino has an obnoxiously loud log for all queries executed, not great when inserting
# tens of thousands of users, so we can disable that (it's just a SQLAlchemy logger)
logging.getLogger("gino.engine._SAEngine").setLevel(logging.WARNING)
+
+# On Windows, the selector event loop is required for aiodns.
+if os.name == "nt":
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
+apply_monkey_patches()
+
+instance: "Bot" = None # Global Bot instance.
diff --git a/metricity/__main__.py b/metricity/__main__.py
index bc711b3..71fe8a7 100644
--- a/metricity/__main__.py
+++ b/metricity/__main__.py
@@ -1,9 +1,49 @@
"""Entry point for the Metricity application."""
-from metricity.bot import bot
+import asyncio
+
+import aiohttp
+import discord
+from discord.ext import commands
+
+import metricity
+from metricity.bot import Bot
from metricity.config import BotConfig
+async def main() -> None:
+ """Entry async method for starting the bot."""
+ intents = discord.Intents(
+ guilds=True,
+ members=True,
+ bans=False,
+ emojis=False,
+ integrations=False,
+ webhooks=False,
+ invites=False,
+ voice_states=False,
+ presences=False,
+ messages=True,
+ reactions=False,
+ typing=False
+ )
+
+ async with aiohttp.ClientSession() as session:
+ metricity.instance = Bot(
+ guild_id=BotConfig.guild_id,
+ http_session=session,
+ command_prefix=commands.when_mentioned,
+ activity=discord.Game(f"Metricity {metricity.__version__}"),
+ intents=intents,
+ max_messages=None,
+ allowed_mentions=None,
+ allowed_roles=None,
+ help_command=None,
+ )
+ async with metricity.instance as _bot:
+ await _bot.start(BotConfig.token)
+
+
def start() -> None:
"""Start the Metricity application."""
- bot.run(BotConfig.token)
+ asyncio.run(main())
diff --git a/metricity/bot.py b/metricity/bot.py
index 73125bf..17e6a6d 100644
--- a/metricity/bot.py
+++ b/metricity/bot.py
@@ -1,430 +1,32 @@
"""Creating and configuring a Discord client for Metricity."""
import asyncio
-import logging
-from typing import Any, Generator, List
-from asyncpg.exceptions import UniqueViolationError
-from discord import (
- CategoryChannel,
- Game,
- Guild,
- Intents,
- Member,
- Message as DiscordMessage,
- MessageType,
- RawBulkMessageDeleteEvent,
- RawMessageDeleteEvent,
- Thread as ThreadChannel,
- VoiceChannel,
-)
-from discord.abc import Messageable
-from discord.ext.commands import Bot
+from botcore import BotBase
+from botcore.utils import logging, scheduling
-from metricity import __version__
-from metricity.config import BotConfig
-from metricity.database import connect, db
-from metricity.models import Category, Channel, Message, Thread, User
+from metricity import exts
+from metricity.database import connect
-log = logging.getLogger(__name__)
+log = logging.get_logger(__name__)
-intents = Intents(
- guilds=True,
- members=True,
- bans=False,
- emojis=False,
- integrations=False,
- webhooks=False,
- invites=False,
- voice_states=False,
- presences=False,
- messages=True,
- reactions=False,
- typing=False
-)
+class Bot(BotBase):
+ """A subclass of `botcore.BotBase` that implements bot-specific functions."""
-bot = Bot(
- command_prefix="",
- help_command=None,
- intents=intents,
- max_messages=None,
- activity=Game(f"Metricity {__version__}")
-)
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
-sync_process_complete = asyncio.Event()
-channel_sync_in_progress = asyncio.Event()
-db_ready = asyncio.Event()
+ self.sync_process_complete = asyncio.Event()
+ self.channel_sync_in_progress = asyncio.Event()
+ async def setup_hook(self) -> None:
+ """Connect to db and load cogs."""
+ await super().setup_hook()
+ log.info(f"Metricity is online, logged in as {self.user}")
+ await connect()
+ scheduling.create_task(self.load_extensions(exts))
-async def insert_thread(thread: ThreadChannel) -> None:
- """Insert the given thread to the database."""
- await Thread.create(
- id=str(thread.id),
- parent_channel_id=str(thread.parent_id),
- name=thread.name,
- archived=thread.archived,
- auto_archive_duration=thread.auto_archive_duration,
- locked=thread.locked,
- type=thread.type.name,
- )
-
-
-async def sync_channels(guild: Guild) -> None:
- """Sync channels and categories with the database."""
- channel_sync_in_progress.clear()
-
- log.info("Beginning category synchronisation process")
-
- for channel in guild.channels:
- if isinstance(channel, CategoryChannel):
- if db_cat := await Category.get(str(channel.id)):
- await db_cat.update(name=channel.name).apply()
- else:
- await Category.create(id=str(channel.id), name=channel.name)
-
- log.info("Category synchronisation process complete, synchronising channels")
-
- for channel in guild.channels:
- if channel.category:
- if channel.category.id in BotConfig.ignore_categories:
- continue
-
- if (
- not isinstance(channel, CategoryChannel) and
- not isinstance(channel, VoiceChannel)
- ):
- category_id = str(channel.category.id) if channel.category else None
- # Cast to bool so is_staff is False if channel.category is None
- is_staff = bool(
- channel.category
- and channel.category.id in BotConfig.staff_categories
- )
- if db_chan := await Channel.get(str(channel.id)):
- await db_chan.update(
- name=channel.name,
- category_id=category_id,
- is_staff=is_staff,
- ).apply()
- else:
- await Channel.create(
- id=str(channel.id),
- name=channel.name,
- category_id=category_id,
- is_staff=is_staff,
- )
-
- log.info("Channel synchronisation process complete, synchronising threads")
-
- for thread in guild.threads:
- if thread.parent and thread.parent.category:
- if thread.parent.category.id in BotConfig.ignore_categories:
- continue
- else:
- # This is a forum channel, not currently supported by Discord.py. Ignore it.
- continue
-
- if db_thread := await Thread.get(str(thread.id)):
- await db_thread.update(
- name=thread.name,
- archived=thread.archived,
- auto_archive_duration=thread.auto_archive_duration,
- locked=thread.locked,
- type=thread.type.name,
- ).apply()
- else:
- await insert_thread(thread)
- channel_sync_in_progress.set()
-
-
-async def sync_thread_archive_state(guild: Guild) -> None:
- """Sync the archive state of all threads in the database with the state in guild."""
- active_thread_ids = [str(thread.id) for thread in guild.threads]
- async with db.transaction() as tx:
- async for db_thread in tx.connection.iterate(Thread.query):
- await db_thread.update(archived=db_thread.id not in active_thread_ids).apply()
-
-
-def gen_chunks(
- chunk_src: List[Any],
- chunk_size: int
-) -> Generator[List[Any], None, List[Any]]:
- """Yield successive n-sized chunks from lst."""
- for i in range(0, len(chunk_src), chunk_size):
- yield chunk_src[i:i + chunk_size]
-
-
-async def on_ready() -> None:
- """Initiate tasks when the bot comes online."""
- log.info(f"Metricity is online, logged in as {bot.user}")
- await connect()
- db_ready.set()
-
-
-async def on_guild_channel_create(channel: Messageable) -> None:
- """Sync the channels when one is created."""
- await db_ready.wait()
-
- if channel.guild.id != BotConfig.guild_id:
- return
-
- await sync_channels(channel.guild)
-
-
-async def on_guild_channel_update(_before: Messageable, channel: Messageable) -> None:
- """Sync the channels when one is updated."""
- await db_ready.wait()
-
- if channel.guild.id != BotConfig.guild_id:
- return
-
- await sync_channels(channel.guild)
-
-
-async def on_thread_join(thread: ThreadChannel) -> None:
- """
- Sync channels when thread join is triggered.
-
- Unlike what the name suggested, this is also triggered when:
- - A thread is created.
- - An un-cached thread is un-archived.
- """
- await db_ready.wait()
-
- if thread.guild.id != BotConfig.guild_id:
- return
-
- await sync_channels(thread.guild)
-
-
-async def on_thread_update(_before: Messageable, thread: Messageable) -> None:
- """Sync the channels when one is updated."""
- await db_ready.wait()
-
- if thread.guild.id != BotConfig.guild_id:
- return
-
- await sync_channels(thread.guild)
-
-
-async def on_guild_available(guild: Guild) -> None:
- """Synchronize the user table with the Discord users."""
- await db_ready.wait()
-
- log.info(f"Received guild available for {guild.id}")
-
- if guild.id != BotConfig.guild_id:
- return log.info("Guild was not the configured guild, discarding event")
-
- await sync_channels(guild)
-
- log.info("Beginning thread archive state synchronisation process")
- await sync_thread_archive_state(guild)
-
- log.info("Beginning user synchronisation process")
-
- await User.update.values(in_guild=False).gino.status()
-
- users = []
-
- for user in guild.members:
- users.append({
- "id": str(user.id),
- "name": user.name,
- "avatar_hash": getattr(user.avatar, "key", None),
- "guild_avatar_hash": getattr(user.guild_avatar, "key", None),
- "joined_at": user.joined_at,
- "created_at": user.created_at,
- "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles],
- "bot": user.bot,
- "in_guild": True,
- "public_flags": dict(user.public_flags),
- "pending": user.pending
- })
-
- log.info(f"Performing bulk upsert of {len(users)} rows")
-
- user_chunks = gen_chunks(users, 500)
-
- for chunk in user_chunks:
- log.info(f"Upserting chunk of {len(chunk)}")
- await User.bulk_upsert(chunk)
-
- log.info("User upsert complete")
-
- sync_process_complete.set()
-
-
-async def on_member_join(member: Member) -> None:
- """On a user joining the server add them to the database."""
- await db_ready.wait()
- await sync_process_complete.wait()
-
- if member.guild.id != BotConfig.guild_id:
- return
-
- if db_user := await User.get(str(member.id)):
- await db_user.update(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
- public_flags=dict(member.public_flags),
- pending=member.pending,
- in_guild=True
- ).apply()
- else:
- try:
- await User.create(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in [role.id for role in member.roles],
- public_flags=dict(member.public_flags),
- pending=member.pending,
- in_guild=True
- )
- except UniqueViolationError:
- pass
-
-
-async def on_member_remove(member: Member) -> None:
- """On a user leaving the server mark in_guild as False."""
- await db_ready.wait()
- await sync_process_complete.wait()
-
- if member.guild.id != BotConfig.guild_id:
- return
-
- if db_user := await User.get(str(member.id)):
- await db_user.update(
- in_guild=False
- ).apply()
-
-
-async def on_member_update(before: Member, member: Member) -> None:
- """When a member updates their profile, update the DB record."""
- await sync_process_complete.wait()
-
- if member.guild.id != BotConfig.guild_id:
- return
-
- # Joined at will be null if we are not ready to process events yet
- if not member.joined_at:
- return
-
- roles = set([role.id for role in member.roles])
-
- if db_user := await User.get(str(member.id)):
- if (
- db_user.name != member.name or
- db_user.avatar_hash != getattr(member.avatar, "key", None) or
- db_user.guild_avatar_hash != getattr(member.guild_avatar, "key", None) or
- BotConfig.staff_role_id in
- [role.id for role in member.roles] != db_user.is_staff
- or db_user.pending is not member.pending
- ):
- await db_user.update(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in roles,
- public_flags=dict(member.public_flags),
- in_guild=True,
- pending=member.pending
- ).apply()
- else:
- try:
- await User.create(
- id=str(member.id),
- name=member.name,
- avatar_hash=getattr(member.avatar, "key", None),
- guild_avatar_hash=getattr(member.guild_avatar, "key", None),
- joined_at=member.joined_at,
- created_at=member.created_at,
- is_staff=BotConfig.staff_role_id in roles,
- public_flags=dict(member.public_flags),
- in_guild=True,
- pending=member.pending
- )
- except UniqueViolationError:
- pass
-
-
-async def on_message(message: DiscordMessage) -> None:
- """Add a message to the table when one is sent providing the author has accepted."""
- await db_ready.wait()
-
- if not message.guild:
- return
-
- if message.author.bot:
- return
-
- if message.guild.id != BotConfig.guild_id:
- return
-
- if message.type == MessageType.thread_created:
- return
-
- await sync_process_complete.wait()
- await channel_sync_in_progress.wait()
-
- if not await User.get(str(message.author.id)):
- return
-
- cat_id = message.channel.category.id if message.channel.category else None
-
- if cat_id in BotConfig.ignore_categories:
- return
-
- args = {
- "id": str(message.id),
- "channel_id": str(message.channel.id),
- "author_id": str(message.author.id),
- "created_at": message.created_at
- }
-
- if isinstance(message.channel, ThreadChannel):
- if not message.channel.parent:
- # This is a forum channel, not currently supported by Discord.py. Ignore it.
- return
- thread = message.channel
- args["channel_id"] = str(thread.parent_id)
- args["thread_id"] = str(thread.id)
-
- await Message.create(**args)
-
-
-async def on_raw_message_delete(message: RawMessageDeleteEvent) -> None:
- """If a message is deleted and we have a record of it set the is_deleted flag."""
- if message := await Message.get(str(message.message_id)):
- await message.update(is_deleted=True).apply()
-
-
-async def on_raw_bulk_message_delete(messages: RawBulkMessageDeleteEvent) -> None:
- """If messages are deleted in bulk and we have a record of them set the is_deleted flag."""
- for message_id in messages.message_ids:
- if message := await Message.get(str(message_id)):
- await message.update(is_deleted=True).apply()
+ async def on_error(self, event: str, *args, **kwargs) -> None:
+ """Log errors raised in event listeners rather than printing them to stderr."""
+ log.exception(f"Unhandled exception in {event}.")
diff --git a/tox.ini b/tox.ini
index 4778c16..4c3386e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,9 +2,13 @@
max-line-length=120
application-import-names=metricity
import-order-style=pycharm
-exclude=alembic
+exclude=alembic,.venv
extend-ignore=
# self params in classes.
ANN101,
+ # args and kwargs
+ ANN002, ANN003,
# line break before/after binary operator
- W503, W504
+ W503, W504,
+ # __init__ doc strings
+ D107