diff options
| -rw-r--r-- | bot/bot.py | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/bot/bot.py b/bot/bot.py index 8f808272f..95fbae17f 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -14,18 +14,13 @@ class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, **kwargs): - # Use asyncio for DNS resolution instead of threads so threads aren't spammed. - # Use AF_INET as its socket family to prevent HTTPS related problems both locally - # and in production. - self.connector = aiohttp.TCPConnector( - resolver=aiohttp.AsyncResolver(), - family=socket.AF_INET, - ) - - super().__init__(*args, connector=self.connector, **kwargs) + super().__init__(*args, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + self.api_client = api.APIClient(loop=self.loop) + + self._connector = None + self._resolver = None log.addHandler(api.APILoggingHandler(self.api_client)) @@ -35,19 +30,39 @@ class Bot(commands.Bot): log.info(f"Cog loaded: {cog.qualified_name}") def clear(self) -> None: - """Clears the internal state of the bot and resets the API client.""" + """Clears the internal state of the bot and sets the HTTPClient connector to None.""" + self.http.connector = None # Use the default connector. super().clear() - self.api_client.recreate() async def close(self) -> None: - """Close the aiohttp session after closing the Discord connection.""" + """Close the Discord connection and the aiohttp session, connector, and resolver.""" await super().close() await self.http_session.close() await self.api_client.close() + if self._connector: + await self._connector.close() + + if self._resolver: + await self._resolver.close() + async def start(self, *args, **kwargs) -> None: - """Open an aiohttp session before logging in and connecting to Discord.""" - self.http_session = aiohttp.ClientSession(connector=self.connector) + """Set up aiohttp sessions before logging in and connecting to Discord.""" + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self._resolver = aiohttp.AsyncResolver() + self._connector = aiohttp.TCPConnector( + resolver=self._resolver, + family=socket.AF_INET, + ) + + # Client.login() will call HTTPClient.static_login() which will create a session using + # this connector attribute. + self.http.connector = self._connector + + self.http_session = aiohttp.ClientSession(connector=self._connector) + self.api_client.recreate(connector=self._connector) await super().start(*args, **kwargs) |