diff options
-rw-r--r-- | bot/exts/utils/extensions.py | 41 |
1 files changed, 26 insertions, 15 deletions
diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 0f5fc0de4..90249867f 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -34,6 +34,7 @@ class Extensions(commands.Cog): def __init__(self, bot: Bot): self.bot = bot + self.action_in_progress = False @group(name="extensions", aliases=("ext", "exts", "c", "cog", "cogs"), invoke_without_command=True) async def extensions_group(self, ctx: Context) -> None: @@ -54,8 +55,7 @@ class Extensions(commands.Cog): if "*" in extensions or "**" in extensions: extensions = set(self.bot.all_extensions) - set(self.bot.extensions.keys()) - msg = await self.batch_manage(Action.LOAD, *extensions) - await ctx.send(msg) + await self.batch_manage(Action.LOAD, ctx, *extensions) @extensions_group.command(name="unload", aliases=("ul",)) async def unload_command(self, ctx: Context, *extensions: Extension) -> None: @@ -71,14 +71,12 @@ class Extensions(commands.Cog): blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) if blacklisted: - msg = f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```" + await ctx.send(f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```") else: if "*" in extensions or "**" in extensions: extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST - msg = await self.batch_manage(Action.UNLOAD, *extensions) - - await ctx.send(msg) + await self.batch_manage(Action.UNLOAD, ctx, *extensions) @extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",)) async def reload_command(self, ctx: Context, *extensions: Extension) -> None: @@ -100,9 +98,7 @@ class Extensions(commands.Cog): extensions = set(self.bot.extensions.keys()) | set(extensions) extensions.remove("*") - msg = await self.batch_manage(Action.RELOAD, *extensions) - - await ctx.send(msg) + await self.batch_manage(Action.RELOAD, ctx, *extensions) @extensions_group.command(name="list", aliases=("all",)) async def list_command(self, ctx: Context) -> None: @@ -151,17 +147,27 @@ class Extensions(commands.Cog): return categories - async def batch_manage(self, action: Action, *extensions: str) -> str: + async def batch_manage(self, action: Action, ctx: Context, *extensions: str) -> None: """ - Apply an action to multiple extensions and return a message with the results. + Apply an action to multiple extensions, giving feedback to the invoker while doing so. If only one extension is given, it is deferred to `manage()`. """ + if self.action_in_progress: + await ctx.send(":x: Another action is in progress, please try again later.") + return + + verb = action.name.lower() + + self.action_in_progress = True + loading_message = await ctx.send(f":hourglass_flowing_sand: {verb} in progress, please wait...") + if len(extensions) == 1: msg, _ = await self.manage(action, extensions[0]) - return msg + await loading_message.edit(content=msg) + self.action_in_progress = False + return - verb = action.name.lower() failures = {} for extension in extensions: @@ -178,7 +184,8 @@ class Extensions(commands.Cog): log.debug(f"Batch {verb}ed extensions.") - return msg + await loading_message.edit(content=msg) + self.action_in_progress = False async def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: """Apply an action to an extension and return the status message and any error message.""" @@ -215,7 +222,11 @@ class Extensions(commands.Cog): # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Handle BadArgument errors locally to prevent the help command from showing.""" + """Handle errors locally to prevent the error handler cog from interfering when not wanted.""" + # Safely clear the flag on unexpected errors to avoid deadlocks. + self.action_in_progress = False + + # Handle BadArgument errors locally to prevent the help command from showing. if isinstance(error, commands.BadArgument): await ctx.send(str(error)) error.handled = True |