diff options
| author | 2020-03-03 09:15:28 -0800 | |
|---|---|---|
| committer | 2020-03-03 09:15:28 -0800 | |
| commit | 96d0481ef21c943bf833a141390876ee2c67d3f2 (patch) | |
| tree | 4d8fdbdfd380724f6e699bde445b7b5d144ce3c1 | |
| parent | Adding helpers to the Filtering whitelist (diff) | |
| parent | Merge branch 'master' into bug/backend/b748/resolver-in-coro (diff) | |
Merge pull request #750 from python-discord/bug/backend/b748/resolver-in-coro
Create AsyncResolver inside a coroutine to avoid DeprecationWarning
| -rw-r--r-- | bot/api.py | 40 | ||||
| -rw-r--r-- | bot/bot.py | 84 |
2 files changed, 93 insertions, 31 deletions
diff --git a/bot/api.py b/bot/api.py index e59916114..4b8520582 100644 --- a/bot/api.py +++ b/bot/api.py @@ -52,7 +52,7 @@ class APIClient: self._ready = asyncio.Event(loop=loop) self._creation_task = None - self._session_args = kwargs + self._default_session_kwargs = kwargs self.recreate() @@ -60,25 +60,41 @@ class APIClient: def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" - async def _create_session(self) -> None: - """Create the aiohttp session and set the ready event.""" - self.session = aiohttp.ClientSession(**self._session_args) + async def _create_session(self, **session_kwargs) -> None: + """ + Create the aiohttp session with `session_kwargs` and set the ready event. + + `session_kwargs` is merged with `_default_session_kwargs` and overwrites its values. + If an open session already exists, it will first be closed. + """ + await self.close() + self.session = aiohttp.ClientSession(**{**self._default_session_kwargs, **session_kwargs}) self._ready.set() async def close(self) -> None: """Close the aiohttp session and unset the ready event.""" - if not self._ready.is_set(): - return + if self.session: + await self.session.close() - await self.session.close() self._ready.clear() - def recreate(self) -> None: - """Schedule the aiohttp session to be created if it's been closed.""" - if self.session is None or self.session.closed: + def recreate(self, force: bool = False, **session_kwargs) -> None: + """ + Schedule the aiohttp session to be created with `session_kwargs` if it's been closed. + + If `force` is True, the session will be recreated even if an open one exists. If a task to + create the session is pending, it will be cancelled. + + `session_kwargs` is merged with the kwargs given when the `APIClient` was created and + overwrites those default kwargs. + """ + if force or self.session is None or self.session.closed: + if force and self._creation_task: + self._creation_task.cancel() + # Don't schedule a task if one is already in progress. - if self._creation_task is None or self._creation_task.done(): - self._creation_task = self.loop.create_task(self._create_session()) + if force or self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session(**session_kwargs)) async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" diff --git a/bot/bot.py b/bot/bot.py index 19b9035c4..950ac6751 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,6 +1,7 @@ import asyncio import logging import socket +import warnings from typing import Optional import aiohttp @@ -17,20 +18,20 @@ class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, **kwargs): - # Use asyncio for DNS resolution instead of threads so threads aren't spammed. - # Use AF_INET as its socket family to prevent HTTPS related problems both locally - # and in production. - self._connector = aiohttp.TCPConnector( - resolver=aiohttp.AsyncResolver(), - family=socket.AF_INET, - ) + if "connector" in kwargs: + warnings.warn( + "If login() is called (or the bot is started), the connector will be overwritten " + "with an internal one" + ) - super().__init__(*args, connector=self._connector, **kwargs) - - self._guild_available = asyncio.Event() + super().__init__(*args, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.api_client = api.APIClient(loop=self.loop, connector=self._connector) + self.api_client = api.APIClient(loop=self.loop) + + self._connector = None + self._resolver = None + self._guild_available = asyncio.Event() def add_cog(self, cog: commands.Cog) -> None: """Adds a "cog" to the bot and logs the operation.""" @@ -38,22 +39,67 @@ class Bot(commands.Bot): log.info(f"Cog loaded: {cog.qualified_name}") def clear(self) -> None: - """Clears the internal state of the bot and resets the API client.""" + """ + Clears the internal state of the bot and recreates the connector and sessions. + + Will cause a DeprecationWarning if called outside a coroutine. + """ + # Because discord.py recreates the HTTPClient session, may as well follow suit and recreate + # our own stuff here too. + self._recreate() super().clear() - self.api_client.recreate() async def close(self) -> None: - """Close the aiohttp session after closing the Discord connection.""" + """Close the Discord connection and the aiohttp session, connector, and resolver.""" await super().close() - await self.http_session.close() await self.api_client.close() - async def start(self, *args, **kwargs) -> None: - """Open an aiohttp session before logging in and connecting to Discord.""" - self.http_session = aiohttp.ClientSession(connector=self._connector) + if self.http_session: + await self.http_session.close() + + if self._connector: + await self._connector.close() - await super().start(*args, **kwargs) + if self._resolver: + await self._resolver.close() + + async def login(self, *args, **kwargs) -> None: + """Re-create the connector and set up sessions before logging into Discord.""" + self._recreate() + await super().login(*args, **kwargs) + + def _recreate(self) -> None: + """Re-create the connector, aiohttp session, and the APIClient.""" + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Doesn't seem to have any state with regards to being closed, so no need to worry? + self._resolver = aiohttp.AsyncResolver() + + # Its __del__ does send a warning but it doesn't always show up for some reason. + if self._connector and not self._connector._closed: + log.warning( + "The previous connector was not closed; it will remain open and be overwritten" + ) + + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self._connector = aiohttp.TCPConnector( + resolver=self._resolver, + family=socket.AF_INET, + ) + + # Client.login() will call HTTPClient.static_login() which will create a session using + # this connector attribute. + self.http.connector = self._connector + + # Its __del__ does send a warning but it doesn't always show up for some reason. + if self.http_session and not self.http_session.closed: + log.warning( + "The previous session was not closed; it will remain open and be overwritten" + ) + + self.http_session = aiohttp.ClientSession(connector=self._connector) + self.api_client.recreate(force=True, connector=self._connector) async def on_guild_available(self, guild: discord.Guild) -> None: """ |