aboutsummaryrefslogtreecommitdiffstats
path: root/tests/helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/helpers.py')
-rw-r--r--tests/helpers.py290
1 files changed, 120 insertions, 170 deletions
diff --git a/tests/helpers.py b/tests/helpers.py
index 6f50f6ae3..facc4e1af 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,18 +1,18 @@
from __future__ import annotations
-import asyncio
import collections
-import functools
-import inspect
import itertools
import logging
import unittest.mock
-from typing import Any, Iterable, Optional
+from asyncio import AbstractEventLoop
+from typing import Callable, Iterable, Optional
import discord
+from aiohttp import ClientSession
from discord.ext.commands import Context
from bot.api import APIClient
+from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
@@ -26,19 +26,22 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def async_test(wrapped):
- """
- Run a test case via asyncio.
- Example:
- >>> @async_test
- ... async def lemon_wins():
- ... assert True
- """
+def autospec(target, *attributes: str, **kwargs) -> Callable:
+ """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = {'spec_set': True, 'autospec': True, **kwargs}
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return asyncio.run(wrapped(*args, **kwargs))
- return wrapper
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ func = patcher(func)
+ return func
+ return decorator
class HashableMixin(discord.mixins.EqualityComparable):
@@ -69,24 +72,31 @@ class CustomMockMixin:
"""
Provides common functionality for our custom Mock types.
- The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
- function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
- of making sure child mocks are instantiated with the correct class. By default, the mock of the
- children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
- `child_mock_type` on the custom mock inheriting from this mixin.
+ The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock
+ object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the
+ class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional
+ attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The
+ class method `spec_set` can be overwritten with the object that should be uses as the specification
+ for the mock.
+
+ Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to
+ implement custom behavior.
"""
child_mock_type = unittest.mock.MagicMock
discord_id = itertools.count(0)
+ spec_set = None
+ additional_spec_asyncs = None
- def __init__(self, spec_set: Any = None, **kwargs):
+ def __init__(self, **kwargs):
name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually.
- super().__init__(spec_set=spec_set, **kwargs)
+ super().__init__(spec_set=self.spec_set, **kwargs)
+
+ if self.additional_spec_asyncs:
+ self._spec_asyncs.extend(self.additional_spec_asyncs)
if name:
self.name = name
- if spec_set:
- self._extract_coroutine_methods_from_spec_instance(spec_set)
def _get_child_mock(self, **kw):
"""
@@ -100,7 +110,16 @@ class CustomMockMixin:
This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
"""
- klass = self.child_mock_type
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return unittest.mock.AsyncMock(**kw)
+
+ _type = type(self)
+ if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = unittest.mock.AsyncMock
+ else:
+ klass = self.child_mock_type
if self._mock_sealed:
attribute = "." + kw["name"] if "name" in kw else "()"
@@ -109,107 +128,6 @@ class CustomMockMixin:
return klass(**kw)
- def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
- """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
- for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
- setattr(self, name, AsyncMock())
-
-
-# TODO: Remove me in Python 3.8
-class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
-
-
-class AsyncContextManagerMock(unittest.mock.MagicMock):
- def __init__(self, return_value: Any):
- super().__init__()
- self._return_value = return_value
-
- async def __aenter__(self):
- return self._return_value
-
- async def __aexit__(self, *args):
- pass
-
-
-class AsyncIteratorMock:
- """
- A class to mock asynchronous iterators.
-
- This allows async for, which is used in certain Discord.py objects. For example,
- an async iterator is returned by the Reaction.users() method.
- """
-
- def __init__(self, iterable: Iterable = None):
- if iterable is None:
- iterable = []
-
- self.iter = iter(iterable)
- self.iterable = iterable
-
- self.call_count = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return next(self.iter)
- except StopIteration:
- raise StopAsyncIteration
-
- def __call__(self):
- """
- Keeps track of the number of times an instance has been called.
-
- This is useful, since it typically shows that the iterator has actually been used somewhere after we have
- instantiated the mock for an attribute that normally returns an iterator when called.
- """
- self.call_count += 1
- return self
-
- @property
- def return_value(self):
- """Makes `self.iterable` accessible as self.return_value."""
- return self.iterable
-
- @return_value.setter
- def return_value(self, iterable):
- """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
- self.iter = iter(iterable)
- self.iterable = iterable
-
- def assert_called(self):
- """Asserts if the AsyncIteratorMock instance has been called at least once."""
- if self.call_count == 0:
- raise AssertionError("Expected AsyncIteratorMock to have been called.")
-
- def assert_called_once(self):
- """Asserts if the AsyncIteratorMock instance has been called exactly once."""
- if self.call_count != 1:
- raise AssertionError(
- f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
- )
-
- def assert_not_called(self):
- """Asserts if the AsyncIteratorMock instance has not been called."""
- if self.call_count != 0:
- raise AssertionError(
- f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
- )
-
- def reset_mock(self):
- """Resets the call count, but not the return value or iterator."""
- self.call_count = 0
-
# Create a guild instance to get a realistic Mock of `discord.Guild`
guild_data = {
@@ -260,9 +178,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
+ spec_set = guild_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'members': []}
- super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -281,6 +201,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = role_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {
'id': next(self.discord_id),
@@ -289,7 +211,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
'colour': discord.Colour(0xdeadbf),
'permissions': discord.Permissions(),
}
- super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if isinstance(self.colour, int):
self.colour = discord.Colour(self.colour)
@@ -304,6 +226,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""Simplified position-based comparisons similar to those of `discord.Role`."""
return self.position < other.position
+ def __ge__(self, other):
+ """Simplified position-based comparisons similar to those of `discord.Role`."""
+ return self.position >= other.position
+
# Create a Member instance to get a realistic Mock of `discord.Member`
member_data = {'user': 'lemon', 'roles': [1]}
@@ -318,9 +244,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = member_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -341,9 +269,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.User` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = user_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
@@ -356,15 +286,19 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `bot.api.APIClient` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = APIClient
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=APIClient, **kwargs)
+def _get_mock_loop() -> unittest.mock.Mock:
+ """Return a mocked asyncio.AbstractEventLoop."""
+ loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True)
-# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
-bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
-bot_instance.http_session = None
-bot_instance.api_client = None
+ # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
+ # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
+ # to prevent "has not been awaited"-warnings.
+ loop.create_task.side_effect = lambda coroutine: coroutine.close()
+
+ return loop
class MockBot(CustomMockMixin, unittest.mock.MagicMock):
@@ -374,20 +308,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop())
+ additional_spec_asyncs = ("wait_for", "redis_ready")
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=bot_instance, **kwargs)
- self.api_client = MockAPIClient()
-
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- self.wait_for = AsyncMock()
+ super().__init__(**kwargs)
- # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
- # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
- # to prevent "has not been awaited"-warnings.
- self.loop.create_task.side_effect = lambda coroutine: coroutine.close()
+ self.loop = _get_mock_loop()
+ self.api_client = MockAPIClient(loop=self.loop)
+ self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True)
+ self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)
# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`
@@ -413,15 +343,37 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = channel_instance
- def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
+ def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
- super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"#{self.name}"
+# Create data for the DMChannel instance
+state = unittest.mock.MagicMock()
+me = unittest.mock.MagicMock()
+dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]}
+dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data)
+
+
+class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A MagicMock subclass to mock TextChannel objects.
+
+ Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ spec_set = dm_channel_instance
+
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+
+
# Create a Message instance to get a realistic MagicMock of `discord.Message`
message_data = {
'id': 1,
@@ -455,9 +407,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+ spec_set = context_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=context_instance, **kwargs)
+ super().__init__(**kwargs)
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -474,8 +427,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Attachment` instances. For
more information, see the `MockGuild` docstring.
"""
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=attachment_instance, **kwargs)
+ spec_set = attachment_instance
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
@@ -485,10 +437,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = message_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'attachments': []}
- super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -504,9 +457,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = emoji_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=emoji_instance, **kwargs)
+ super().__init__(**kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -520,9 +474,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=partial_emoji_instance, **kwargs)
+ spec_set = partial_emoji_instance
reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
@@ -535,12 +487,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = reaction_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=reaction_instance, **kwargs)
+ _users = kwargs.pop("users", [])
+ super().__init__(**kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
- self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+ user_iterator = unittest.mock.AsyncMock()
+ user_iterator.__aiter__.return_value = _users
+ self.users.return_value = user_iterator
+
self.__str__.return_value = str(self.emoji)
@@ -554,13 +512,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Webhook` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=webhook_instance, **kwargs)
-
- # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
- # as coroutines. That's why we need to set the methods manually.
- self.send = AsyncMock()
- self.edit = AsyncMock()
- self.delete = AsyncMock()
- self.execute = AsyncMock()
+ spec_set = webhook_instance
+ additional_spec_asyncs = ("send", "edit", "delete", "execute")