diff options
| author | 2022-07-10 19:31:40 +0100 | |
|---|---|---|
| committer | 2022-07-10 11:31:40 -0700 | |
| commit | 6e4b958a719a0d312ba556e877470e8d412f666d (patch) | |
| tree | c1c6972052b72840dbaf01343cc9c27bd2318efe | |
| parent | Bump urllib3 from 1.24.3 to 1.26.5 (#2210) (diff) | |
Limit the ext cog to 1 action at a time (#2205)
| -rw-r--r-- | bot/exts/utils/extensions.py | 41 | 
1 files changed, 26 insertions, 15 deletions
| diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 0f5fc0de4..90249867f 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -34,6 +34,7 @@ class Extensions(commands.Cog):      def __init__(self, bot: Bot):          self.bot = bot +        self.action_in_progress = False      @group(name="extensions", aliases=("ext", "exts", "c", "cog", "cogs"), invoke_without_command=True)      async def extensions_group(self, ctx: Context) -> None: @@ -54,8 +55,7 @@ class Extensions(commands.Cog):          if "*" in extensions or "**" in extensions:              extensions = set(self.bot.all_extensions) - set(self.bot.extensions.keys()) -        msg = await self.batch_manage(Action.LOAD, *extensions) -        await ctx.send(msg) +        await self.batch_manage(Action.LOAD, ctx, *extensions)      @extensions_group.command(name="unload", aliases=("ul",))      async def unload_command(self, ctx: Context, *extensions: Extension) -> None: @@ -71,14 +71,12 @@ class Extensions(commands.Cog):          blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))          if blacklisted: -            msg = f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```" +            await ctx.send(f":x: The following extension(s) may not be unloaded:```\n{blacklisted}```")          else:              if "*" in extensions or "**" in extensions:                  extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST -            msg = await self.batch_manage(Action.UNLOAD, *extensions) - -        await ctx.send(msg) +            await self.batch_manage(Action.UNLOAD, ctx, *extensions)      @extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",))      async def reload_command(self, ctx: Context, *extensions: Extension) -> None: @@ -100,9 +98,7 @@ class Extensions(commands.Cog):              extensions = set(self.bot.extensions.keys()) | set(extensions)              extensions.remove("*") -        msg = await self.batch_manage(Action.RELOAD, *extensions) - -        await ctx.send(msg) +        await self.batch_manage(Action.RELOAD, ctx, *extensions)      @extensions_group.command(name="list", aliases=("all",))      async def list_command(self, ctx: Context) -> None: @@ -151,17 +147,27 @@ class Extensions(commands.Cog):          return categories -    async def batch_manage(self, action: Action, *extensions: str) -> str: +    async def batch_manage(self, action: Action, ctx: Context, *extensions: str) -> None:          """ -        Apply an action to multiple extensions and return a message with the results. +        Apply an action to multiple extensions, giving feedback to the invoker while doing so.          If only one extension is given, it is deferred to `manage()`.          """ +        if self.action_in_progress: +            await ctx.send(":x: Another action is in progress, please try again later.") +            return + +        verb = action.name.lower() + +        self.action_in_progress = True +        loading_message = await ctx.send(f":hourglass_flowing_sand: {verb} in progress, please wait...") +          if len(extensions) == 1:              msg, _ = await self.manage(action, extensions[0]) -            return msg +            await loading_message.edit(content=msg) +            self.action_in_progress = False +            return -        verb = action.name.lower()          failures = {}          for extension in extensions: @@ -178,7 +184,8 @@ class Extensions(commands.Cog):          log.debug(f"Batch {verb}ed extensions.") -        return msg +        await loading_message.edit(content=msg) +        self.action_in_progress = False      async def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:          """Apply an action to an extension and return the status message and any error message.""" @@ -215,7 +222,11 @@ class Extensions(commands.Cog):      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: -        """Handle BadArgument errors locally to prevent the help command from showing.""" +        """Handle errors locally to prevent the error handler cog from interfering when not wanted.""" +        # Safely clear the flag on unexpected errors to avoid deadlocks. +        self.action_in_progress = False + +        # Handle BadArgument errors locally to prevent the help command from showing.          if isinstance(error, commands.BadArgument):              await ctx.send(str(error))              error.handled = True | 
