diff options
| -rw-r--r-- | bot/api.py | 73 | ||||
| -rw-r--r-- | bot/bot.py | 82 | ||||
| -rw-r--r-- | tests/bot/test_api.py | 8 | 
3 files changed, 39 insertions, 124 deletions
| diff --git a/bot/api.py b/bot/api.py index 4b8520582..d93f9f2ba 100644 --- a/bot/api.py +++ b/bot/api.py @@ -37,64 +37,27 @@ class APIClient:      session: Optional[aiohttp.ClientSession] = None      loop: asyncio.AbstractEventLoop = None -    def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): +    def __init__(self, **session_kwargs):          auth_headers = {              'Authorization': f"Token {Keys.site_api}"          } -        if 'headers' in kwargs: -            kwargs['headers'].update(auth_headers) +        if 'headers' in session_kwargs: +            session_kwargs['headers'].update(auth_headers)          else: -            kwargs['headers'] = auth_headers +            session_kwargs['headers'] = auth_headers -        self.session = None -        self.loop = loop - -        self._ready = asyncio.Event(loop=loop) -        self._creation_task = None -        self._default_session_kwargs = kwargs - -        self.recreate() +        # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we +        # don't and shouldn't need to do that, so we can avoid scheduling a task to create it. +        self.session = aiohttp.ClientSession(**session_kwargs)      @staticmethod      def _url_for(endpoint: str) -> str:          return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" -    async def _create_session(self, **session_kwargs) -> None: -        """ -        Create the aiohttp session with `session_kwargs` and set the ready event. - -        `session_kwargs` is merged with `_default_session_kwargs` and overwrites its values. -        If an open session already exists, it will first be closed. -        """ -        await self.close() -        self.session = aiohttp.ClientSession(**{**self._default_session_kwargs, **session_kwargs}) -        self._ready.set() -      async def close(self) -> None: -        """Close the aiohttp session and unset the ready event.""" -        if self.session: -            await self.session.close() - -        self._ready.clear() - -    def recreate(self, force: bool = False, **session_kwargs) -> None: -        """ -        Schedule the aiohttp session to be created with `session_kwargs` if it's been closed. - -        If `force` is True, the session will be recreated even if an open one exists. If a task to -        create the session is pending, it will be cancelled. - -        `session_kwargs` is merged with the kwargs given when the `APIClient` was created and -        overwrites those default kwargs. -        """ -        if force or self.session is None or self.session.closed: -            if force and self._creation_task: -                self._creation_task.cancel() - -            # Don't schedule a task if one is already in progress. -            if force or self._creation_task is None or self._creation_task.done(): -                self._creation_task = self.loop.create_task(self._create_session(**session_kwargs)) +        """Close the aiohttp session.""" +        await self.session.close()      async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:          """Raise ResponseCodeError for non-OK response if an exception should be raised.""" @@ -108,8 +71,6 @@ class APIClient:      async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:          """Send an HTTP request to the site API and return the JSON response.""" -        await self._ready.wait() -          async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json() @@ -132,25 +93,9 @@ class APIClient:      async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]:          """Site API DELETE.""" -        await self._ready.wait() -          async with self.session.delete(self._url_for(endpoint), **kwargs) as resp:              if resp.status == 204:                  return None              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json() - - -def loop_is_running() -> bool: -    """ -    Determine if there is a running asyncio event loop. - -    This helps enable "call this when event loop is running" logic (see: Twisted's `callWhenRunning`), -    which is currently not provided by asyncio. -    """ -    try: -        asyncio.get_running_loop() -    except RuntimeError: -        return False -    return True diff --git a/bot/bot.py b/bot/bot.py index f71f5d1fb..4ebe0a5c3 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -32,7 +32,7 @@ class Bot(commands.Bot):          self.http_session: Optional[aiohttp.ClientSession] = None          self.redis_session = redis_session -        self.api_client = api.APIClient(loop=self.loop) +        self.api_client: Optional[api.APIClient] = None          self.filter_list_cache = defaultdict(dict)          self._connector = None @@ -77,46 +77,6 @@ class Bot(commands.Bot):          for item in full_cache:              self.insert_item_into_filter_list_cache(item) -    def _recreate(self) -> None: -        """Re-create the connector, aiohttp session, the APIClient and the Redis session.""" -        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. -        # Doesn't seem to have any state with regards to being closed, so no need to worry? -        self._resolver = aiohttp.AsyncResolver() - -        # Its __del__ does send a warning but it doesn't always show up for some reason. -        if self._connector and not self._connector._closed: -            log.warning( -                "The previous connector was not closed; it will remain open and be overwritten" -            ) - -        if self.redis_session.closed: -            # If the RedisSession was somehow closed, we try to reconnect it -            # here. Normally, this shouldn't happen. -            self.loop.create_task(self.redis_session.connect()) - -        # Use AF_INET as its socket family to prevent HTTPS related problems both locally -        # and in production. -        self._connector = aiohttp.TCPConnector( -            resolver=self._resolver, -            family=socket.AF_INET, -        ) - -        # Client.login() will call HTTPClient.static_login() which will create a session using -        # this connector attribute. -        self.http.connector = self._connector - -        # Its __del__ does send a warning but it doesn't always show up for some reason. -        if self.http_session and not self.http_session.closed: -            log.warning( -                "The previous session was not closed; it will remain open and be overwritten" -            ) - -        self.http_session = aiohttp.ClientSession(connector=self._connector) -        self.api_client.recreate(force=True, connector=self._connector) - -        # Build the FilterList cache -        self.loop.create_task(self.cache_filter_list_data()) -      @classmethod      def create(cls) -> "Bot":          """Create and return an instance of a Bot.""" @@ -180,21 +140,15 @@ class Bot(commands.Bot):          return command      def clear(self) -> None: -        """ -        Clears the internal state of the bot and recreates the connector and sessions. - -        Will cause a DeprecationWarning if called outside a coroutine. -        """ -        # Because discord.py recreates the HTTPClient session, may as well follow suit and recreate -        # our own stuff here too. -        self._recreate() -        super().clear() +        """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one.""" +        raise NotImplementedError("Re-using a Bot object after closing it is not supported.")      async def close(self) -> None:          """Close the Discord connection and the aiohttp session, connector, statsd client, and resolver."""          await super().close() -        await self.api_client.close() +        if self.api_client: +            await self.api_client.close()          if self.http_session:              await self.http_session.close() @@ -229,7 +183,31 @@ class Bot(commands.Bot):      async def login(self, *args, **kwargs) -> None:          """Re-create the connector and set up sessions before logging into Discord.""" -        self._recreate() +        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. +        self._resolver = aiohttp.AsyncResolver() + +        # Use AF_INET as its socket family to prevent HTTPS related problems both locally +        # and in production. +        self._connector = aiohttp.TCPConnector( +            resolver=self._resolver, +            family=socket.AF_INET, +        ) + +        # Client.login() will call HTTPClient.static_login() which will create a session using +        # this connector attribute. +        self.http.connector = self._connector + +        self.http_session = aiohttp.ClientSession(connector=self._connector) +        self.api_client = api.APIClient(connector=self._connector) + +        if self.redis_session.closed: +            # If the RedisSession was somehow closed, we try to reconnect it +            # here. Normally, this shouldn't happen. +            await self.redis_session.connect() + +        # Build the FilterList cache +        await self.cache_filter_list_data() +          await self.stats.create_socket()          await super().login(*args, **kwargs) diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 99e942813..76bcb481d 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -13,14 +13,6 @@ class APIClientTests(unittest.IsolatedAsyncioTestCase):          cls.error_api_response = MagicMock()          cls.error_api_response.status = 999 -    def test_loop_is_not_running_by_default(self): -        """The event loop should not be running by default.""" -        self.assertFalse(api.loop_is_running()) - -    async def test_loop_is_running_in_async_context(self): -        """The event loop should be running in an async context.""" -        self.assertTrue(api.loop_is_running()) -      def test_response_code_error_default_initialization(self):          """Test the default initialization of `ResponseCodeError` without `text` or `json`"""          error = api.ResponseCodeError(response=self.error_api_response) | 
