diff options
Diffstat (limited to 'tests/helpers.py')
-rw-r--r-- | tests/helpers.py | 190 |
1 files changed, 62 insertions, 128 deletions
diff --git a/tests/helpers.py b/tests/helpers.py index 01752a791..506fe9894 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,11 +1,10 @@ from __future__ import annotations import collections -import inspect import itertools import logging import unittest.mock -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import discord from discord.ext.commands import Context @@ -51,24 +50,31 @@ class CustomMockMixin: """ Provides common functionality for our custom Mock types. - The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine - function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care - of making sure child mocks are instantiated with the correct class. By default, the mock of the - children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute - `child_mock_type` on the custom mock inheriting from this mixin. + The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the + class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional + attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The + class method `spec_set` can be overwritten with the object that should be uses as the specification + for the mock. + + Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to + implement custom behavior. """ child_mock_type = unittest.mock.MagicMock discord_id = itertools.count(0) + spec_set = None + additional_spec_asyncs = None - def __init__(self, spec_set: Any = None, **kwargs): + def __init__(self, **kwargs): name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually. - super().__init__(spec_set=spec_set, **kwargs) + super().__init__(spec_set=self.spec_set, **kwargs) + + if self.additional_spec_asyncs: + self._spec_asyncs.extend(self.additional_spec_asyncs) if name: self.name = name - if spec_set: - self._extract_coroutine_methods_from_spec_instance(spec_set) def _get_child_mock(self, **kw): """ @@ -82,7 +88,16 @@ class CustomMockMixin: This override will look for an attribute called `child_mock_type` and use that as the type of the child mock. """ - klass = self.child_mock_type + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return unittest.mock.AsyncMock(**kw) + + _type = type(self) + if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: + # Any asynchronous magic becomes an AsyncMock + klass = unittest.mock.AsyncMock + else: + klass = self.child_mock_type if self._mock_sealed: attribute = "." + kw["name"] if "name" in kw else "()" @@ -91,95 +106,6 @@ class CustomMockMixin: return klass(**kw) - def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: - """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" - for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): - setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): - """ - A MagicMock subclass to mock async callables. - - Python 3.8 will introduce an AsyncMock class in the standard library that will have some more - features; this stand-in only overwrites the `__call__` method to an async version. - """ - - async def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - -class AsyncIteratorMock: - """ - A class to mock asynchronous iterators. - - This allows async for, which is used in certain Discord.py objects. For example, - an async iterator is returned by the Reaction.users() method. - """ - - def __init__(self, iterable: Iterable = None): - if iterable is None: - iterable = [] - - self.iter = iter(iterable) - self.iterable = iterable - - self.call_count = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - def __call__(self): - """ - Keeps track of the number of times an instance has been called. - - This is useful, since it typically shows that the iterator has actually been used somewhere after we have - instantiated the mock for an attribute that normally returns an iterator when called. - """ - self.call_count += 1 - return self - - @property - def return_value(self): - """Makes `self.iterable` accessible as self.return_value.""" - return self.iterable - - @return_value.setter - def return_value(self, iterable): - """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" - self.iter = iter(iterable) - self.iterable = iterable - - def assert_called(self): - """Asserts if the AsyncIteratorMock instance has been called at least once.""" - if self.call_count == 0: - raise AssertionError("Expected AsyncIteratorMock to have been called.") - - def assert_called_once(self): - """Asserts if the AsyncIteratorMock instance has been called exactly once.""" - if self.call_count != 1: - raise AssertionError( - f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." - ) - - def assert_not_called(self): - """Asserts if the AsyncIteratorMock instance has not been called.""" - if self.call_count != 0: - raise AssertionError( - f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." - ) - - def reset_mock(self): - """Resets the call count, but not the return value or iterator.""" - self.call_count = 0 - # Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { @@ -230,9 +156,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): For more info, see the `Mocking` section in `tests/README.md`. """ + spec_set = guild_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'members': []} - super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -251,9 +179,11 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ + spec_set = role_instance + def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} - super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -276,9 +206,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ + spec_set = member_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -299,9 +231,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ + spec_set = user_instance + def __init__(self, **kwargs) -> None: default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"@{self.name}" @@ -320,14 +254,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + spec_set = bot_instance + additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: - super().__init__(spec_set=bot_instance, **kwargs) + super().__init__(**kwargs) # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which # is technically incorrect, since it's a regular def.) - self.wait_for = AsyncMock() + # self.wait_for = unittest.mock.AsyncMock() # Since calling `create_task` on our MockBot does not actually schedule the coroutine object # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object @@ -358,10 +294,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + spec_set = channel_instance def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} - super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"#{self.name}" @@ -400,9 +337,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + spec_set = context_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=context_instance, **kwargs) + super().__init__(**kwargs) self.bot = kwargs.get('bot', MockBot()) self.guild = kwargs.get('guild', MockGuild()) self.author = kwargs.get('author', MockMember()) @@ -419,8 +357,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=attachment_instance, **kwargs) + spec_set = attachment_instance class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -430,10 +367,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + spec_set = message_instance def __init__(self, **kwargs) -> None: default_kwargs = {'attachments': []} - super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -449,9 +387,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + spec_set = emoji_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=emoji_instance, **kwargs) + super().__init__(**kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -465,9 +404,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=partial_emoji_instance, **kwargs) + spec_set = partial_emoji_instance reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -480,12 +417,17 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + spec_set = reaction_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=reaction_instance, **kwargs) + _users = kwargs.pop("users", []) + super().__init__(**kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) - self.users = AsyncIteratorMock(kwargs.get('users', [])) + + user_iterator = unittest.mock.AsyncMock() + user_iterator.__aiter__.return_value = _users + self.users.return_value = user_iterator webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) @@ -498,13 +440,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=webhook_instance, **kwargs) - - # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined - # as coroutines. That's why we need to set the methods manually. - self.send = AsyncMock() - self.edit = AsyncMock() - self.delete = AsyncMock() - self.execute = AsyncMock() + spec_set = webhook_instance + additional_spec_asyncs = ("send", "edit", "delete", "execute") |