diff options
-rw-r--r-- | bot/__main__.py | 72 | ||||
-rw-r--r-- | bot/bot.py | 295 |
2 files changed, 84 insertions, 283 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index 0d3fce180..67d512eca 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,16 +1,80 @@ +import asyncio + import aiohttp +import discord +from async_rediscache import RedisSession +from botcore import StartupError +from botcore.site_api import APIClient +from discord.ext import commands import bot from bot import constants -from bot.bot import Bot, StartupError +from bot.bot import Bot from bot.log import get_logger, setup_sentry setup_sentry() +LOCALHOST = "127.0.0.1" + + +async def _create_redis_session() -> RedisSession: + """Create and connect to a redis session.""" + redis_session = RedisSession( + address=(constants.Redis.host, constants.Redis.port), + password=constants.Redis.password, + minsize=1, + maxsize=20, + use_fakeredis=constants.Redis.use_fakeredis, + global_namespace="bot", + ) + try: + await redis_session.connect() + except OSError as e: + raise StartupError(e) + return redis_session + + +async def main() -> None: + """Entry Async method for starting the bot.""" + statsd_url = constants.Stats.statsd_host + if constants.DEBUG_MODE: + # Since statsd is UDP, there are no errors for sending to a down port. + # For this reason, setting the statsd host to 127.0.0.1 for development + # will effectively disable stats. + statsd_url = LOCALHOST + + allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) + intents = discord.Intents.all() + intents.presences = False + intents.dm_typing = False + intents.dm_reactions = False + intents.invites = False + intents.webhooks = False + intents.integrations = False + + async with aiohttp.ClientSession() as session: + bot.instance = Bot( + guild_id=constants.Guild.id, + http_session=session, + redis_session=await _create_redis_session(), + statsd_url=statsd_url, + command_prefix=commands.when_mentioned_or(constants.Bot.prefix), + activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), + case_insensitive=True, + max_messages=10_000, + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), + intents=intents, + allowed_roles=list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}), + ) + async with bot.instance as _bot: + _bot.api_client = APIClient( + site_api_url=f"{constants.URLs.site_api_schema}{constants.URLs.site_api}", + site_api_token=constants.Keys.site_api, + ) + await _bot.start(constants.Bot.token) + try: - bot.instance = Bot.create() - bot.instance.load_extensions() - bot.instance.run(constants.Bot.token) + asyncio.run(main()) except StartupError as e: message = "Unknown Startup Error Occurred." if isinstance(e.exception, (aiohttp.ClientConnectorError, aiohttp.ServerDisconnectedError)): diff --git a/bot/bot.py b/bot/bot.py index 94783a466..1239727a2 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,22 +1,15 @@ import asyncio -import socket -import warnings from collections import defaultdict -from contextlib import suppress -from typing import Dict, List, Optional import aiohttp -import discord -from async_rediscache import RedisSession -from discord.ext import commands +from botcore import BotBase +from botcore.utils import scheduling from sentry_sdk import push_scope -from bot import api, constants -from bot.async_stats import AsyncStatsClient +from bot import constants, exts from bot.log import get_logger log = get_logger('bot') -LOCALHOST = "127.0.0.1" class StartupError(Exception): @@ -27,68 +20,15 @@ class StartupError(Exception): self.exception = base -class Bot(commands.Bot): - """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" +class Bot(BotBase): + """A subclass of `botcore.BotBase` that implements bot-specific functions.""" - def __init__(self, *args, redis_session: RedisSession, **kwargs): - if "connector" in kwargs: - warnings.warn( - "If login() is called (or the bot is started), the connector will be overwritten " - "with an internal one" - ) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.http_session: Optional[aiohttp.ClientSession] = None - self.redis_session = redis_session - self.api_client: Optional[api.APIClient] = None self.filter_list_cache = defaultdict(dict) - self._connector = None - self._resolver = None - self._statsd_timerhandle: asyncio.TimerHandle = None - self._guild_available = asyncio.Event() - - statsd_url = constants.Stats.statsd_host - - if constants.DEBUG_MODE: - # Since statsd is UDP, there are no errors for sending to a down port. - # For this reason, setting the statsd host to 127.0.0.1 for development - # will effectively disable stats. - statsd_url = LOCALHOST - - self.stats = AsyncStatsClient(self.loop, LOCALHOST) - self._connect_statsd(statsd_url) - - def _connect_statsd(self, statsd_url: str, retry_after: int = 2, attempt: int = 1) -> None: - """Callback used to retry a connection to statsd if it should fail.""" - if attempt >= 8: - log.error("Reached 8 attempts trying to reconnect AsyncStatsClient. Aborting") - return - - try: - self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot") - except socket.gaierror: - log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})") - # Use a fallback strategy for retrying, up to 8 times. - self._statsd_timerhandle = self.loop.call_later( - retry_after, - self._connect_statsd, - statsd_url, - retry_after * 2, - attempt + 1 - ) - - # All tasks that need to block closing until finished - self.closing_tasks: List[asyncio.Task] = [] - - async def cache_filter_list_data(self) -> None: - """Cache all the data in the FilterList on the site.""" - full_cache = await self.api_client.get('bot/filter-lists') - - for item in full_cache: - self.insert_item_into_filter_list_cache(item) - async def ping_services(self) -> None: """A helper to make sure all the services the bot relies on are available on startup.""" # Connect Site/API @@ -105,112 +45,7 @@ class Bot(commands.Bot): raise await asyncio.sleep(constants.URLs.connect_cooldown) - @classmethod - def create(cls) -> "Bot": - """Create and return an instance of a Bot.""" - loop = asyncio.get_event_loop() - allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) - - intents = discord.Intents.all() - intents.presences = False - intents.dm_typing = False - intents.dm_reactions = False - intents.invites = False - intents.webhooks = False - intents.integrations = False - - return cls( - redis_session=_create_redis_session(loop), - loop=loop, - command_prefix=commands.when_mentioned_or(constants.Bot.prefix), - activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), - case_insensitive=True, - max_messages=10_000, - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), - intents=intents, - ) - - def load_extensions(self) -> None: - """Load all enabled extensions.""" - # Must be done here to avoid a circular import. - from bot.utils.extensions import EXTENSIONS - - extensions = set(EXTENSIONS) # Create a mutable copy. - if not constants.HelpChannels.enable: - extensions.remove("bot.exts.help_channels") - - for extension in extensions: - self.load_extension(extension) - - def add_cog(self, cog: commands.Cog) -> None: - """Adds a "cog" to the bot and logs the operation.""" - super().add_cog(cog) - log.info(f"Cog loaded: {cog.qualified_name}") - - def add_command(self, command: commands.Command) -> None: - """Add `command` as normal and then add its root aliases to the bot.""" - super().add_command(command) - self._add_root_aliases(command) - - def remove_command(self, name: str) -> Optional[commands.Command]: - """ - Remove a command/alias as normal and then remove its root aliases from the bot. - - Individual root aliases cannot be removed by this function. - To remove them, either remove the entire command or manually edit `bot.all_commands`. - """ - command = super().remove_command(name) - if command is None: - # Even if it's a root alias, there's no way to get the Bot instance to remove the alias. - return - - self._remove_root_aliases(command) - return command - - def clear(self) -> None: - """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one.""" - raise NotImplementedError("Re-using a Bot object after closing it is not supported.") - - async def close(self) -> None: - """Close the Discord connection and the aiohttp session, connector, statsd client, and resolver.""" - # Done before super().close() to allow tasks finish before the HTTP session closes. - for ext in list(self.extensions): - with suppress(Exception): - self.unload_extension(ext) - - for cog in list(self.cogs): - with suppress(Exception): - self.remove_cog(cog) - - # Wait until all tasks that have to be completed before bot is closing is done - log.trace("Waiting for tasks before closing.") - await asyncio.gather(*self.closing_tasks) - - # Now actually do full close of bot - await super().close() - - if self.api_client: - await self.api_client.close() - - if self.http_session: - await self.http_session.close() - - if self._connector: - await self._connector.close() - - if self._resolver: - await self._resolver.close() - - if self.stats._transport: - self.stats._transport.close() - - if self.redis_session: - await self.redis_session.close() - - if self._statsd_timerhandle: - self._statsd_timerhandle.cancel() - - def insert_item_into_filter_list_cache(self, item: Dict[str, str]) -> None: + def insert_item_into_filter_list_cache(self, item: dict[str, str]) -> None: """Add an item to the bots filter_list_cache.""" type_ = item["type"] allowed = item["allowed"] @@ -223,81 +58,26 @@ class Bot(commands.Bot): "updated_at": item["updated_at"], } - async def login(self, *args, **kwargs) -> None: - """Re-create the connector and set up sessions before logging into Discord.""" - # Use asyncio for DNS resolution instead of threads so threads aren't spammed. - self._resolver = aiohttp.AsyncResolver() - - # Use AF_INET as its socket family to prevent HTTPS related problems both locally - # and in production. - self._connector = aiohttp.TCPConnector( - resolver=self._resolver, - family=socket.AF_INET, - ) + async def cache_filter_list_data(self) -> None: + """Cache all the data in the FilterList on the site.""" + full_cache = await self.api_client.get('bot/filter-lists') - # Client.login() will call HTTPClient.static_login() which will create a session using - # this connector attribute. - self.http.connector = self._connector + for item in full_cache: + self.insert_item_into_filter_list_cache(item) - self.http_session = aiohttp.ClientSession(connector=self._connector) - self.api_client = api.APIClient(connector=self._connector) + async def setup_hook(self) -> None: + """Default Async initialisation method for Discord.py.""" + await super().setup_hook() if self.redis_session.closed: # If the RedisSession was somehow closed, we try to reconnect it # here. Normally, this shouldn't happen. await self.redis_session.connect() - try: - await self.ping_services() - except Exception as e: - raise StartupError(e) - # Build the FilterList cache await self.cache_filter_list_data() - await self.stats.create_socket() - await super().login(*args, **kwargs) - - async def on_guild_available(self, guild: discord.Guild) -> None: - """ - Set the internal guild available event when constants.Guild.id becomes available. - - If the cache appears to still be empty (no members, no channels, or no roles), the event - will not be set. - """ - if guild.id != constants.Guild.id: - return - - if not guild.roles or not guild.members or not guild.channels: - msg = "Guild available event was dispatched but the cache appears to still be empty!" - log.warning(msg) - - try: - webhook = await self.fetch_webhook(constants.Webhooks.dev_log) - except discord.HTTPException as e: - log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") - else: - await webhook.send(f"<@&{constants.Roles.admin}> {msg}") - - return - - self._guild_available.set() - - async def on_guild_unavailable(self, guild: discord.Guild) -> None: - """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" - if guild.id != constants.Guild.id: - return - - self._guild_available.clear() - - async def wait_until_guild_available(self) -> None: - """ - Wait until the constants.Guild.id guild is available (and the cache is ready). - - The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE - gateway event before giving up and thus not populating the cache for unavailable guilds. - """ - await self._guild_available.wait() + scheduling.create_task(self.load_extensions(exts)) async def on_error(self, event: str, *args, **kwargs) -> None: """Log errors raised in event listeners rather than printing them to stderr.""" @@ -309,46 +89,3 @@ class Bot(commands.Bot): scope.set_extra("kwargs", kwargs) log.exception(f"Unhandled exception in {event}.") - - def _add_root_aliases(self, command: commands.Command) -> None: - """Recursively add root aliases for `command` and any of its subcommands.""" - if isinstance(command, commands.Group): - for subcommand in command.commands: - self._add_root_aliases(subcommand) - - for alias in getattr(command, "root_aliases", ()): - if alias in self.all_commands: - raise commands.CommandRegistrationError(alias, alias_conflict=True) - - self.all_commands[alias] = command - - def _remove_root_aliases(self, command: commands.Command) -> None: - """Recursively remove root aliases for `command` and any of its subcommands.""" - if isinstance(command, commands.Group): - for subcommand in command.commands: - self._remove_root_aliases(subcommand) - - for alias in getattr(command, "root_aliases", ()): - self.all_commands.pop(alias, None) - - -def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession: - """ - Create and connect to a redis session. - - Ensure the connection is established before returning to prevent race conditions. - `loop` is the event loop on which to connect. The Bot should use this same event loop. - """ - redis_session = RedisSession( - address=(constants.Redis.host, constants.Redis.port), - password=constants.Redis.password, - minsize=1, - maxsize=20, - use_fakeredis=constants.Redis.use_fakeredis, - global_namespace="bot", - ) - try: - loop.run_until_complete(redis_session.connect()) - except OSError as e: - raise StartupError(e) - return redis_session |