diff options
Diffstat (limited to '')
-rw-r--r-- | botcore/__init__.py | 3 | ||||
-rw-r--r-- | botcore/_bot.py | 265 | ||||
-rw-r--r-- | botcore/utils/__init__.py | 3 | ||||
-rw-r--r-- | botcore/utils/_extensions.py (renamed from botcore/utils/extensions.py) | 9 | ||||
-rw-r--r-- | docs/conf.py | 8 |
5 files changed, 280 insertions, 8 deletions
diff --git a/botcore/__init__.py b/botcore/__init__.py index 59081e57..f0c4e6bb 100644 --- a/botcore/__init__.py +++ b/botcore/__init__.py @@ -1,12 +1,15 @@ """Useful utilities and tools for Discord bot development.""" from botcore import async_stats, exts, site_api, utils +from botcore._bot import BotBase, StartupError __all__ = [ async_stats, + BotBase, exts, utils, site_api, + StartupError, ] __all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/_bot.py b/botcore/_bot.py new file mode 100644 index 00000000..28f5a1a2 --- /dev/null +++ b/botcore/_bot.py @@ -0,0 +1,265 @@ +import asyncio +import socket +import types +from abc import abstractmethod +from contextlib import suppress +from typing import Optional + +import aiohttp +import discord +from async_rediscache import RedisSession +from discord.ext import commands + +from botcore.async_stats import AsyncStatsClient +from botcore.site_api import APIClient +from botcore.utils._extensions import walk_extensions +from botcore.utils.logging import get_logger + +log = get_logger() + + +class StartupError(Exception): + """Exception class for startup errors.""" + + def __init__(self, base: Exception): + super().__init__() + self.exception = base + + +class BotBase(commands.Bot): + """A sub-class that implements many common features that Python Discord bots use.""" + + def __init__( + self, + *args, + guild_id: int, + prefix: str, + allowed_roles: list, + intents: discord.Intents, + http_session: aiohttp.ClientSession, + redis_session: Optional[RedisSession] = None, + **kwargs, + ): + """ + Initialise the base bot instance. + + Args: + guild_id: The ID of the guild use for :func:`wait_until_guild_available`. + prefix: The prefix to use for the bot. + allowed_roles: A list of role IDs that the bot is allowed to mention. + intents: The :obj:`discord.Intents` to use for the bot. + http_session (aiohttp.ClientSession): The session to use for the bot. + redis_session: The + ``[async_rediscache.RedisSession](https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession)`` + to use for the bot. + """ + super().__init__( + *args, + prefix=prefix, + allowed_roles=allowed_roles, + intents=intents, + **kwargs, + ) + + self.guild_id = guild_id + self.http_session = http_session + if redis_session: + self.redis_session = redis_session + + self.api_client: Optional[APIClient] = None + + self._resolver = aiohttp.AsyncResolver() + self._connector = aiohttp.TCPConnector( + resolver=self._resolver, + family=socket.AF_INET, + ) + self.http.connector = self._connector + + self.statsd_url: Optional[str] = None + self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None + self._guild_available = asyncio.Event() + + self.stats: Optional[AsyncStatsClient] = None + + def _connect_statsd( + self, + statsd_url: str, + loop: asyncio.AbstractEventLoop, + retry_after: int = 2, + attempt: int = 1 + ) -> None: + """Callback used to retry a connection to statsd if it should fail.""" + if attempt >= 8: + log.error("Reached 8 attempts trying to reconnect AsyncStatsClient. Aborting") + return + + try: + self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot") + except socket.gaierror: + log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})") + # Use a fallback strategy for retrying, up to 8 times. + self._statsd_timerhandle = loop.call_later( + retry_after, + self._connect_statsd, + statsd_url, + retry_after * 2, + attempt + 1 + ) + + # All tasks that need to block closing until finished + self.closing_tasks: list[asyncio.Task] = [] + + async def load_extensions(self, module: types.ModuleType) -> None: + """Load all the extensions within the given module.""" + for extension in walk_extensions(module): + await self.load_extension(extension) + + def _add_root_aliases(self, command: commands.Command) -> None: + """Recursively add root aliases for ``command`` and any of its subcommands.""" + if isinstance(command, commands.Group): + for subcommand in command.commands: + self._add_root_aliases(subcommand) + + for alias in getattr(command, "root_aliases", ()): + if alias in self.all_commands: + raise commands.CommandRegistrationError(alias, alias_conflict=True) + + self.all_commands[alias] = command + + def _remove_root_aliases(self, command: commands.Command) -> None: + """Recursively remove root aliases for ``command`` and any of its subcommands.""" + if isinstance(command, commands.Group): + for subcommand in command.commands: + self._remove_root_aliases(subcommand) + + for alias in getattr(command, "root_aliases", ()): + self.all_commands.pop(alias, None) + + async def add_cog(self, cog: commands.Cog) -> None: + """Adds the given ``cog`` to the bot and logs the operation.""" + await super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def add_command(self, command: commands.Command) -> None: + """Add ``command`` as normal and then add its root aliases to the bot.""" + super().add_command(command) + self._add_root_aliases(command) + + def remove_command(self, name: str) -> Optional[commands.Command]: + """ + Remove a command/alias as normal and then remove its root aliases from the bot. + + Individual root aliases cannot be removed by this function. + To remove them, either remove the entire command or manually edit `bot.all_commands`. + """ + command = super().remove_command(name) + if command is None: + # Even if it's a root alias, there's no way to get the Bot instance to remove the alias. + return None + + self._remove_root_aliases(command) + return command + + def clear(self) -> None: + """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one.""" + raise NotImplementedError("Re-using a Bot object after closing it is not supported.") + + async def on_guild_unavailable(self, guild: discord.Guild) -> None: + """Clear the internal guild available event when self.guild_id becomes unavailable.""" + if guild.id != self.guild_id: + return + + self._guild_available.clear() + + async def on_guild_available(self, guild: discord.Guild) -> None: + """ + Set the internal guild available event when self.guild_id becomes available. + + If the cache appears to still be empty (no members, no channels, or no roles), the event + will not be set and `guild_available_but_cache_empty` event will be emitted. + """ + if guild.id != self.guild_id: + return + + if not guild.roles or not guild.members or not guild.channels: + msg = "Guild available event was dispatched but the cache appears to still be empty!" + self.log_to_dev_log(msg) + return + + self._guild_available.set() + + @abstractmethod + async def log_to_dev_log(self, message: str) -> None: + """Log the given message to #dev-log.""" + ... + + async def wait_until_guild_available(self) -> None: + """ + Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready). + + The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE + gateway event before giving up and thus not populating the cache for unavailable guilds. + """ + await self._guild_available.wait() + + async def setup_hook(self) -> None: + """ + An async init to startup generic services. + + Connects to statsd, and calls + :func:`AsyncStatsClient.create_socket <botcore.async_stats.AsyncStatsClient.create_socket>` + and :func:`ping_services`. + """ + loop = asyncio.get_running_loop() + self._connect_statsd(self.statsd_url, loop) + self.stats = AsyncStatsClient(loop, "127.0.0.1") + await self.stats.create_socket() + + try: + await self.ping_services() + except Exception as e: + raise StartupError(e) + + @abstractmethod + async def ping_services() -> None: + """Ping all required services on setup to ensure they are up before starting.""" + ... + + async def close(self) -> None: + """Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver.""" + # Done before super().close() to allow tasks finish before the HTTP session closes. + for ext in list(self.extensions): + with suppress(Exception): + await self.unload_extension(ext) + + for cog in list(self.cogs): + with suppress(Exception): + await self.remove_cog(cog) + + # Wait until all tasks that have to be completed before bot is closing is done + log.trace("Waiting for tasks before closing.") + await asyncio.gather(*self.closing_tasks) + + # Now actually do full close of bot + await super().close() + + if self.api_client: + await self.api_client.close() + + if self.http_session: + await self.http_session.close() + + if self._connector: + await self._connector.close() + + if self._resolver: + await self._resolver.close() + + if self.stats._transport: + self.stats._transport.close() + + if self.redis_session: + await self.redis_session.close() + + if self._statsd_timerhandle: + self._statsd_timerhandle.cancel() diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index fe906075..7e6ea788 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,6 @@ """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, extensions, logging, members, regex, scheduling +from botcore.utils import _monkey_patches, caching, channel, logging, members, regex, scheduling def apply_monkey_patches() -> None: @@ -23,7 +23,6 @@ __all__ = [ apply_monkey_patches, caching, channel, - extensions, logging, members, regex, diff --git a/botcore/utils/extensions.py b/botcore/utils/_extensions.py index 841120c9..6848fae6 100644 --- a/botcore/utils/extensions.py +++ b/botcore/utils/_extensions.py @@ -28,14 +28,13 @@ def walk_extensions(module: types.ModuleType) -> frozenset[str]: module (types.ModuleType): The module to look for extensions in. Returns: - A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. + An iterator object, that returns a string that can be passed directly to + :obj:`discord.ext.commands.Bot.load_extension` on call to next(). """ def on_error(name: str) -> NoReturn: raise ImportError(name=name) # pragma: no cover - modules = set() - for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): if unqualify(module_info.name).startswith("_"): # Ignore module/package names starting with an underscore. @@ -47,6 +46,4 @@ def walk_extensions(module: types.ModuleType) -> frozenset[str]: # If it lacks a setup function, it's not an extension. continue - modules.add(module_info.name) - - return frozenset(modules) + yield module_info.name diff --git a/docs/conf.py b/docs/conf.py index 47c788df..8b55b58a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -107,6 +107,14 @@ def setup(app: Sphinx) -> None: app.connect("autodoc-skip-member", skip) +ignored_modules = [ + "async_rediscache", +] + +nitpick_ignore_regex = [ + ("py:.*", "|".join([f".*{entry}.*" for entry in ignored_modules])), +] + # -- Extension configuration ------------------------------------------------- # -- Options for todo extension ---------------------------------------------- |