aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2019-10-03 21:37:03 -0700
committerGravatar MarkKoz <[email protected]>2019-10-03 21:37:03 -0700
commit1fda5f7e1d7fc3bd7002bf047cd975dae5eb1c25 (patch)
treee7acb8d2a3271c7cfec45102ad11547ac43aef48
parentSupport giving multiple extensions to reload (diff)
Use reload_extension() instead of calling unload and reload
* Simplify output format of batch reload with only 1 list of failures * Show success/failure emoji for batch reloads * Simplify logic in the manage() function * Clean up some imports
-rw-r--r--bot/cogs/extensions.py123
1 files changed, 56 insertions, 67 deletions
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index 5e0bd29bf..0d2cc726e 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -1,11 +1,12 @@
+import functools
import logging
-import textwrap
import typing as t
from enum import Enum
from pkgutil import iter_modules
from discord import Colour, Embed
-from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group
+from discord.ext import commands
+from discord.ext.commands import Bot, Context, group
from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
from bot.pagination import LinePaginator
@@ -24,12 +25,13 @@ EXTENSIONS = frozenset(
class Action(Enum):
"""Represents an action to perform on an extension."""
- LOAD = (Bot.load_extension,)
- UNLOAD = (Bot.unload_extension,)
- RELOAD = (Bot.unload_extension, Bot.load_extension)
+ # Need to be partial otherwise they are considered to be function definitions.
+ LOAD = functools.partial(Bot.load_extension)
+ UNLOAD = functools.partial(Bot.unload_extension)
+ RELOAD = functools.partial(Bot.reload_extension)
-class Extension(Converter):
+class Extension(commands.Converter):
"""
Fully qualify the name of an extension and ensure it exists.
@@ -50,10 +52,10 @@ class Extension(Converter):
if argument in EXTENSIONS:
return argument
else:
- raise BadArgument(f":x: Could not find the extension `{argument}`.")
+ raise commands.BadArgument(f":x: Could not find the extension `{argument}`.")
-class Extensions(Cog):
+class Extensions(commands.Cog):
"""Extension management commands."""
def __init__(self, bot: Bot):
@@ -85,12 +87,12 @@ class Extensions(Cog):
"""
Reload extensions given their fully qualified or unqualified names.
+ If an extension fails to be reloaded, it will be rolled-back to the prior working state.
+
If `*` is given as the name, all currently loaded extensions will be reloaded.
If `**` is given as the name, all extensions, including unloaded ones, will be reloaded.
"""
- if "**" in extensions:
- msg = await self.batch_reload(reload_unloaded=True)
- elif "*" in extensions or len(extensions) > 1:
+ if len(extensions) > 1:
msg = await self.batch_reload(*extensions)
else:
msg, _ = self.manage(extensions[0], Action.RELOAD)
@@ -142,48 +144,37 @@ class Extensions(Cog):
log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
- async def batch_reload(self, *extensions: str, reload_unloaded: bool = False) -> str:
- """Reload given extensions or all loaded ones and return a message with the results."""
- unloaded = []
- unload_failures = {}
- load_failures = {}
+ async def batch_reload(self, *extensions: str) -> str:
+ """
+ Reload given extensions and return a message with the results.
+
+ If `*` is given, all currently loaded extensions will be reloaded along with any other
+ specified extensions. If `**` is given, all extensions, including unloaded ones, will be
+ reloaded.
+ """
+ failures = {}
- if "*" in extensions:
- to_unload = set(self.bot.extensions.keys()) | set(extensions)
- to_unload.remove("*")
+ if "**" in extensions:
+ to_reload = EXTENSIONS
+ elif "*" in extensions:
+ to_reload = set(self.bot.extensions.keys()) | set(extensions)
+ to_reload.remove("*")
elif extensions:
- to_unload = extensions
+ to_reload = extensions
else:
- to_unload = self.bot.extensions.copy().keys()
+ to_reload = self.bot.extensions.copy().keys()
- for extension in to_unload:
- _, error = self.manage(extension, Action.UNLOAD)
+ for extension in to_reload:
+ _, error = self.manage(extension, Action.RELOAD)
if error:
- unload_failures[extension] = error
- else:
- unloaded.append(extension)
+ failures[extension] = error
- if reload_unloaded:
- unloaded = EXTENSIONS
-
- for extension in unloaded:
- _, error = self.manage(extension, Action.LOAD)
- if error:
- load_failures[extension] = error
+ emoji = ":x:" if failures else ":ok_hand:"
+ msg = f"{emoji} {len(to_reload) - len(failures)} / {len(to_reload)} extensions reloaded."
- msg = textwrap.dedent(f"""
- **All extensions reloaded**
- Unloaded: {len(to_unload) - len(unload_failures)} / {len(to_unload)}
- Loaded: {len(unloaded) - len(load_failures)} / {len(unloaded)}
- """).strip()
-
- if unload_failures:
- failures = '\n'.join(f'{ext}\n {err}' for ext, err in unload_failures.items())
- msg += f'\nUnload failures:```{failures}```'
-
- if load_failures:
- failures = '\n'.join(f'{ext}\n {err}' for ext, err in load_failures.items())
- msg += f'\nLoad failures:```{failures}```'
+ if failures:
+ failures = '\n'.join(f'{ext}\n {err}' for ext, err in failures.items())
+ msg += f'\nFailures:```{failures}```'
log.debug(f'Reloaded all extensions.')
@@ -194,28 +185,26 @@ class Extensions(Cog):
verb = action.name.lower()
error_msg = None
- if (
- (action is Action.LOAD and ext not in self.bot.extensions)
- or (action is Action.UNLOAD and ext in self.bot.extensions)
- or action is Action.RELOAD
- ):
- try:
- for func in action.value:
- func(self.bot, ext)
- except Exception as e:
- if hasattr(e, "original"):
- e = e.original
-
- log.exception(f"Extension '{ext}' failed to {verb}.")
-
- error_msg = f"{e.__class__.__name__}: {e}"
- msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```"
- else:
- msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`."
- log.debug(msg[10:])
- else:
+ try:
+ action.value(self.bot, ext)
+ except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
+ if action is Action.RELOAD:
+ # When reloading, just load the extension if it was not loaded.
+ return self.manage(ext, Action.LOAD)
+
msg = f":x: Extension `{ext}` is already {verb}ed."
log.debug(msg[4:])
+ except Exception as e:
+ if hasattr(e, "original"):
+ e = e.original
+
+ log.exception(f"Extension '{ext}' failed to {verb}.")
+
+ error_msg = f"{e.__class__.__name__}: {e}"
+ msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```"
+ else:
+ msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`."
+ log.debug(msg[10:])
return msg, error_msg
@@ -227,7 +216,7 @@ class Extensions(Cog):
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""Handle BadArgument errors locally to prevent the help command from showing."""
- if isinstance(error, BadArgument):
+ if isinstance(error, commands.BadArgument):
await ctx.send(str(error))
error.handled = True