aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-03-09 15:02:22 -0700
committerGravatar MarkKoz <[email protected]>2020-03-13 17:11:27 -0700
commitbbcdf24a4b5d4f84834bbc8a8da7db2da627541f (patch)
tree9f6b4334f8da8b21a473945cb19219f118b2ea80 /tests
parentCog tests: fix duplicate cogs being yielded (diff)
Cog tests: fix nested modules not being found
* Rename `walk_extensions` to `walk_modules` because some extensions don't consist of a single module
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/cogs/test_cogs.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py
index de0982c93..3a9f07db6 100644
--- a/tests/bot/cogs/test_cogs.py
+++ b/tests/bot/cogs/test_cogs.py
@@ -24,17 +24,21 @@ class CommandNameTests(unittest.TestCase):
yield from command.walk_commands()
@staticmethod
- def walk_extensions() -> t.Iterator[ModuleType]:
- """Yield imported extensions (modules) from the bot.cogs subpackage."""
- for module in pkgutil.iter_modules(cogs.__path__, "bot.cogs."):
- yield importlib.import_module(module.name)
+ def walk_modules() -> t.Iterator[ModuleType]:
+ """Yield imported modules from the bot.cogs subpackage."""
+ def on_error(name: str) -> t.NoReturn:
+ raise ImportError(name=name)
+
+ for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error):
+ if not module.ispkg:
+ yield importlib.import_module(module.name)
@staticmethod
- def walk_cogs(extension: ModuleType) -> t.Iterator[commands.Cog]:
+ def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]:
"""Yield all cogs defined in an extension."""
- for obj in extension.__dict__.values():
+ for obj in module.__dict__.values():
is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog)
- if is_cog and obj.__module__ == extension.__name__:
+ if is_cog and obj.__module__ == module.__name__:
yield obj
@staticmethod
@@ -47,7 +51,7 @@ class CommandNameTests(unittest.TestCase):
def get_all_commands(self) -> t.Iterator[commands.Command]:
"""Yield all commands for all cogs in all extensions."""
- for extension in self.walk_extensions():
- for cog in self.walk_cogs(extension):
+ for module in self.walk_modules():
+ for cog in self.walk_cogs(module):
for cmd in self.walk_commands(cog):
yield cmd