diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/api.py | 41 | ||||
| -rw-r--r-- | bot/bot.py | 34 | 
2 files changed, 65 insertions, 10 deletions
| diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError):  class APIClient:      """Django Site API wrapper.""" -    def __init__(self, **kwargs): +    def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):          auth_headers = {              'Authorization': f"Token {Keys.site_api}"          } @@ -42,12 +42,39 @@ class APIClient:          else:              kwargs['headers'] = auth_headers -        self.session = aiohttp.ClientSession(**kwargs) +        self.session: Optional[aiohttp.ClientSession] = None +        self.loop = loop + +        self._ready = asyncio.Event(loop=loop) +        self._creation_task = None +        self._session_args = kwargs + +        self.recreate()      @staticmethod      def _url_for(endpoint: str) -> str:          return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" +    async def _create_session(self) -> None: +        """Create the aiohttp session and set the ready event.""" +        self.session = aiohttp.ClientSession(**self._session_args) +        self._ready.set() + +    async def close(self) -> None: +        """Close the aiohttp session and unset the ready event.""" +        if not self._ready.is_set(): +            return + +        await self.session.close() +        self._ready.clear() + +    def recreate(self) -> None: +        """Schedule the aiohttp session to be created if it's been closed.""" +        if self.session is None or self.session.closed: +            # Don't schedule a task if one is already in progress. +            if self._creation_task is None or self._creation_task.done(): +                self._creation_task = self.loop.create_task(self._create_session()) +      async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:          """Raise ResponseCodeError for non-OK response if an exception should be raised."""          if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient:      async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API GET.""" +        await self._ready.wait() +          async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API PATCH.""" +        await self._ready.wait() +          async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API POST.""" +        await self._ready.wait() +          async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API PUT.""" +        await self._ready.wait() +          async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]:          """Site API DELETE.""" +        await self._ready.wait() +          async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp:              if resp.status == 204:                  return None diff --git a/bot/bot.py b/bot/bot.py index f39bfb50a..4b3b991a3 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,6 +1,6 @@ -import asyncio  import logging  import socket +from typing import Optional  import aiohttp  from discord.ext import commands @@ -16,6 +16,30 @@ class Bot(commands.Bot):      def __init__(self, *args, **kwargs):          super().__init__(*args, **kwargs) +        self.http_session: Optional[aiohttp.ClientSession] = None +        self.api_client = api.APIClient(loop=self.loop) + +        log.addHandler(api.APILoggingHandler(self.api_client)) + +    def add_cog(self, cog: commands.Cog) -> None: +        """Adds a "cog" to the bot and logs the operation.""" +        super().add_cog(cog) +        log.info(f"Cog loaded: {cog.qualified_name}") + +    def clear(self) -> None: +        """Clears the internal state of the bot and resets the API client.""" +        super().clear() +        self.api_client.recreate() + +    async def close(self) -> None: +        """Close the aiohttp session after closing the Discord connection.""" +        await super().close() + +        await self.http_session.close() +        await self.api_client.close() + +    async def start(self, *args, **kwargs) -> None: +        """Open an aiohttp session before logging in and connecting to Discord."""          # Global aiohttp session for all cogs          # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads          # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. @@ -26,10 +50,4 @@ class Bot(commands.Bot):              )          ) -        self.api_client = api.APIClient(loop=asyncio.get_event_loop()) -        log.addHandler(api.APILoggingHandler(self.api_client)) - -    def add_cog(self, cog: commands.Cog) -> None: -        """Adds a "cog" to the bot and logs the operation.""" -        super().add_cog(cog) -        log.info(f"Cog loaded: {cog.qualified_name}") +        await super().start(*args, **kwargs) | 
