aboutsummaryrefslogtreecommitdiffstats
path: root/botcore/_bot.py
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2022-04-02 20:11:34 +0100
committerGravatar GitHub <[email protected]>2022-04-02 20:11:34 +0100
commit765284a5e436f7786619a5ba4c312d75df5e9dee (patch)
treef4748800698892fc9940c217b5507011ec614d27 /botcore/_bot.py
parentMerge pull request #39 from python-discord/revert-disnake (diff)
parentUse imperative mood in docstrings (diff)
Merge pull request #42 from python-discord/bump-d.py-and-add-BotBase
Bump d.py and add bot base
Diffstat (limited to 'botcore/_bot.py')
-rw-r--r--botcore/_bot.py262
1 files changed, 262 insertions, 0 deletions
diff --git a/botcore/_bot.py b/botcore/_bot.py
new file mode 100644
index 00000000..a43bb344
--- /dev/null
+++ b/botcore/_bot.py
@@ -0,0 +1,262 @@
+import asyncio
+import socket
+import types
+import warnings
+from contextlib import suppress
+from typing import Optional
+
+import aiohttp
+import discord
+from discord.ext import commands
+
+from botcore.async_stats import AsyncStatsClient
+from botcore.site_api import APIClient
+from botcore.utils._extensions import walk_extensions
+from botcore.utils.logging import get_logger
+
+try:
+ from async_rediscache import RedisSession
+except ImportError:
+ RedisSession = discord.utils._MissingSentinel
+
+log = get_logger()
+
+
+class StartupError(Exception):
+ """Exception class for startup errors."""
+
+ def __init__(self, base: Exception):
+ super().__init__()
+ self.exception = base
+
+
+class BotBase(commands.Bot):
+ """A sub-class that implements many common features that Python Discord bots use."""
+
+ def __init__(
+ self,
+ *args,
+ guild_id: int,
+ allowed_roles: list,
+ http_session: aiohttp.ClientSession,
+ redis_session: Optional[RedisSession] = None,
+ **kwargs,
+ ):
+ """
+ Initialise the base bot instance.
+
+ Args:
+ guild_id: The ID of the guild use for :func:`wait_until_guild_available`.
+ allowed_roles: A list of role IDs that the bot is allowed to mention.
+ http_session (aiohttp.ClientSession): The session to use for the bot.
+ redis_session: The
+ ``[async_rediscache.RedisSession](https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession)``
+ to use for the bot.
+ """
+ super().__init__(
+ *args,
+ allowed_roles=allowed_roles,
+ **kwargs,
+ )
+
+ self.guild_id = guild_id
+ self.http_session = http_session
+
+ if redis_session and RedisSession == discord.utils._MissingSentinel:
+ warnings.warn("redis_session kwarg passed, but async-rediscache not installed!")
+ elif redis_session:
+ self.redis_session = redis_session
+
+ self.api_client: Optional[APIClient] = None
+
+ self._resolver = aiohttp.AsyncResolver()
+ self._connector = aiohttp.TCPConnector(
+ resolver=self._resolver,
+ family=socket.AF_INET,
+ )
+ self.http.connector = self._connector
+
+ self.statsd_url: Optional[str] = None
+ self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None
+ self._guild_available = asyncio.Event()
+
+ self.stats: Optional[AsyncStatsClient] = None
+
+ self.all_extensions: Optional[frozenset[str]] = None
+
+ def _connect_statsd(
+ self,
+ statsd_url: str,
+ loop: asyncio.AbstractEventLoop,
+ retry_after: int = 2,
+ attempt: int = 1
+ ) -> None:
+ """Callback used to retry a connection to statsd if it should fail."""
+ if attempt >= 8:
+ log.error("Reached 8 attempts trying to reconnect AsyncStatsClient. Aborting")
+ return
+
+ try:
+ self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot")
+ except socket.gaierror:
+ log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})")
+ # Use a fallback strategy for retrying, up to 8 times.
+ self._statsd_timerhandle = loop.call_later(
+ retry_after,
+ self._connect_statsd,
+ statsd_url,
+ retry_after * 2,
+ attempt + 1
+ )
+
+ async def load_extensions(self, module: types.ModuleType) -> None:
+ """Load all the extensions within the given module and save them to ``self.all_extensions``."""
+ self.all_extensions = walk_extensions(module)
+
+ for extension in self.all_extensions:
+ await self.load_extension(extension)
+
+ def _add_root_aliases(self, command: commands.Command) -> None:
+ """Recursively add root aliases for ``command`` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._add_root_aliases(subcommand)
+
+ for alias in getattr(command, "root_aliases", ()):
+ if alias in self.all_commands:
+ raise commands.CommandRegistrationError(alias, alias_conflict=True)
+
+ self.all_commands[alias] = command
+
+ def _remove_root_aliases(self, command: commands.Command) -> None:
+ """Recursively remove root aliases for ``command`` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._remove_root_aliases(subcommand)
+
+ for alias in getattr(command, "root_aliases", ()):
+ self.all_commands.pop(alias, None)
+
+ async def add_cog(self, cog: commands.Cog) -> None:
+ """Add the given ``cog`` to the bot and log the operation."""
+ await super().add_cog(cog)
+ log.info(f"Cog loaded: {cog.qualified_name}")
+
+ def add_command(self, command: commands.Command) -> None:
+ """Add ``command`` as normal and then add its root aliases to the bot."""
+ super().add_command(command)
+ self._add_root_aliases(command)
+
+ def remove_command(self, name: str) -> Optional[commands.Command]:
+ """
+ Remove a command/alias as normal and then remove its root aliases from the bot.
+
+ Individual root aliases cannot be removed by this function.
+ To remove them, either remove the entire command or manually edit `bot.all_commands`.
+ """
+ command = super().remove_command(name)
+ if command is None:
+ # Even if it's a root alias, there's no way to get the Bot instance to remove the alias.
+ return None
+
+ self._remove_root_aliases(command)
+ return command
+
+ def clear(self) -> None:
+ """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one."""
+ raise NotImplementedError("Re-using a Bot object after closing it is not supported.")
+
+ async def on_guild_unavailable(self, guild: discord.Guild) -> None:
+ """Clear the internal guild available event when self.guild_id becomes unavailable."""
+ if guild.id != self.guild_id:
+ return
+
+ self._guild_available.clear()
+
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """
+ Set the internal guild available event when self.guild_id becomes available.
+
+ If the cache appears to still be empty (no members, no channels, or no roles), the event
+ will not be set and `guild_available_but_cache_empty` event will be emitted.
+ """
+ if guild.id != self.guild_id:
+ return
+
+ if not guild.roles or not guild.members or not guild.channels:
+ msg = "Guild available event was dispatched but the cache appears to still be empty!"
+ self.log_to_dev_log(msg)
+ return
+
+ self._guild_available.set()
+
+ async def log_to_dev_log(self, message: str) -> None:
+ """Log the given message to #dev-log."""
+ ...
+
+ async def wait_until_guild_available(self) -> None:
+ """
+ Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready).
+
+ The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
+ gateway event before giving up and thus not populating the cache for unavailable guilds.
+ """
+ await self._guild_available.wait()
+
+ async def setup_hook(self) -> None:
+ """
+ An async init to startup generic services.
+
+ Connects to statsd, and calls
+ :func:`AsyncStatsClient.create_socket <botcore.async_stats.AsyncStatsClient.create_socket>`
+ and :func:`ping_services`.
+ """
+ loop = asyncio.get_running_loop()
+
+ self._connect_statsd(self.statsd_url, loop)
+ self.stats = AsyncStatsClient(loop, "127.0.0.1")
+ await self.stats.create_socket()
+
+ try:
+ await self.ping_services()
+ except Exception as e:
+ raise StartupError(e)
+
+ async def ping_services() -> None:
+ """Ping all required services on setup to ensure they are up before starting."""
+ ...
+
+ async def close(self) -> None:
+ """Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver."""
+ # Done before super().close() to allow tasks finish before the HTTP session closes.
+ for ext in list(self.extensions):
+ with suppress(Exception):
+ await self.unload_extension(ext)
+
+ for cog in list(self.cogs):
+ with suppress(Exception):
+ await self.remove_cog(cog)
+
+ # Now actually do full close of bot
+ await super().close()
+
+ if self.api_client:
+ await self.api_client.close()
+
+ if self.http_session:
+ await self.http_session.close()
+
+ if self._connector:
+ await self._connector.close()
+
+ if self._resolver:
+ await self._resolver.close()
+
+ if self.stats._transport:
+ self.stats._transport.close()
+
+ if getattr(self.redis_session, None):
+ await self.redis_session.close()
+
+ if self._statsd_timerhandle:
+ self._statsd_timerhandle.cancel()