aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Chris Lovering <[email protected]>2022-03-31 20:02:21 +0100
committerGravatar Chris Lovering <[email protected]>2022-04-18 17:44:38 +0100
commit1bae068ba66fef3524e830acd092f78da6ca4544 (patch)
treeb2f7cf1589acf861d794a385fc43bca5720efca9
parentMove to async cog loading (diff)
Use BotBase from bot-core
-rw-r--r--bot/__main__.py72
-rw-r--r--bot/bot.py295
2 files changed, 84 insertions, 283 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index 0d3fce180..67d512eca 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,16 +1,80 @@
+import asyncio
+
import aiohttp
+import discord
+from async_rediscache import RedisSession
+from botcore import StartupError
+from botcore.site_api import APIClient
+from discord.ext import commands
import bot
from bot import constants
-from bot.bot import Bot, StartupError
+from bot.bot import Bot
from bot.log import get_logger, setup_sentry
setup_sentry()
+LOCALHOST = "127.0.0.1"
+
+
+async def _create_redis_session() -> RedisSession:
+ """Create and connect to a redis session."""
+ redis_session = RedisSession(
+ address=(constants.Redis.host, constants.Redis.port),
+ password=constants.Redis.password,
+ minsize=1,
+ maxsize=20,
+ use_fakeredis=constants.Redis.use_fakeredis,
+ global_namespace="bot",
+ )
+ try:
+ await redis_session.connect()
+ except OSError as e:
+ raise StartupError(e)
+ return redis_session
+
+
+async def main() -> None:
+ """Entry Async method for starting the bot."""
+ statsd_url = constants.Stats.statsd_host
+ if constants.DEBUG_MODE:
+ # Since statsd is UDP, there are no errors for sending to a down port.
+ # For this reason, setting the statsd host to 127.0.0.1 for development
+ # will effectively disable stats.
+ statsd_url = LOCALHOST
+
+ allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES})
+ intents = discord.Intents.all()
+ intents.presences = False
+ intents.dm_typing = False
+ intents.dm_reactions = False
+ intents.invites = False
+ intents.webhooks = False
+ intents.integrations = False
+
+ async with aiohttp.ClientSession() as session:
+ bot.instance = Bot(
+ guild_id=constants.Guild.id,
+ http_session=session,
+ redis_session=await _create_redis_session(),
+ statsd_url=statsd_url,
+ command_prefix=commands.when_mentioned_or(constants.Bot.prefix),
+ activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),
+ case_insensitive=True,
+ max_messages=10_000,
+ allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
+ intents=intents,
+ allowed_roles=list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}),
+ )
+ async with bot.instance as _bot:
+ _bot.api_client = APIClient(
+ site_api_url=f"{constants.URLs.site_api_schema}{constants.URLs.site_api}",
+ site_api_token=constants.Keys.site_api,
+ )
+ await _bot.start(constants.Bot.token)
+
try:
- bot.instance = Bot.create()
- bot.instance.load_extensions()
- bot.instance.run(constants.Bot.token)
+ asyncio.run(main())
except StartupError as e:
message = "Unknown Startup Error Occurred."
if isinstance(e.exception, (aiohttp.ClientConnectorError, aiohttp.ServerDisconnectedError)):
diff --git a/bot/bot.py b/bot/bot.py
index 94783a466..1239727a2 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,22 +1,15 @@
import asyncio
-import socket
-import warnings
from collections import defaultdict
-from contextlib import suppress
-from typing import Dict, List, Optional
import aiohttp
-import discord
-from async_rediscache import RedisSession
-from discord.ext import commands
+from botcore import BotBase
+from botcore.utils import scheduling
from sentry_sdk import push_scope
-from bot import api, constants
-from bot.async_stats import AsyncStatsClient
+from bot import constants, exts
from bot.log import get_logger
log = get_logger('bot')
-LOCALHOST = "127.0.0.1"
class StartupError(Exception):
@@ -27,68 +20,15 @@ class StartupError(Exception):
self.exception = base
-class Bot(commands.Bot):
- """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""
+class Bot(BotBase):
+ """A subclass of `botcore.BotBase` that implements bot-specific functions."""
- def __init__(self, *args, redis_session: RedisSession, **kwargs):
- if "connector" in kwargs:
- warnings.warn(
- "If login() is called (or the bot is started), the connector will be overwritten "
- "with an internal one"
- )
+ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.http_session: Optional[aiohttp.ClientSession] = None
- self.redis_session = redis_session
- self.api_client: Optional[api.APIClient] = None
self.filter_list_cache = defaultdict(dict)
- self._connector = None
- self._resolver = None
- self._statsd_timerhandle: asyncio.TimerHandle = None
- self._guild_available = asyncio.Event()
-
- statsd_url = constants.Stats.statsd_host
-
- if constants.DEBUG_MODE:
- # Since statsd is UDP, there are no errors for sending to a down port.
- # For this reason, setting the statsd host to 127.0.0.1 for development
- # will effectively disable stats.
- statsd_url = LOCALHOST
-
- self.stats = AsyncStatsClient(self.loop, LOCALHOST)
- self._connect_statsd(statsd_url)
-
- def _connect_statsd(self, statsd_url: str, retry_after: int = 2, attempt: int = 1) -> None:
- """Callback used to retry a connection to statsd if it should fail."""
- if attempt >= 8:
- log.error("Reached 8 attempts trying to reconnect AsyncStatsClient. Aborting")
- return
-
- try:
- self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot")
- except socket.gaierror:
- log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})")
- # Use a fallback strategy for retrying, up to 8 times.
- self._statsd_timerhandle = self.loop.call_later(
- retry_after,
- self._connect_statsd,
- statsd_url,
- retry_after * 2,
- attempt + 1
- )
-
- # All tasks that need to block closing until finished
- self.closing_tasks: List[asyncio.Task] = []
-
- async def cache_filter_list_data(self) -> None:
- """Cache all the data in the FilterList on the site."""
- full_cache = await self.api_client.get('bot/filter-lists')
-
- for item in full_cache:
- self.insert_item_into_filter_list_cache(item)
-
async def ping_services(self) -> None:
"""A helper to make sure all the services the bot relies on are available on startup."""
# Connect Site/API
@@ -105,112 +45,7 @@ class Bot(commands.Bot):
raise
await asyncio.sleep(constants.URLs.connect_cooldown)
- @classmethod
- def create(cls) -> "Bot":
- """Create and return an instance of a Bot."""
- loop = asyncio.get_event_loop()
- allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES})
-
- intents = discord.Intents.all()
- intents.presences = False
- intents.dm_typing = False
- intents.dm_reactions = False
- intents.invites = False
- intents.webhooks = False
- intents.integrations = False
-
- return cls(
- redis_session=_create_redis_session(loop),
- loop=loop,
- command_prefix=commands.when_mentioned_or(constants.Bot.prefix),
- activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),
- case_insensitive=True,
- max_messages=10_000,
- allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
- intents=intents,
- )
-
- def load_extensions(self) -> None:
- """Load all enabled extensions."""
- # Must be done here to avoid a circular import.
- from bot.utils.extensions import EXTENSIONS
-
- extensions = set(EXTENSIONS) # Create a mutable copy.
- if not constants.HelpChannels.enable:
- extensions.remove("bot.exts.help_channels")
-
- for extension in extensions:
- self.load_extension(extension)
-
- def add_cog(self, cog: commands.Cog) -> None:
- """Adds a "cog" to the bot and logs the operation."""
- super().add_cog(cog)
- log.info(f"Cog loaded: {cog.qualified_name}")
-
- def add_command(self, command: commands.Command) -> None:
- """Add `command` as normal and then add its root aliases to the bot."""
- super().add_command(command)
- self._add_root_aliases(command)
-
- def remove_command(self, name: str) -> Optional[commands.Command]:
- """
- Remove a command/alias as normal and then remove its root aliases from the bot.
-
- Individual root aliases cannot be removed by this function.
- To remove them, either remove the entire command or manually edit `bot.all_commands`.
- """
- command = super().remove_command(name)
- if command is None:
- # Even if it's a root alias, there's no way to get the Bot instance to remove the alias.
- return
-
- self._remove_root_aliases(command)
- return command
-
- def clear(self) -> None:
- """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one."""
- raise NotImplementedError("Re-using a Bot object after closing it is not supported.")
-
- async def close(self) -> None:
- """Close the Discord connection and the aiohttp session, connector, statsd client, and resolver."""
- # Done before super().close() to allow tasks finish before the HTTP session closes.
- for ext in list(self.extensions):
- with suppress(Exception):
- self.unload_extension(ext)
-
- for cog in list(self.cogs):
- with suppress(Exception):
- self.remove_cog(cog)
-
- # Wait until all tasks that have to be completed before bot is closing is done
- log.trace("Waiting for tasks before closing.")
- await asyncio.gather(*self.closing_tasks)
-
- # Now actually do full close of bot
- await super().close()
-
- if self.api_client:
- await self.api_client.close()
-
- if self.http_session:
- await self.http_session.close()
-
- if self._connector:
- await self._connector.close()
-
- if self._resolver:
- await self._resolver.close()
-
- if self.stats._transport:
- self.stats._transport.close()
-
- if self.redis_session:
- await self.redis_session.close()
-
- if self._statsd_timerhandle:
- self._statsd_timerhandle.cancel()
-
- def insert_item_into_filter_list_cache(self, item: Dict[str, str]) -> None:
+ def insert_item_into_filter_list_cache(self, item: dict[str, str]) -> None:
"""Add an item to the bots filter_list_cache."""
type_ = item["type"]
allowed = item["allowed"]
@@ -223,81 +58,26 @@ class Bot(commands.Bot):
"updated_at": item["updated_at"],
}
- async def login(self, *args, **kwargs) -> None:
- """Re-create the connector and set up sessions before logging into Discord."""
- # Use asyncio for DNS resolution instead of threads so threads aren't spammed.
- self._resolver = aiohttp.AsyncResolver()
-
- # Use AF_INET as its socket family to prevent HTTPS related problems both locally
- # and in production.
- self._connector = aiohttp.TCPConnector(
- resolver=self._resolver,
- family=socket.AF_INET,
- )
+ async def cache_filter_list_data(self) -> None:
+ """Cache all the data in the FilterList on the site."""
+ full_cache = await self.api_client.get('bot/filter-lists')
- # Client.login() will call HTTPClient.static_login() which will create a session using
- # this connector attribute.
- self.http.connector = self._connector
+ for item in full_cache:
+ self.insert_item_into_filter_list_cache(item)
- self.http_session = aiohttp.ClientSession(connector=self._connector)
- self.api_client = api.APIClient(connector=self._connector)
+ async def setup_hook(self) -> None:
+ """Default Async initialisation method for Discord.py."""
+ await super().setup_hook()
if self.redis_session.closed:
# If the RedisSession was somehow closed, we try to reconnect it
# here. Normally, this shouldn't happen.
await self.redis_session.connect()
- try:
- await self.ping_services()
- except Exception as e:
- raise StartupError(e)
-
# Build the FilterList cache
await self.cache_filter_list_data()
- await self.stats.create_socket()
- await super().login(*args, **kwargs)
-
- async def on_guild_available(self, guild: discord.Guild) -> None:
- """
- Set the internal guild available event when constants.Guild.id becomes available.
-
- If the cache appears to still be empty (no members, no channels, or no roles), the event
- will not be set.
- """
- if guild.id != constants.Guild.id:
- return
-
- if not guild.roles or not guild.members or not guild.channels:
- msg = "Guild available event was dispatched but the cache appears to still be empty!"
- log.warning(msg)
-
- try:
- webhook = await self.fetch_webhook(constants.Webhooks.dev_log)
- except discord.HTTPException as e:
- log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}")
- else:
- await webhook.send(f"<@&{constants.Roles.admin}> {msg}")
-
- return
-
- self._guild_available.set()
-
- async def on_guild_unavailable(self, guild: discord.Guild) -> None:
- """Clear the internal guild available event when constants.Guild.id becomes unavailable."""
- if guild.id != constants.Guild.id:
- return
-
- self._guild_available.clear()
-
- async def wait_until_guild_available(self) -> None:
- """
- Wait until the constants.Guild.id guild is available (and the cache is ready).
-
- The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
- gateway event before giving up and thus not populating the cache for unavailable guilds.
- """
- await self._guild_available.wait()
+ scheduling.create_task(self.load_extensions(exts))
async def on_error(self, event: str, *args, **kwargs) -> None:
"""Log errors raised in event listeners rather than printing them to stderr."""
@@ -309,46 +89,3 @@ class Bot(commands.Bot):
scope.set_extra("kwargs", kwargs)
log.exception(f"Unhandled exception in {event}.")
-
- def _add_root_aliases(self, command: commands.Command) -> None:
- """Recursively add root aliases for `command` and any of its subcommands."""
- if isinstance(command, commands.Group):
- for subcommand in command.commands:
- self._add_root_aliases(subcommand)
-
- for alias in getattr(command, "root_aliases", ()):
- if alias in self.all_commands:
- raise commands.CommandRegistrationError(alias, alias_conflict=True)
-
- self.all_commands[alias] = command
-
- def _remove_root_aliases(self, command: commands.Command) -> None:
- """Recursively remove root aliases for `command` and any of its subcommands."""
- if isinstance(command, commands.Group):
- for subcommand in command.commands:
- self._remove_root_aliases(subcommand)
-
- for alias in getattr(command, "root_aliases", ()):
- self.all_commands.pop(alias, None)
-
-
-def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession:
- """
- Create and connect to a redis session.
-
- Ensure the connection is established before returning to prevent race conditions.
- `loop` is the event loop on which to connect. The Bot should use this same event loop.
- """
- redis_session = RedisSession(
- address=(constants.Redis.host, constants.Redis.port),
- password=constants.Redis.password,
- minsize=1,
- maxsize=20,
- use_fakeredis=constants.Redis.use_fakeredis,
- global_namespace="bot",
- )
- try:
- loop.run_until_complete(redis_session.connect())
- except OSError as e:
- raise StartupError(e)
- return redis_session