aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-08-22 19:13:21 -0700
committerGravatar MarkKoz <[email protected]>2020-08-22 20:07:03 -0700
commitf455a7908a9b07747db6ab89f9c5c53bd5ea2450 (patch)
treed4f81d577b8e453eba5902c5a527149d21a094c2
parentDefine a Command subclass with root alias support (diff)
Bot: add root alias support
Override `Bot.add_command` and `Bot.remove_command` to add/remove root aliases for a command (and recursively for any subcommands). This has to happen in `Bot` because there's no reliable way to get the `Bot` instance otherwise. Therefore, overriding the methods in `GroupMixin` unfortunately doesn't work. Otherwise, it'd be possible to avoid recursion by processing each subcommand as it got added.
-rw-r--r--bot/bot.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/bot/bot.py b/bot/bot.py
index 756449293..34254d8e8 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -130,6 +130,26 @@ class Bot(commands.Bot):
super().add_cog(cog)
log.info(f"Cog loaded: {cog.qualified_name}")
+ def add_command(self, command: commands.Command) -> None:
+ """Add `command` as normal and then add its root aliases to the bot."""
+ super().add_command(command)
+ self._add_root_aliases(command)
+
+ def remove_command(self, name: str) -> Optional[commands.Command]:
+ """
+ Remove a command/alias as normal and then remove its root aliases from the bot.
+
+ Individual root aliases cannot be removed by this function.
+ To remove them, either remove the entire command or manually edit `bot.all_commands`.
+ """
+ command = super().remove_command(name)
+ if command is None:
+ # Even if it's a root alias, there's no way to get the Bot instance to remove the alias.
+ return
+
+ self._remove_root_aliases(command)
+ return command
+
def clear(self) -> None:
"""
Clears the internal state of the bot and recreates the connector and sessions.
@@ -235,3 +255,24 @@ class Bot(commands.Bot):
scope.set_extra("kwargs", kwargs)
log.exception(f"Unhandled exception in {event}.")
+
+ def _add_root_aliases(self, command: commands.Command) -> None:
+ """Recursively add root aliases for `command` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._add_root_aliases(subcommand)
+
+ for alias in command.root_aliases:
+ if alias in self.all_commands:
+ raise commands.CommandRegistrationError(alias, alias_conflict=True)
+
+ self.all_commands[alias] = command
+
+ def _remove_root_aliases(self, command: commands.Command) -> None:
+ """Recursively remove root aliases for `command` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._remove_root_aliases(subcommand)
+
+ for alias in command.root_aliases:
+ self.all_commands.pop(alias, None)