aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/__main__.py26
-rw-r--r--bot/bot.py30
2 files changed, 32 insertions, 24 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index ea7c43a12..84bc7094b 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,18 +1,11 @@
-import asyncio
-import logging
-import socket
-
import discord
-from aiohttp import AsyncResolver, ClientSession, TCPConnector
-from discord.ext.commands import Bot, when_mentioned_or
+from discord.ext.commands import when_mentioned_or
from bot import patches
-from bot.api import APIClient, APILoggingHandler
+from bot.bot import Bot
from bot.constants import Bot as BotConfig, DEBUG_MODE
-log = logging.getLogger('bot')
-
bot = Bot(
command_prefix=when_mentioned_or(BotConfig.prefix),
activity=discord.Game(name="Commands: !help"),
@@ -20,18 +13,6 @@ bot = Bot(
max_messages=10_000,
)
-# Global aiohttp session for all cogs
-# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads
-# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod.
-bot.http_session = ClientSession(
- connector=TCPConnector(
- resolver=AsyncResolver(),
- family=socket.AF_INET,
- )
-)
-bot.api_client = APIClient(loop=asyncio.get_event_loop())
-log.addHandler(APILoggingHandler(bot.api_client))
-
# Internal/debug
bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
@@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'):
patches.message_edited_at.apply_patch()
bot.run(BotConfig.token)
-
-# This calls a coroutine, so it doesn't do anything at the moment.
-# bot.http_session.close() # Close the aiohttp session when the bot finishes running
diff --git a/bot/bot.py b/bot/bot.py
new file mode 100644
index 000000000..05734ac1d
--- /dev/null
+++ b/bot/bot.py
@@ -0,0 +1,30 @@
+import asyncio
+import logging
+import socket
+
+import aiohttp
+from discord.ext import commands
+
+from bot import api
+
+log = logging.getLogger('bot')
+
+
+class Bot(commands.Bot):
+ """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Global aiohttp session for all cogs
+ # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads
+ # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod.
+ self.http_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(
+ resolver=aiohttp.AsyncResolver(),
+ family=socket.AF_INET,
+ )
+ )
+
+ self.api_client = api.APIClient(loop=asyncio.get_event_loop())
+ log.addHandler(api.APILoggingHandler(self.api_client))