diff options
| -rw-r--r-- | bot/bot.py | 21 | ||||
| -rw-r--r-- | bot/exts/info/reddit.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 9 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 11 | ||||
| -rw-r--r-- | bot/exts/moderation/silence.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/watchchannels/_watchchannel.py | 53 |
6 files changed, 62 insertions, 36 deletions
diff --git a/bot/bot.py b/bot/bot.py index b2e5237fe..b51e41117 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -3,7 +3,7 @@ import logging import socket import warnings from collections import defaultdict -from typing import Dict, Optional +from typing import Dict, List, Optional import aiohttp import discord @@ -48,6 +48,9 @@ class Bot(commands.Bot): self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot") + # All tasks that need to block closing until finished + self.closing_tasks: List[asyncio.Task] = [] + async def cache_filter_list_data(self) -> None: """Cache all the data in the FilterList on the site.""" full_cache = await self.api_client.get('bot/filter-lists') @@ -131,8 +134,24 @@ class Bot(commands.Bot): self._recreate() super().clear() + def _remove_extensions(self) -> None: + """Remove all extensions to trigger cog unloads.""" + for ext in self.extensions.keys(): + try: + self.unload_extension(ext) + except Exception: + pass + async def close(self) -> None: """Close the Discord connection and the aiohttp session, connector, statsd client, and resolver.""" + # Done before super().close() to allow tasks finish before the HTTP session closes. + self.remove_extensions() + + # Wait until all tasks that have to be completed before bot is closing is done + log.trace("Waiting for tasks before closing.") + await asyncio.gather(*self.closing_tasks) + + # Now actually do full close of bot await super().close() await self.api_client.close() diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py index debe40c82..9b9c0028f 100644 --- a/bot/exts/info/reddit.py +++ b/bot/exts/info/reddit.py @@ -45,7 +45,7 @@ class Reddit(Cog): """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() if self.access_token and self.access_token.expires_at > datetime.utcnow(): - asyncio.create_task(self.revoke_access_token()) + self.bot.closing_tasks.append(asyncio.create_task(self.revoke_access_token())) async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 814b17830..99bb1ae11 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -73,8 +73,13 @@ class InfractionScheduler: return # Allowing mod log since this is a passive action that should be logged. - await apply_coro - log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") + try: + await apply_coro + except discord.NotFound: + # When user joined and then right after this left again before action completed, this can't add roles + log.info(f"Can't reapply {infraction['type']} to user {infraction['user']} because user left again.") + else: + log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") async def apply_infraction( self, diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 7cf7075e6..ccddd4530 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -240,10 +240,13 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_update, user.id) async def action() -> None: - await user.add_roles(self._muted_role, reason=reason) - - log.trace(f"Attempting to kick {user} from voice because they've been muted.") - await user.move_to(None, reason=reason) + try: + await user.add_roles(self._muted_role, reason=reason) + except discord.NotFound: + log.info(f"User {user} ({user.id}) left from guild. Can't give Muted role.") + else: + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) await self.apply_infraction(ctx, infraction, user, action()) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index ac0c1c85e..229f991a0 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -156,7 +156,7 @@ class Silence(commands.Cog): if self.muted_channels: channels_string = ''.join(channel.mention for channel in self.muted_channels) message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" - asyncio.create_task(self._mod_alerts_channel.send(message)) + self.bot.closing_tasks.append(asyncio.create_task(self._mod_alerts_channel.send(message))) # This cannot be static (must have a __func__ attribute). async def cog_check(self, ctx: Context) -> bool: diff --git a/bot/exts/moderation/watchchannels/_watchchannel.py b/bot/exts/moderation/watchchannels/_watchchannel.py index 7118dee02..4715dce14 100644 --- a/bot/exts/moderation/watchchannels/_watchchannel.py +++ b/bot/exts/moderation/watchchannels/_watchchannel.py @@ -171,32 +171,38 @@ class WatchChannel(metaclass=CogABCMeta): async def consume_messages(self, delay_consumption: bool = True) -> None: """Consumes the message queues to log watched users' messages.""" - if delay_consumption: - self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") - await asyncio.sleep(BigBrotherConfig.log_delay) + try: + if delay_consumption: + self.log.trace(f"Sleeping {BigBrotherConfig.log_delay} seconds before consuming message queue") + await asyncio.sleep(BigBrotherConfig.log_delay) - self.log.trace("Started consuming the message queue") + self.log.trace("Started consuming the message queue") - # If the previous consumption Task failed, first consume the existing comsumption_queue - if not self.consumption_queue: - self.consumption_queue = self.message_queue.copy() - self.message_queue.clear() + # If the previous consumption Task failed, first consume the existing comsumption_queue + if not self.consumption_queue: + self.consumption_queue = self.message_queue.copy() + self.message_queue.clear() - for user_channel_queues in self.consumption_queue.values(): - for channel_queue in user_channel_queues.values(): - while channel_queue: - msg = channel_queue.popleft() + for user_channel_queues in self.consumption_queue.values(): + for channel_queue in user_channel_queues.values(): + while channel_queue: + msg = channel_queue.popleft() - self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") - await self.relay_message(msg) + self.log.trace(f"Consuming message {msg.id} ({len(msg.attachments)} attachments)") + await self.relay_message(msg) - self.consumption_queue.clear() + self.consumption_queue.clear() - if self.message_queue: - self.log.trace("Channel queue not empty: Continuing consuming queues") - self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) - else: - self.log.trace("Done consuming messages.") + if self.message_queue: + self.log.trace("Channel queue not empty: Continuing consuming queues") + self._consume_task = self.bot.loop.create_task(self.consume_messages(delay_consumption=False)) + else: + self.log.trace("Done consuming messages.") + except asyncio.CancelledError as e: + self.log.exception( + "The consume task was canceled. Messages may be lost.", + exc_info=e + ) async def webhook_send( self, @@ -343,10 +349,3 @@ class WatchChannel(metaclass=CogABCMeta): self.log.trace("Unloading the cog") if self._consume_task and not self._consume_task.done(): self._consume_task.cancel() - try: - self._consume_task.result() - except asyncio.CancelledError as e: - self.log.exception( - "The consume task was canceled. Messages may be lost.", - exc_info=e - ) |