aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Mark <[email protected]>2020-03-03 09:15:28 -0800
committerGravatar GitHub <[email protected]>2020-03-03 09:15:28 -0800
commit96d0481ef21c943bf833a141390876ee2c67d3f2 (patch)
tree4d8fdbdfd380724f6e699bde445b7b5d144ce3c1
parentAdding helpers to the Filtering whitelist (diff)
parentMerge branch 'master' into bug/backend/b748/resolver-in-coro (diff)
Merge pull request #750 from python-discord/bug/backend/b748/resolver-in-coro
Create AsyncResolver inside a coroutine to avoid DeprecationWarning
-rw-r--r--bot/api.py40
-rw-r--r--bot/bot.py84
2 files changed, 93 insertions, 31 deletions
diff --git a/bot/api.py b/bot/api.py
index e59916114..4b8520582 100644
--- a/bot/api.py
+++ b/bot/api.py
@@ -52,7 +52,7 @@ class APIClient:
self._ready = asyncio.Event(loop=loop)
self._creation_task = None
- self._session_args = kwargs
+ self._default_session_kwargs = kwargs
self.recreate()
@@ -60,25 +60,41 @@ class APIClient:
def _url_for(endpoint: str) -> str:
return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}"
- async def _create_session(self) -> None:
- """Create the aiohttp session and set the ready event."""
- self.session = aiohttp.ClientSession(**self._session_args)
+ async def _create_session(self, **session_kwargs) -> None:
+ """
+ Create the aiohttp session with `session_kwargs` and set the ready event.
+
+ `session_kwargs` is merged with `_default_session_kwargs` and overwrites its values.
+ If an open session already exists, it will first be closed.
+ """
+ await self.close()
+ self.session = aiohttp.ClientSession(**{**self._default_session_kwargs, **session_kwargs})
self._ready.set()
async def close(self) -> None:
"""Close the aiohttp session and unset the ready event."""
- if not self._ready.is_set():
- return
+ if self.session:
+ await self.session.close()
- await self.session.close()
self._ready.clear()
- def recreate(self) -> None:
- """Schedule the aiohttp session to be created if it's been closed."""
- if self.session is None or self.session.closed:
+ def recreate(self, force: bool = False, **session_kwargs) -> None:
+ """
+ Schedule the aiohttp session to be created with `session_kwargs` if it's been closed.
+
+ If `force` is True, the session will be recreated even if an open one exists. If a task to
+ create the session is pending, it will be cancelled.
+
+ `session_kwargs` is merged with the kwargs given when the `APIClient` was created and
+ overwrites those default kwargs.
+ """
+ if force or self.session is None or self.session.closed:
+ if force and self._creation_task:
+ self._creation_task.cancel()
+
# Don't schedule a task if one is already in progress.
- if self._creation_task is None or self._creation_task.done():
- self._creation_task = self.loop.create_task(self._create_session())
+ if force or self._creation_task is None or self._creation_task.done():
+ self._creation_task = self.loop.create_task(self._create_session(**session_kwargs))
async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:
"""Raise ResponseCodeError for non-OK response if an exception should be raised."""
diff --git a/bot/bot.py b/bot/bot.py
index 19b9035c4..950ac6751 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,6 +1,7 @@
import asyncio
import logging
import socket
+import warnings
from typing import Optional
import aiohttp
@@ -17,20 +18,20 @@ class Bot(commands.Bot):
"""A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""
def __init__(self, *args, **kwargs):
- # Use asyncio for DNS resolution instead of threads so threads aren't spammed.
- # Use AF_INET as its socket family to prevent HTTPS related problems both locally
- # and in production.
- self._connector = aiohttp.TCPConnector(
- resolver=aiohttp.AsyncResolver(),
- family=socket.AF_INET,
- )
+ if "connector" in kwargs:
+ warnings.warn(
+ "If login() is called (or the bot is started), the connector will be overwritten "
+ "with an internal one"
+ )
- super().__init__(*args, connector=self._connector, **kwargs)
-
- self._guild_available = asyncio.Event()
+ super().__init__(*args, **kwargs)
self.http_session: Optional[aiohttp.ClientSession] = None
- self.api_client = api.APIClient(loop=self.loop, connector=self._connector)
+ self.api_client = api.APIClient(loop=self.loop)
+
+ self._connector = None
+ self._resolver = None
+ self._guild_available = asyncio.Event()
def add_cog(self, cog: commands.Cog) -> None:
"""Adds a "cog" to the bot and logs the operation."""
@@ -38,22 +39,67 @@ class Bot(commands.Bot):
log.info(f"Cog loaded: {cog.qualified_name}")
def clear(self) -> None:
- """Clears the internal state of the bot and resets the API client."""
+ """
+ Clears the internal state of the bot and recreates the connector and sessions.
+
+ Will cause a DeprecationWarning if called outside a coroutine.
+ """
+ # Because discord.py recreates the HTTPClient session, may as well follow suit and recreate
+ # our own stuff here too.
+ self._recreate()
super().clear()
- self.api_client.recreate()
async def close(self) -> None:
- """Close the aiohttp session after closing the Discord connection."""
+ """Close the Discord connection and the aiohttp session, connector, and resolver."""
await super().close()
- await self.http_session.close()
await self.api_client.close()
- async def start(self, *args, **kwargs) -> None:
- """Open an aiohttp session before logging in and connecting to Discord."""
- self.http_session = aiohttp.ClientSession(connector=self._connector)
+ if self.http_session:
+ await self.http_session.close()
+
+ if self._connector:
+ await self._connector.close()
- await super().start(*args, **kwargs)
+ if self._resolver:
+ await self._resolver.close()
+
+ async def login(self, *args, **kwargs) -> None:
+ """Re-create the connector and set up sessions before logging into Discord."""
+ self._recreate()
+ await super().login(*args, **kwargs)
+
+ def _recreate(self) -> None:
+ """Re-create the connector, aiohttp session, and the APIClient."""
+ # Use asyncio for DNS resolution instead of threads so threads aren't spammed.
+ # Doesn't seem to have any state with regards to being closed, so no need to worry?
+ self._resolver = aiohttp.AsyncResolver()
+
+ # Its __del__ does send a warning but it doesn't always show up for some reason.
+ if self._connector and not self._connector._closed:
+ log.warning(
+ "The previous connector was not closed; it will remain open and be overwritten"
+ )
+
+ # Use AF_INET as its socket family to prevent HTTPS related problems both locally
+ # and in production.
+ self._connector = aiohttp.TCPConnector(
+ resolver=self._resolver,
+ family=socket.AF_INET,
+ )
+
+ # Client.login() will call HTTPClient.static_login() which will create a session using
+ # this connector attribute.
+ self.http.connector = self._connector
+
+ # Its __del__ does send a warning but it doesn't always show up for some reason.
+ if self.http_session and not self.http_session.closed:
+ log.warning(
+ "The previous session was not closed; it will remain open and be overwritten"
+ )
+
+ self.http_session = aiohttp.ClientSession(connector=self._connector)
+ self.api_client.recreate(force=True, connector=self._connector)
async def on_guild_available(self, guild: discord.Guild) -> None:
"""