diff options
Diffstat (limited to 'tests/bot/exts/test_cogs.py')
| -rw-r--r-- | tests/bot/exts/test_cogs.py | 82 | 
1 files changed, 82 insertions, 0 deletions
| diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py new file mode 100644 index 000000000..f8e120262 --- /dev/null +++ b/tests/bot/exts/test_cogs.py @@ -0,0 +1,82 @@ +"""Test suite for general tests which apply to all cogs.""" + +import importlib +import pkgutil +import typing as t +import unittest +from collections import defaultdict +from types import ModuleType +from unittest import mock + +from discord.ext import commands + +from bot import exts + + +class CommandNameTests(unittest.TestCase): +    """Tests for shadowing command names and aliases.""" + +    @staticmethod +    def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: +        """An iterator that recursively walks through `cog`'s commands and subcommands.""" +        # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. +        for command in cog.__cog_commands__: +            if command.parent is None: +                yield command +                if isinstance(command, commands.GroupMixin): +                    # Annoyingly it returns duplicates for each alias so use a set to fix that +                    yield from set(command.walk_commands()) + +    @staticmethod +    def walk_modules() -> t.Iterator[ModuleType]: +        """Yield imported modules from the bot.exts subpackage.""" +        def on_error(name: str) -> t.NoReturn: +            raise ImportError(name=name)  # pragma: no cover + +        # The mock prevents asyncio.get_event_loop() from being called. +        with mock.patch("discord.ext.tasks.loop"): +            prefix = f"{exts.__name__}." +            for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): +                if not module.ispkg: +                    yield importlib.import_module(module.name) + +    @staticmethod +    def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: +        """Yield all cogs defined in an extension.""" +        for obj in module.__dict__.values(): +            # Check if it's a class type cause otherwise issubclass() may raise a TypeError. +            is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) +            if is_cog and obj.__module__ == module.__name__: +                yield obj + +    @staticmethod +    def get_qualified_names(command: commands.Command) -> t.List[str]: +        """Return a list of all qualified names, including aliases, for the `command`.""" +        names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] +        names.append(command.qualified_name) +        names += getattr(command, "root_aliases", []) + +        return names + +    def get_all_commands(self) -> t.Iterator[commands.Command]: +        """Yield all commands for all cogs in all extensions.""" +        for module in self.walk_modules(): +            for cog in self.walk_cogs(module): +                for cmd in self.walk_commands(cog): +                    yield cmd + +    def test_names_dont_shadow(self): +        """Names and aliases of commands should be unique.""" +        all_names = defaultdict(list) +        for cmd in self.get_all_commands(): +            func_name = f"{cmd.module}.{cmd.callback.__qualname__}" + +            for name in self.get_qualified_names(cmd): +                with self.subTest(cmd=func_name, name=name): +                    if name in all_names:  # pragma: no cover +                        conflicts = ", ".join(all_names.get(name, "")) +                        self.fail( +                            f"Name '{name}' of the command {func_name} conflicts with {conflicts}." +                        ) + +                all_names[name].append(func_name) | 
