diff options
| -rw-r--r-- | bot/bot.py | 45 | 
1 files changed, 30 insertions, 15 deletions
| diff --git a/bot/bot.py b/bot/bot.py index 8f808272f..95fbae17f 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -14,18 +14,13 @@ class Bot(commands.Bot):      """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""      def __init__(self, *args, **kwargs): -        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. -        # Use AF_INET as its socket family to prevent HTTPS related problems both locally -        # and in production. -        self.connector = aiohttp.TCPConnector( -            resolver=aiohttp.AsyncResolver(), -            family=socket.AF_INET, -        ) - -        super().__init__(*args, connector=self.connector, **kwargs) +        super().__init__(*args, **kwargs)          self.http_session: Optional[aiohttp.ClientSession] = None -        self.api_client = api.APIClient(loop=self.loop, connector=self.connector) +        self.api_client = api.APIClient(loop=self.loop) + +        self._connector = None +        self._resolver = None          log.addHandler(api.APILoggingHandler(self.api_client)) @@ -35,19 +30,39 @@ class Bot(commands.Bot):          log.info(f"Cog loaded: {cog.qualified_name}")      def clear(self) -> None: -        """Clears the internal state of the bot and resets the API client.""" +        """Clears the internal state of the bot and sets the HTTPClient connector to None.""" +        self.http.connector = None  # Use the default connector.          super().clear() -        self.api_client.recreate()      async def close(self) -> None: -        """Close the aiohttp session after closing the Discord connection.""" +        """Close the Discord connection and the aiohttp session, connector, and resolver."""          await super().close()          await self.http_session.close()          await self.api_client.close() +        if self._connector: +            await self._connector.close() + +        if self._resolver: +            await self._resolver.close() +      async def start(self, *args, **kwargs) -> None: -        """Open an aiohttp session before logging in and connecting to Discord.""" -        self.http_session = aiohttp.ClientSession(connector=self.connector) +        """Set up aiohttp sessions before logging in and connecting to Discord.""" +        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. +        # Use AF_INET as its socket family to prevent HTTPS related problems both locally +        # and in production. +        self._resolver = aiohttp.AsyncResolver() +        self._connector = aiohttp.TCPConnector( +            resolver=self._resolver, +            family=socket.AF_INET, +        ) + +        # Client.login() will call HTTPClient.static_login() which will create a session using +        # this connector attribute. +        self.http.connector = self._connector + +        self.http_session = aiohttp.ClientSession(connector=self._connector) +        self.api_client.recreate(connector=self._connector)          await super().start(*args, **kwargs) | 
