diff options
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 22 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 34 | ||||
| -rw-r--r-- | tests/bot/cogs/test_token_remover.py | 4 | ||||
| -rw-r--r-- | tests/bot/utils/test_time.py | 3 | ||||
| -rw-r--r-- | tests/helpers.py | 190 | ||||
| -rw-r--r-- | tests/test_helpers.py | 63 | 
6 files changed, 103 insertions, 213 deletions
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 6406f0737..e164f7544 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio  import logging  import typing  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch  import discord @@ -293,8 +293,8 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          )          for message, expect_webhook_call, expect_attachment_call in test_values: -            with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: -                with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: +            with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: +                with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:                      with self.subTest(clean_content=message.clean_content, attachments=message.attachments):                          await self.cog.relay_message(message) @@ -303,7 +303,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):                          message.add_reaction.assert_called_once_with(self.checkmark_emoji) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -314,15 +314,15 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          for side_effect in side_effects:              send_attachments.side_effect = side_effect -            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: +            with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:                  with self.subTest(side_effect=type(side_effect).__name__):                      with self.assertNotLogs(logger=log, level=logging.ERROR):                          await self.cog.relay_message(message)                      self.assertEqual(send_webhook.call_count, 2) -    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) -    @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) +    @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) +    @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)      async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):          """The `relay_message` method should handle irretrievable attachments."""          message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -456,7 +456,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):                  channel.fetch_message.reset_mock()      @patch(f"{MODULE_PATH}.DuckPond.is_staff") -    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) +    @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)      def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):          """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""          channel_id = 31415926535 @@ -491,8 +491,8 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          payload.emoji = self.duck_pond_emoji          for duck_count, should_relay in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: -                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: +                with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                      count_ducks.return_value = duck_count                      with self.subTest(duck_count=duck_count, should_relay=should_relay):                          await self.cog.on_raw_reaction_add(payload) @@ -526,7 +526,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):              (constants.DuckPond.threshold + 1, True),          )          for duck_count, should_re_add_checkmark in test_cases: -            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: +            with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:                  count_ducks.return_value = duck_count                  with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):                      await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index deae7ebad..f5e937356 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase):          """Test if the `role_info` command correctly returns the `moderator_role`."""          self.ctx.guild.roles.append(self.moderator_role) -        self.cog.roles_info.can_run = helpers.AsyncMock() +        self.cog.roles_info.can_run = unittest.mock.AsyncMock()          self.cog.roles_info.can_run.return_value = True          coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase):          self.ctx.guild.roles.append([dummy_role, admin_role]) -        self.cog.role_info.can_run = helpers.AsyncMock() +        self.cog.role_info.can_run = unittest.mock.AsyncMock()          self.cog.role_info.can_run.return_value = True          coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot)          self.member = helpers.MockMember(id=1234) @@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase):      def setUp(self):          """Common set-up steps done before for each test."""          self.bot = helpers.MockBot() -        self.bot.api_client.get = helpers.AsyncMock() +        self.bot.api_client.get = unittest.mock.AsyncMock()          self.cog = information.Information(self.bot) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Mr. Hemlock") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_nick_in_title_if_available(self):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_ignores_everyone_role(self):          """Created `!user` embeds should not contain mention of the @everyone-role."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase):          self.assertIn("&Admins", embed.description)          self.assertNotIn("&Everyone", embed.description) -    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) -    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):          """The embed should contain expanded infractions and nomination info in mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)      def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):          """The embed should contain only basic infraction data outside of mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase):              embed.description          ) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):          """The embed should be created with the colour of the top role, if a top role is available."""          ctx = helpers.MockContext() @@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):          """The embed should be created with a blurple colour if the user has no assigned roles."""          ctx = helpers.MockContext() @@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour.blurple()) -    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) +    @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))      def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):          """The embed thumbnail should be set to the user's avatar in `png` format."""          ctx = helpers.MockContext() @@ -529,7 +529,7 @@ class UserCommandTests(unittest.TestCase):          with self.assertRaises(InChannelCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx)) -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -542,7 +542,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -555,7 +555,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id] @@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase):          create_embed.assert_called_once_with(ctx, self.moderator)          ctx.send.assert_called_once() -    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) +    @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)      def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..33d1ec170 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,7 +1,7 @@  import asyncio  import logging  import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock  from discord import Colour @@ -11,7 +11,7 @@ from bot.cogs.token_remover import (      setup as setup_cog,  )  from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from tests.helpers import MockBot, MockMessage  class TokenRemoverTests(unittest.TestCase): diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..de5724bca 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@  import asyncio  import unittest  from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch  from dateutil.relativedelta import relativedelta  from bot.utils import time -from tests.helpers import AsyncMock  class TimeTests(unittest.TestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 01752a791..506fe9894 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,11 +1,10 @@  from __future__ import annotations  import collections -import inspect  import itertools  import logging  import unittest.mock -from typing import Any, Iterable, Optional +from typing import Iterable, Optional  import discord  from discord.ext.commands import Context @@ -51,24 +50,31 @@ class CustomMockMixin:      """      Provides common functionality for our custom Mock types. -    The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine -    function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care -    of making sure child mocks are instantiated with the correct class. By default, the mock of the -    children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute -    `child_mock_type` on the custom mock inheriting from this mixin. +    The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock +    object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the +    class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional +    attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The +    class method `spec_set` can be overwritten with the object that should be uses as the specification +    for the mock. + +    Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to +    implement custom behavior.      """      child_mock_type = unittest.mock.MagicMock      discord_id = itertools.count(0) +    spec_set = None +    additional_spec_asyncs = None -    def __init__(self, spec_set: Any = None, **kwargs): +    def __init__(self, **kwargs):          name = kwargs.pop('name', None)  # `name` has special meaning for Mock classes, so we need to set it manually. -        super().__init__(spec_set=spec_set, **kwargs) +        super().__init__(spec_set=self.spec_set, **kwargs) + +        if self.additional_spec_asyncs: +            self._spec_asyncs.extend(self.additional_spec_asyncs)          if name:              self.name = name -        if spec_set: -            self._extract_coroutine_methods_from_spec_instance(spec_set)      def _get_child_mock(self, **kw):          """ @@ -82,7 +88,16 @@ class CustomMockMixin:          This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.          """ -        klass = self.child_mock_type +        _new_name = kw.get("_new_name") +        if _new_name in self.__dict__['_spec_asyncs']: +            return unittest.mock.AsyncMock(**kw) + +        _type = type(self) +        if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: +            # Any asynchronous magic becomes an AsyncMock +            klass = unittest.mock.AsyncMock +        else: +            klass = self.child_mock_type          if self._mock_sealed:              attribute = "." + kw["name"] if "name" in kw else "()" @@ -91,95 +106,6 @@ class CustomMockMixin:          return klass(**kw) -    def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: -        """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" -        for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): -            setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): -    """ -    A MagicMock subclass to mock async callables. - -    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more -    features; this stand-in only overwrites the `__call__` method to an async version. -    """ - -    async def __call__(self, *args, **kwargs): -        return super().__call__(*args, **kwargs) - - -class AsyncIteratorMock: -    """ -    A class to mock asynchronous iterators. - -    This allows async for, which is used in certain Discord.py objects. For example, -    an async iterator is returned by the Reaction.users() method. -    """ - -    def __init__(self, iterable: Iterable = None): -        if iterable is None: -            iterable = [] - -        self.iter = iter(iterable) -        self.iterable = iterable - -        self.call_count = 0 - -    def __aiter__(self): -        return self - -    async def __anext__(self): -        try: -            return next(self.iter) -        except StopIteration: -            raise StopAsyncIteration - -    def __call__(self): -        """ -        Keeps track of the number of times an instance has been called. - -        This is useful, since it typically shows that the iterator has actually been used somewhere after we have -        instantiated the mock for an attribute that normally returns an iterator when called. -        """ -        self.call_count += 1 -        return self - -    @property -    def return_value(self): -        """Makes `self.iterable` accessible as self.return_value.""" -        return self.iterable - -    @return_value.setter -    def return_value(self, iterable): -        """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" -        self.iter = iter(iterable) -        self.iterable = iterable - -    def assert_called(self): -        """Asserts if the AsyncIteratorMock instance has been called at least once.""" -        if self.call_count == 0: -            raise AssertionError("Expected AsyncIteratorMock to have been called.") - -    def assert_called_once(self): -        """Asserts if the AsyncIteratorMock instance has been called exactly once.""" -        if self.call_count != 1: -            raise AssertionError( -                f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." -            ) - -    def assert_not_called(self): -        """Asserts if the AsyncIteratorMock instance has not been called.""" -        if self.call_count != 0: -            raise AssertionError( -                f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." -            ) - -    def reset_mock(self): -        """Resets the call count, but not the return value or iterator.""" -        self.call_count = 0 -  # Create a guild instance to get a realistic Mock of `discord.Guild`  guild_data = { @@ -230,9 +156,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):      For more info, see the `Mocking` section in `tests/README.md`.      """ +    spec_set = guild_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'members': []} -        super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -251,9 +179,11 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.Role` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = role_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} -        super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f'&{self.name}' @@ -276,9 +206,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin      Instances of this class will follow the specifications of `discord.Member` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = member_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -299,9 +231,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.User` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = user_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"@{self.name}" @@ -320,14 +254,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ +    spec_set = bot_instance +    additional_spec_asyncs = ("wait_for",)      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=bot_instance, **kwargs) +        super().__init__(**kwargs)          # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and          # and should therefore be awaited. (The documentation calls it a coroutine as well, which          # is technically incorrect, since it's a regular def.) -        self.wait_for = AsyncMock() +        # self.wait_for = unittest.mock.AsyncMock()          # Since calling `create_task` on our MockBot does not actually schedule the coroutine object          # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object @@ -358,10 +294,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = channel_instance      def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} -        super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"#{self.name}" @@ -400,9 +337,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Context`      instances. For more information, see the `MockGuild` docstring.      """ +    spec_set = context_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=context_instance, **kwargs) +        super().__init__(**kwargs)          self.bot = kwargs.get('bot', MockBot())          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember()) @@ -419,8 +357,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Attachment` instances. For      more information, see the `MockGuild` docstring.      """ -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=attachment_instance, **kwargs) +    spec_set = attachment_instance  class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -430,10 +367,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Message` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = message_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'attachments': []} -        super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) @@ -449,9 +387,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Emoji` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = emoji_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=emoji_instance, **kwargs) +        super().__init__(**kwargs)          self.guild = kwargs.get('guild', MockGuild()) @@ -465,9 +404,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=partial_emoji_instance, **kwargs) +    spec_set = partial_emoji_instance  reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -480,12 +417,17 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Reaction` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = reaction_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=reaction_instance, **kwargs) +        _users = kwargs.pop("users", []) +        super().__init__(**kwargs)          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage()) -        self.users = AsyncIteratorMock(kwargs.get('users', [])) + +        user_iterator = unittest.mock.AsyncMock() +        user_iterator.__aiter__.return_value = _users +        self.users.return_value = user_iterator  webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) @@ -498,13 +440,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Webhook` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=webhook_instance, **kwargs) - -        # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined -        # as coroutines. That's why we need to set the methods manually. -        self.send = AsyncMock() -        self.edit = AsyncMock() -        self.delete = AsyncMock() -        self.execute = AsyncMock() +    spec_set = webhook_instance +    additional_spec_asyncs = ("send", "edit", "delete", "execute") diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fe39df308..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,4 @@  import asyncio -import inspect  import unittest  import unittest.mock @@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):          with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):              asyncio.run(coroutine_object) +    def test_user_mock_uses_explicitly_passed_mention_attribute(self): +        """MockUser should use an explicitly passed value for user.mention.""" +        user = helpers.MockUser(mention="hello") +        self.assertEqual(user.mention, "hello") +  class MockObjectTests(unittest.TestCase):      """Tests the mock objects and mixins we've defined.""" @@ -341,57 +345,10 @@ class MockObjectTests(unittest.TestCase):                  attribute = getattr(mock, valid_attribute)                  self.assertTrue(isinstance(attribute, mock_type.child_mock_type)) -    def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self): -        """Test if all coroutine functions are extracted, but not regular methods or attributes.""" -        class CoroutineDonor: -            def __init__(self): -                self.some_attribute = 'alpha' - -            async def first_coroutine(): -                """This coroutine function should be extracted.""" - -            async def second_coroutine(): -                """This coroutine function should be extracted.""" - -            def regular_method(): -                """This regular function should not be extracted.""" - -        class Receiver: +    def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self): +        """The CustomMockMixin should mock async magic methods with an AsyncMock.""" +        class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):              pass -        donor = CoroutineDonor() -        receiver = Receiver() - -        helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor) - -        self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock) -        self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock) -        self.assertFalse(hasattr(receiver, 'regular_method')) -        self.assertFalse(hasattr(receiver, 'some_attribute')) - -    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) -    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") -    def test_custom_mock_mixin_init_with_spec(self, extract_method_mock): -        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" -        spec_set = "pydis" - -        helpers.CustomMockMixin(spec_set=spec_set) - -        extract_method_mock.assert_called_once_with(spec_set) - -    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) -    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") -    def test_custom_mock_mixin_init_without_spec(self, extract_method_mock): -        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" -        helpers.CustomMockMixin() - -        extract_method_mock.assert_not_called() - -    def test_async_mock_provides_coroutine_for_dunder_call(self): -        """Test if AsyncMock objects have a coroutine for their __call__ method.""" -        async_mock = helpers.AsyncMock() -        self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__)) - -        coroutine = async_mock() -        self.assertTrue(inspect.iscoroutine(coroutine)) -        self.assertIsNotNone(asyncio.run(coroutine)) +        mock = MyMock() +        self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock)  |