diff options
| -rw-r--r-- | bot/__main__.py | 26 | ||||
| -rw-r--r-- | bot/bot.py | 30 | 
2 files changed, 32 insertions, 24 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index ea7c43a12..84bc7094b 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,18 +1,11 @@ -import asyncio -import logging -import socket -  import discord -from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord.ext.commands import Bot, when_mentioned_or +from discord.ext.commands import when_mentioned_or  from bot import patches -from bot.api import APIClient, APILoggingHandler +from bot.bot import Bot  from bot.constants import Bot as BotConfig, DEBUG_MODE -log = logging.getLogger('bot') -  bot = Bot(      command_prefix=when_mentioned_or(BotConfig.prefix),      activity=discord.Game(name="Commands: !help"), @@ -20,18 +13,6 @@ bot = Bot(      max_messages=10_000,  ) -# Global aiohttp session for all cogs -# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads -# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. -bot.http_session = ClientSession( -    connector=TCPConnector( -        resolver=AsyncResolver(), -        family=socket.AF_INET, -    ) -) -bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.addHandler(APILoggingHandler(bot.api_client)) -  # Internal/debug  bot.load_extension("bot.cogs.error_handler")  bot.load_extension("bot.cogs.filtering") @@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'):      patches.message_edited_at.apply_patch()  bot.run(BotConfig.token) - -# This calls a coroutine, so it doesn't do anything at the moment. -# bot.http_session.close()  # Close the aiohttp session when the bot finishes running diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 000000000..05734ac1d --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,30 @@ +import asyncio +import logging +import socket + +import aiohttp +from discord.ext import commands + +from bot import api + +log = logging.getLogger('bot') + + +class Bot(commands.Bot): +    """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + +    def __init__(self, *args, **kwargs): +        super().__init__(*args, **kwargs) + +        # Global aiohttp session for all cogs +        # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads +        # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. +        self.http_session = aiohttp.ClientSession( +            connector=aiohttp.TCPConnector( +                resolver=aiohttp.AsyncResolver(), +                family=socket.AF_INET, +            ) +        ) + +        self.api_client = api.APIClient(loop=asyncio.get_event_loop()) +        log.addHandler(api.APILoggingHandler(self.api_client))  |