aboutsummaryrefslogtreecommitdiffstats
path: root/tests/helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/helpers.py')
-rw-r--r--tests/helpers.py146
1 files changed, 142 insertions, 4 deletions
diff --git a/tests/helpers.py b/tests/helpers.py
index 8a14aeef4..5df796c23 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -10,7 +10,9 @@ import unittest.mock
from typing import Any, Iterable, Optional
import discord
-from discord.ext.commands import Bot, Context
+from discord.ext.commands import Context
+
+from bot.bot import Bot
for logger in logging.Logger.manager.loggerDict.values():
@@ -120,8 +122,80 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
features; this stand-in only overwrites the `__call__` method to an async version.
"""
+
async def __call__(self, *args, **kwargs):
- return super(AsyncMock, self).__call__(*args, **kwargs)
+ return super().__call__(*args, **kwargs)
+
+
+class AsyncIteratorMock:
+ """
+ A class to mock asynchronous iterators.
+
+ This allows async for, which is used in certain Discord.py objects. For example,
+ an async iterator is returned by the Reaction.users() method.
+ """
+
+ def __init__(self, iterable: Iterable = None):
+ if iterable is None:
+ iterable = []
+
+ self.iter = iter(iterable)
+ self.iterable = iterable
+
+ self.call_count = 0
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ return next(self.iter)
+ except StopIteration:
+ raise StopAsyncIteration
+
+ def __call__(self):
+ """
+ Keeps track of the number of times an instance has been called.
+
+ This is useful, since it typically shows that the iterator has actually been used somewhere after we have
+ instantiated the mock for an attribute that normally returns an iterator when called.
+ """
+ self.call_count += 1
+ return self
+
+ @property
+ def return_value(self):
+ """Makes `self.iterable` accessible as self.return_value."""
+ return self.iterable
+
+ @return_value.setter
+ def return_value(self, iterable):
+ """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
+ self.iter = iter(iterable)
+ self.iterable = iterable
+
+ def assert_called(self):
+ """Asserts if the AsyncIteratorMock instance has been called at least once."""
+ if self.call_count == 0:
+ raise AssertionError("Expected AsyncIteratorMock to have been called.")
+
+ def assert_called_once(self):
+ """Asserts if the AsyncIteratorMock instance has been called exactly once."""
+ if self.call_count != 1:
+ raise AssertionError(
+ f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
+ )
+
+ def assert_not_called(self):
+ """Asserts if the AsyncIteratorMock instance has not been called."""
+ if self.call_count != 0:
+ raise AssertionError(
+ f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
+ )
+
+ def reset_mock(self):
+ """Resets the call count, but not the return value or iterator."""
+ self.call_count = 0
# Create a guild instance to get a realistic Mock of `discord.Guild`
@@ -220,7 +294,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
information, see the `MockGuild` docstring.
"""
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
- default_kwargs = {'name': 'member', 'id': next(self.discord_id)}
+ default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
@@ -231,6 +305,25 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
self.mention = f"@{self.name}"
+# Create a User instance to get a realistic Mock of `discord.User`
+user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock())
+
+
+class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
+ """
+ A Mock subclass to mock User objects.
+
+ Instances of this class will follow the specifications of `discord.User` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
+ super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+
+ if 'mention' not in kwargs:
+ self.mention = f"@{self.name}"
+
+
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
bot_instance.http_session = None
@@ -244,6 +337,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=bot_instance, **kwargs)
@@ -281,6 +375,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
@@ -322,6 +417,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=context_instance, **kwargs)
self.bot = kwargs.get('bot', MockBot())
@@ -330,6 +426,20 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.channel = kwargs.get('channel', MockTextChannel())
+attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())
+
+
+class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Attachment objects.
+
+ Instances of this class will follow the specifications of `discord.Attachment` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=attachment_instance, **kwargs)
+
+
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
"""
A MagicMock subclass to mock Message objects.
@@ -337,8 +447,10 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=message_instance, **kwargs)
+ default_kwargs = {'attachments': []}
+ super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -354,6 +466,7 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=emoji_instance, **kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -369,6 +482,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=partial_emoji_instance, **kwargs)
@@ -383,7 +497,31 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=reaction_instance, **kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
+ self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+
+webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())
+
+
+class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter.
+
+ Instances of this class will follow the specifications of `discord.Webhook` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=webhook_instance, **kwargs)
+
+ # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
+ # as coroutines. That's why we need to set the methods manually.
+ self.send = AsyncMock()
+ self.edit = AsyncMock()
+ self.delete = AsyncMock()
+ self.execute = AsyncMock()