diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/api.py | 40 | ||||
| -rw-r--r-- | bot/bot.py | 84 | ||||
| -rw-r--r-- | bot/cogs/watchchannels/watchchannel.py | 11 | 
3 files changed, 102 insertions, 33 deletions
| diff --git a/bot/api.py b/bot/api.py index e59916114..4b8520582 100644 --- a/bot/api.py +++ b/bot/api.py @@ -52,7 +52,7 @@ class APIClient:          self._ready = asyncio.Event(loop=loop)          self._creation_task = None -        self._session_args = kwargs +        self._default_session_kwargs = kwargs          self.recreate() @@ -60,25 +60,41 @@ class APIClient:      def _url_for(endpoint: str) -> str:          return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" -    async def _create_session(self) -> None: -        """Create the aiohttp session and set the ready event.""" -        self.session = aiohttp.ClientSession(**self._session_args) +    async def _create_session(self, **session_kwargs) -> None: +        """ +        Create the aiohttp session with `session_kwargs` and set the ready event. + +        `session_kwargs` is merged with `_default_session_kwargs` and overwrites its values. +        If an open session already exists, it will first be closed. +        """ +        await self.close() +        self.session = aiohttp.ClientSession(**{**self._default_session_kwargs, **session_kwargs})          self._ready.set()      async def close(self) -> None:          """Close the aiohttp session and unset the ready event.""" -        if not self._ready.is_set(): -            return +        if self.session: +            await self.session.close() -        await self.session.close()          self._ready.clear() -    def recreate(self) -> None: -        """Schedule the aiohttp session to be created if it's been closed.""" -        if self.session is None or self.session.closed: +    def recreate(self, force: bool = False, **session_kwargs) -> None: +        """ +        Schedule the aiohttp session to be created with `session_kwargs` if it's been closed. + +        If `force` is True, the session will be recreated even if an open one exists. If a task to +        create the session is pending, it will be cancelled. + +        `session_kwargs` is merged with the kwargs given when the `APIClient` was created and +        overwrites those default kwargs. +        """ +        if force or self.session is None or self.session.closed: +            if force and self._creation_task: +                self._creation_task.cancel() +              # Don't schedule a task if one is already in progress. -            if self._creation_task is None or self._creation_task.done(): -                self._creation_task = self.loop.create_task(self._create_session()) +            if force or self._creation_task is None or self._creation_task.done(): +                self._creation_task = self.loop.create_task(self._create_session(**session_kwargs))      async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:          """Raise ResponseCodeError for non-OK response if an exception should be raised.""" diff --git a/bot/bot.py b/bot/bot.py index 19b9035c4..950ac6751 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,6 +1,7 @@  import asyncio  import logging  import socket +import warnings  from typing import Optional  import aiohttp @@ -17,20 +18,20 @@ class Bot(commands.Bot):      """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""      def __init__(self, *args, **kwargs): -        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. -        # Use AF_INET as its socket family to prevent HTTPS related problems both locally -        # and in production. -        self._connector = aiohttp.TCPConnector( -            resolver=aiohttp.AsyncResolver(), -            family=socket.AF_INET, -        ) +        if "connector" in kwargs: +            warnings.warn( +                "If login() is called (or the bot is started), the connector will be overwritten " +                "with an internal one" +            ) -        super().__init__(*args, connector=self._connector, **kwargs) - -        self._guild_available = asyncio.Event() +        super().__init__(*args, **kwargs)          self.http_session: Optional[aiohttp.ClientSession] = None -        self.api_client = api.APIClient(loop=self.loop, connector=self._connector) +        self.api_client = api.APIClient(loop=self.loop) + +        self._connector = None +        self._resolver = None +        self._guild_available = asyncio.Event()      def add_cog(self, cog: commands.Cog) -> None:          """Adds a "cog" to the bot and logs the operation.""" @@ -38,22 +39,67 @@ class Bot(commands.Bot):          log.info(f"Cog loaded: {cog.qualified_name}")      def clear(self) -> None: -        """Clears the internal state of the bot and resets the API client.""" +        """ +        Clears the internal state of the bot and recreates the connector and sessions. + +        Will cause a DeprecationWarning if called outside a coroutine. +        """ +        # Because discord.py recreates the HTTPClient session, may as well follow suit and recreate +        # our own stuff here too. +        self._recreate()          super().clear() -        self.api_client.recreate()      async def close(self) -> None: -        """Close the aiohttp session after closing the Discord connection.""" +        """Close the Discord connection and the aiohttp session, connector, and resolver."""          await super().close() -        await self.http_session.close()          await self.api_client.close() -    async def start(self, *args, **kwargs) -> None: -        """Open an aiohttp session before logging in and connecting to Discord.""" -        self.http_session = aiohttp.ClientSession(connector=self._connector) +        if self.http_session: +            await self.http_session.close() + +        if self._connector: +            await self._connector.close() -        await super().start(*args, **kwargs) +        if self._resolver: +            await self._resolver.close() + +    async def login(self, *args, **kwargs) -> None: +        """Re-create the connector and set up sessions before logging into Discord.""" +        self._recreate() +        await super().login(*args, **kwargs) + +    def _recreate(self) -> None: +        """Re-create the connector, aiohttp session, and the APIClient.""" +        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. +        # Doesn't seem to have any state with regards to being closed, so no need to worry? +        self._resolver = aiohttp.AsyncResolver() + +        # Its __del__ does send a warning but it doesn't always show up for some reason. +        if self._connector and not self._connector._closed: +            log.warning( +                "The previous connector was not closed; it will remain open and be overwritten" +            ) + +        # Use AF_INET as its socket family to prevent HTTPS related problems both locally +        # and in production. +        self._connector = aiohttp.TCPConnector( +            resolver=self._resolver, +            family=socket.AF_INET, +        ) + +        # Client.login() will call HTTPClient.static_login() which will create a session using +        # this connector attribute. +        self.http.connector = self._connector + +        # Its __del__ does send a warning but it doesn't always show up for some reason. +        if self.http_session and not self.http_session.closed: +            log.warning( +                "The previous session was not closed; it will remain open and be overwritten" +            ) + +        self.http_session = aiohttp.ClientSession(connector=self._connector) +        self.api_client.recreate(force=True, connector=self._connector)      async def on_guild_available(self, guild: discord.Guild) -> None:          """ diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 3667a80e8..479820444 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -9,7 +9,7 @@ from typing import Optional  import dateutil.parser  import discord -from discord import Color, Embed, HTTPException, Message, errors +from discord import Color, DMChannel, Embed, HTTPException, Message, errors  from discord.ext.commands import Cog, Context  from bot.api import ResponseCodeError @@ -273,7 +273,14 @@ class WatchChannel(metaclass=CogABCMeta):          reason = self.watched_users[user_id]['reason'] -        embed = Embed(description=f"{msg.author.mention} in [#{msg.channel.name}]({msg.jump_url})") +        if isinstance(msg.channel, DMChannel): +            # If a watched user DMs the bot there won't be a channel name or jump URL +            # This could technically include a GroupChannel but bot's can't be in those +            message_jump = "via DM" +        else: +            message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + +        embed = Embed(description=f"{msg.author.mention} {message_jump}")          embed.set_footer(text=f"Added {time_delta} by {actor} | Reason: {reason}")          await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) | 
