aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/extensions.py41
1 files changed, 26 insertions, 15 deletions
diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py
index 0f5fc0de4..90249867f 100644
--- a/bot/exts/utils/extensions.py
+++ b/bot/exts/utils/extensions.py
@@ -34,6 +34,7 @@ class Extensions(commands.Cog):
def __init__(self, bot: Bot):
self.bot = bot
+ self.action_in_progress = False
@group(name="extensions", aliases=("ext", "exts", "c", "cog", "cogs"), invoke_without_command=True)
async def extensions_group(self, ctx: Context) -> None:
@@ -54,8 +55,7 @@ class Extensions(commands.Cog):
if "*" in extensions or "**" in extensions:
extensions = set(self.bot.all_extensions) - set(self.bot.extensions.keys())
- msg = await self.batch_manage(Action.LOAD, *extensions)
- await ctx.send(msg)
+ await self.batch_manage(Action.LOAD, ctx, *extensions)
@extensions_group.command(name="unload", aliases=("ul",))
async def unload_command(self, ctx: Context, *extensions: Extension) -> None:
@@ -71,14 +71,12 @@ class Extensions(commands.Cog):
blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))
if blacklisted:
- msg = f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```"
+ await ctx.send(f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```")
else:
if "*" in extensions or "**" in extensions:
extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST
- msg = await self.batch_manage(Action.UNLOAD, *extensions)
-
- await ctx.send(msg)
+ await self.batch_manage(Action.UNLOAD, ctx, *extensions)
@extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",))
async def reload_command(self, ctx: Context, *extensions: Extension) -> None:
@@ -100,9 +98,7 @@ class Extensions(commands.Cog):
extensions = set(self.bot.extensions.keys()) | set(extensions)
extensions.remove("*")
- msg = await self.batch_manage(Action.RELOAD, *extensions)
-
- await ctx.send(msg)
+ await self.batch_manage(Action.RELOAD, ctx, *extensions)
@extensions_group.command(name="list", aliases=("all",))
async def list_command(self, ctx: Context) -> None:
@@ -151,17 +147,27 @@ class Extensions(commands.Cog):
return categories
- async def batch_manage(self, action: Action, *extensions: str) -> str:
+ async def batch_manage(self, action: Action, ctx: Context, *extensions: str) -> None:
"""
- Apply an action to multiple extensions and return a message with the results.
+ Apply an action to multiple extensions, giving feedback to the invoker while doing so.
If only one extension is given, it is deferred to `manage()`.
"""
+ if self.action_in_progress:
+ await ctx.send(":x: Another action is in progress, please try again later.")
+ return
+
+ verb = action.name.lower()
+
+ self.action_in_progress = True
+ loading_message = await ctx.send(f":hourglass_flowing_sand: {verb} in progress, please wait...")
+
if len(extensions) == 1:
msg, _ = await self.manage(action, extensions[0])
- return msg
+ await loading_message.edit(content=msg)
+ self.action_in_progress = False
+ return
- verb = action.name.lower()
failures = {}
for extension in extensions:
@@ -178,7 +184,8 @@ class Extensions(commands.Cog):
log.debug(f"Batch {verb}ed extensions.")
- return msg
+ await loading_message.edit(content=msg)
+ self.action_in_progress = False
async def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
"""Apply an action to an extension and return the status message and any error message."""
@@ -215,7 +222,11 @@ class Extensions(commands.Cog):
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
- """Handle BadArgument errors locally to prevent the help command from showing."""
+ """Handle errors locally to prevent the error handler cog from interfering when not wanted."""
+ # Safely clear the flag on unexpected errors to avoid deadlocks.
+ self.action_in_progress = False
+
+ # Handle BadArgument errors locally to prevent the help command from showing.
if isinstance(error, commands.BadArgument):
await ctx.send(str(error))
error.handled = True