diff options
Diffstat (limited to 'tests/helpers.py')
-rw-r--r-- | tests/helpers.py | 115 |
1 files changed, 55 insertions, 60 deletions
diff --git a/tests/helpers.py b/tests/helpers.py index fd79141ec..22f07934f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import collections import functools import inspect +import itertools +import logging import unittest.mock from typing import Any, Iterable, Optional @@ -10,6 +13,16 @@ import discord from discord.ext.commands import Bot, Context +for logger in logging.Logger.manager.loggerDict.values(): + # Set all loggers to CRITICAL by default to prevent screen clutter during testing + + if not isinstance(logger, logging.Logger): + # There might be some logging.PlaceHolder objects in there + continue + + logger.setLevel(logging.CRITICAL) + + def async_test(wrapped): """ Run a test case via asyncio. @@ -61,11 +74,16 @@ class CustomMockMixin: """ child_mock_type = unittest.mock.MagicMock + discord_id = itertools.count(0) + + def __init__(self, spec_set: Any = None, **kwargs): + name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually. + super().__init__(spec_set=spec_set, **kwargs) - def __init__(self, spec: Any = None, **kwargs): - super().__init__(spec=spec, **kwargs) - if spec: - self._extract_coroutine_methods_from_spec_instance(spec) + if name: + self.name = name + if spec_set: + self._extract_coroutine_methods_from_spec_instance(spec_set) def _get_child_mock(self, **kw): """ @@ -177,26 +195,14 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): For more info, see the `Mocking` section in `tests/README.md`. """ + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'members': []} + super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) - def __init__( - self, - guild_id: int = 1, - roles: Optional[Iterable[MockRole]] = None, - members: Optional[Iterable[MockMember]] = None, - **kwargs, - ) -> None: - super().__init__(spec=guild_instance, **kwargs) - - self.id = guild_id - - self.roles = [MockRole("@everyone", 1)] + self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: self.roles.extend(roles) - self.members = [] - if members: - self.members.extend(members) - # Create a Role instance to get a realistic Mock of `discord.Role` role_data = {'name': 'role', 'id': 1} @@ -210,14 +216,12 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} + super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) - def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None: - super().__init__(spec=role_instance, **kwargs) - - self.name = name - self.id = role_id - self.position = position - self.mention = f'&{self.name}' + if 'mention' not in kwargs: + self.mention = f'&{self.name}' def __lt__(self, other): """Simplified position-based comparisons similar to those of `discord.Role`.""" @@ -237,28 +241,22 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: + default_kwargs = {'name': 'member', 'id': next(self.discord_id)} + super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) - def __init__( - self, - name: str = "member", - user_id: int = 1, - roles: Optional[Iterable[MockRole]] = None, - **kwargs, - ) -> None: - super().__init__(spec=member_instance, **kwargs) - - self.name = name - self.id = user_id - - self.roles = [MockRole("@everyone", 1)] + self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: self.roles.extend(roles) - self.mention = f"@{self.name}" + if 'mention' not in kwargs: + self.mention = f"@{self.name}" # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) +bot_instance.http_session = None +bot_instance.api_client = None class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -270,17 +268,18 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=bot_instance, **kwargs) - - # Our custom attributes and methods - self.http_session = unittest.mock.MagicMock() - self.api_client = unittest.mock.MagicMock() + super().__init__(spec_set=bot_instance, **kwargs) # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which # is technically incorrect, since it's a regular def.) self.wait_for = AsyncMock() + # Since calling `create_task` on our MockBot does not actually schedule the coroutine object + # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object + # to prevent "has not been awaited"-warnings. + self.loop.create_task.side_effect = lambda coroutine: coroutine.close() + # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { @@ -307,11 +306,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: - super().__init__(spec=channel_instance, **kwargs) - self.id = channel_id - self.name = name - self.guild = kwargs.get('guild', MockGuild()) - self.mention = f"#{self.name}" + default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} + super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"#{self.name}" # Create a Message instance to get a realistic MagicMock of `discord.Message` @@ -349,12 +348,11 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=context_instance, **kwargs) + super().__init__(spec_set=context_instance, **kwargs) self.bot = kwargs.get('bot', MockBot()) self.guild = kwargs.get('guild', MockGuild()) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) - self.command = kwargs.get('command', unittest.mock.MagicMock()) class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -366,7 +364,7 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=message_instance, **kwargs) + super().__init__(spec_set=message_instance, **kwargs) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -384,12 +382,9 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=emoji_instance, **kwargs) + super().__init__(spec_set=emoji_instance, **kwargs) self.guild = kwargs.get('guild', MockGuild()) - # Get all coroutine functions and set them as AsyncMock attributes - self._extract_coroutine_methods_from_spec_instance(emoji_instance) - partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') @@ -403,7 +398,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=partial_emoji_instance, **kwargs) + super().__init__(spec_set=partial_emoji_instance, **kwargs) reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -418,7 +413,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec=reaction_instance, **kwargs) + super().__init__(spec_set=reaction_instance, **kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) self.user_list = AsyncIteratorMock(kwargs.get('user_list', [])) |