aboutsummaryrefslogtreecommitdiffstats
path: root/tests/bot/exts/test_cogs.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/bot/exts/test_cogs.py')
-rw-r--r--tests/bot/exts/test_cogs.py82
1 files changed, 82 insertions, 0 deletions
diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py
new file mode 100644
index 000000000..f8e120262
--- /dev/null
+++ b/tests/bot/exts/test_cogs.py
@@ -0,0 +1,82 @@
+"""Test suite for general tests which apply to all cogs."""
+
+import importlib
+import pkgutil
+import typing as t
+import unittest
+from collections import defaultdict
+from types import ModuleType
+from unittest import mock
+
+from discord.ext import commands
+
+from bot import exts
+
+
+class CommandNameTests(unittest.TestCase):
+ """Tests for shadowing command names and aliases."""
+
+ @staticmethod
+ def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]:
+ """An iterator that recursively walks through `cog`'s commands and subcommands."""
+ # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods.
+ for command in cog.__cog_commands__:
+ if command.parent is None:
+ yield command
+ if isinstance(command, commands.GroupMixin):
+ # Annoyingly it returns duplicates for each alias so use a set to fix that
+ yield from set(command.walk_commands())
+
+ @staticmethod
+ def walk_modules() -> t.Iterator[ModuleType]:
+ """Yield imported modules from the bot.exts subpackage."""
+ def on_error(name: str) -> t.NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ # The mock prevents asyncio.get_event_loop() from being called.
+ with mock.patch("discord.ext.tasks.loop"):
+ prefix = f"{exts.__name__}."
+ for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error):
+ if not module.ispkg:
+ yield importlib.import_module(module.name)
+
+ @staticmethod
+ def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]:
+ """Yield all cogs defined in an extension."""
+ for obj in module.__dict__.values():
+ # Check if it's a class type cause otherwise issubclass() may raise a TypeError.
+ is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog)
+ if is_cog and obj.__module__ == module.__name__:
+ yield obj
+
+ @staticmethod
+ def get_qualified_names(command: commands.Command) -> t.List[str]:
+ """Return a list of all qualified names, including aliases, for the `command`."""
+ names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases]
+ names.append(command.qualified_name)
+ names += getattr(command, "root_aliases", [])
+
+ return names
+
+ def get_all_commands(self) -> t.Iterator[commands.Command]:
+ """Yield all commands for all cogs in all extensions."""
+ for module in self.walk_modules():
+ for cog in self.walk_cogs(module):
+ for cmd in self.walk_commands(cog):
+ yield cmd
+
+ def test_names_dont_shadow(self):
+ """Names and aliases of commands should be unique."""
+ all_names = defaultdict(list)
+ for cmd in self.get_all_commands():
+ func_name = f"{cmd.module}.{cmd.callback.__qualname__}"
+
+ for name in self.get_qualified_names(cmd):
+ with self.subTest(cmd=func_name, name=name):
+ if name in all_names: # pragma: no cover
+ conflicts = ", ".join(all_names.get(name, ""))
+ self.fail(
+ f"Name '{name}' of the command {func_name} conflicts with {conflicts}."
+ )
+
+ all_names[name].append(func_name)