diff options
-rw-r--r-- | bot/api.py | 41 | ||||
-rw-r--r-- | bot/bot.py | 34 |
2 files changed, 65 insertions, 10 deletions
diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" - def __init__(self, **kwargs): + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" } @@ -42,12 +42,39 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session = aiohttp.ClientSession(**kwargs) + self.session: Optional[aiohttp.ClientSession] = None + self.loop = loop + + self._ready = asyncio.Event(loop=loop) + self._creation_task = None + self._session_args = kwargs + + self.recreate() @staticmethod def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" + async def _create_session(self) -> None: + """Create the aiohttp session and set the ready event.""" + self.session = aiohttp.ClientSession(**self._session_args) + self._ready.set() + + async def close(self) -> None: + """Close the aiohttp session and unset the ready event.""" + if not self._ready.is_set(): + return + + await self.session.close() + self._ready.clear() + + def recreate(self) -> None: + """Schedule the aiohttp session to be created if it's been closed.""" + if self.session is None or self.session.closed: + # Don't schedule a task if one is already in progress. + if self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session()) + async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient: async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API GET.""" + await self._ready.wait() + async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PATCH.""" + await self._ready.wait() + async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" + await self._ready.wait() + async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" + await self._ready.wait() + async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" + await self._ready.wait() + async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: if resp.status == 204: return None diff --git a/bot/bot.py b/bot/bot.py index f39bfb50a..4b3b991a3 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,6 +1,6 @@ -import asyncio import logging import socket +from typing import Optional import aiohttp from discord.ext import commands @@ -16,6 +16,30 @@ class Bot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.http_session: Optional[aiohttp.ClientSession] = None + self.api_client = api.APIClient(loop=self.loop) + + log.addHandler(api.APILoggingHandler(self.api_client)) + + def add_cog(self, cog: commands.Cog) -> None: + """Adds a "cog" to the bot and logs the operation.""" + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def clear(self) -> None: + """Clears the internal state of the bot and resets the API client.""" + super().clear() + self.api_client.recreate() + + async def close(self) -> None: + """Close the aiohttp session after closing the Discord connection.""" + await super().close() + + await self.http_session.close() + await self.api_client.close() + + async def start(self, *args, **kwargs) -> None: + """Open an aiohttp session before logging in and connecting to Discord.""" # Global aiohttp session for all cogs # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. @@ -26,10 +50,4 @@ class Bot(commands.Bot): ) ) - self.api_client = api.APIClient(loop=asyncio.get_event_loop()) - log.addHandler(api.APILoggingHandler(self.api_client)) - - def add_cog(self, cog: commands.Cog) -> None: - """Adds a "cog" to the bot and logs the operation.""" - super().add_cog(cog) - log.info(f"Cog loaded: {cog.qualified_name}") + await super().start(*args, **kwargs) |