aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ks129 <[email protected]>2020-06-20 21:33:36 +0300
committerGravatar ks129 <[email protected]>2020-06-20 21:33:36 +0300
commit429cc865309242f0cf37147f9c3f05036972eb8c (patch)
treefb6e9c852a1388407032547dd862cb5d1b5b5a43
parentInfractions: Fix cases when user leave from guild before assigning roles (diff)
Implement bot closing tasks waiting + breaking `close` to multiple parts
Made to resolve problem with Reddit cog that revoking access token raise exception because session is closed. To solve this, I made `Bot.closing_tasks` that bot wait before closing. Moved all extensions and cogs removing to `remove_extension` what is called before closing everything else because need to call `cog_unload`.
-rw-r--r--bot/bot.py30
-rw-r--r--bot/cogs/reddit.py4
2 files changed, 31 insertions, 3 deletions
diff --git a/bot/bot.py b/bot/bot.py
index 313652d11..c9eb24bb5 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import socket
import warnings
-from typing import Optional
+from typing import List, Optional
import aiohttp
import aioredis
@@ -49,6 +49,9 @@ class Bot(commands.Bot):
self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot")
+ # All tasks that need to block closing until finished
+ self.closing_tasks: List[asyncio.Task] = []
+
async def _create_redis_session(self) -> None:
"""
Create the Redis connection pool, and then open the redis event gate.
@@ -89,9 +92,32 @@ class Bot(commands.Bot):
self._recreate()
super().clear()
+ def remove_extensions(self) -> None:
+ """Remove all extensions and Cog to close bot. Copy from discord.py's own `close` for right closing order."""
+ for extension in tuple(self.extensions):
+ try:
+ self.unload_extension(extension)
+ except Exception:
+ pass
+
+ for cog in tuple(self.cogs):
+ try:
+ self.remove_cog(cog)
+ except Exception:
+ pass
+
async def close(self) -> None:
"""Close the Discord connection and the aiohttp session, connector, statsd client, and resolver."""
- await super().close()
+ # Remove extensions and cogs before calling super().close() to allow task finish before HTTP session close
+ self.remove_extensions()
+
+ # Wait until all tasks that have to be completed before bot is closing is done
+ for task in self.closing_tasks:
+ log.trace(f"Waiting for task {task.get_name()} before closing.")
+ await task
+
+ # Now actually do full close of bot
+ await super(commands.Bot, self).close()
await self.api_client.close()
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 3b77538a0..5a63d71fc 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -44,7 +44,9 @@ class Reddit(Cog):
"""Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
if self.access_token and self.access_token.expires_at > datetime.utcnow():
- asyncio.create_task(self.revoke_access_token())
+ task = asyncio.create_task(self.revoke_access_token())
+ task.set_name("revoke_reddit_access_token")
+ self.bot.closing_tasks.append(task)
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""