diff options
| author | 2019-12-12 00:47:53 -0800 | |
|---|---|---|
| committer | 2019-12-12 00:47:53 -0800 | |
| commit | 2728473e5d0881042d7664f0120a0549248bcce0 (patch) | |
| tree | 77b8d69aa3df5f061b32f68b9f1117bbfe50b1db /tests/helpers.py | |
| parent | apply kosa's requested changes. (diff) | |
| parent | Subclass Bot (#681) (diff) | |
Merge remote-tracking branch 'origin/master' into zen-command
Diffstat (limited to '')
| -rw-r--r-- | tests/helpers.py | 254 | 
1 files changed, 195 insertions, 59 deletions
| diff --git a/tests/helpers.py b/tests/helpers.py index 8496ba031..5df796c23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,13 +1,28 @@  from __future__ import annotations  import asyncio +import collections  import functools  import inspect +import itertools +import logging  import unittest.mock  from typing import Any, Iterable, Optional  import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot + + +for logger in logging.Logger.manager.loggerDict.values(): +    # Set all loggers to CRITICAL by default to prevent screen clutter during testing + +    if not isinstance(logger, logging.Logger): +        # There might be some logging.PlaceHolder objects in there +        continue + +    logger.setLevel(logging.CRITICAL)  def async_test(wrapped): @@ -61,11 +76,16 @@ class CustomMockMixin:      """      child_mock_type = unittest.mock.MagicMock +    discord_id = itertools.count(0) -    def __init__(self, spec: Any = None, **kwargs): -        super().__init__(spec=spec, **kwargs) -        if spec: -            self._extract_coroutine_methods_from_spec_instance(spec) +    def __init__(self, spec_set: Any = None, **kwargs): +        name = kwargs.pop('name', None)  # `name` has special meaning for Mock classes, so we need to set it manually. +        super().__init__(spec_set=spec_set, **kwargs) + +        if name: +            self.name = name +        if spec_set: +            self._extract_coroutine_methods_from_spec_instance(spec_set)      def _get_child_mock(self, **kw):          """ @@ -102,8 +122,80 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):      Python 3.8 will introduce an AsyncMock class in the standard library that will have some more      features; this stand-in only overwrites the `__call__` method to an async version.      """ +      async def __call__(self, *args, **kwargs): -        return super(AsyncMock, self).__call__(*args, **kwargs) +        return super().__call__(*args, **kwargs) + + +class AsyncIteratorMock: +    """ +    A class to mock asynchronous iterators. + +    This allows async for, which is used in certain Discord.py objects. For example, +    an async iterator is returned by the Reaction.users() method. +    """ + +    def __init__(self, iterable: Iterable = None): +        if iterable is None: +            iterable = [] + +        self.iter = iter(iterable) +        self.iterable = iterable + +        self.call_count = 0 + +    def __aiter__(self): +        return self + +    async def __anext__(self): +        try: +            return next(self.iter) +        except StopIteration: +            raise StopAsyncIteration + +    def __call__(self): +        """ +        Keeps track of the number of times an instance has been called. + +        This is useful, since it typically shows that the iterator has actually been used somewhere after we have +        instantiated the mock for an attribute that normally returns an iterator when called. +        """ +        self.call_count += 1 +        return self + +    @property +    def return_value(self): +        """Makes `self.iterable` accessible as self.return_value.""" +        return self.iterable + +    @return_value.setter +    def return_value(self, iterable): +        """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" +        self.iter = iter(iterable) +        self.iterable = iterable + +    def assert_called(self): +        """Asserts if the AsyncIteratorMock instance has been called at least once.""" +        if self.call_count == 0: +            raise AssertionError("Expected AsyncIteratorMock to have been called.") + +    def assert_called_once(self): +        """Asserts if the AsyncIteratorMock instance has been called exactly once.""" +        if self.call_count != 1: +            raise AssertionError( +                f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." +            ) + +    def assert_not_called(self): +        """Asserts if the AsyncIteratorMock instance has not been called.""" +        if self.call_count != 0: +            raise AssertionError( +                f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." +            ) + +    def reset_mock(self): +        """Resets the call count, but not the return value or iterator.""" +        self.call_count = 0  # Create a guild instance to get a realistic Mock of `discord.Guild` @@ -155,25 +247,14 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):      For more info, see the `Mocking` section in `tests/README.md`.      """ -    def __init__( -        self, -        guild_id: int = 1, -        roles: Optional[Iterable[MockRole]] = None, -        members: Optional[Iterable[MockMember]] = None, -        **kwargs, -    ) -> None: -        super().__init__(spec=guild_instance, **kwargs) - -        self.id = guild_id +    def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'members': []} +        super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) -        self.roles = [MockRole("@everyone", 1)] +        self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles:              self.roles.extend(roles) -        self.members = [] -        if members: -            self.members.extend(members) -  # Create a Role instance to get a realistic Mock of `discord.Role`  role_data = {'name': 'role', 'id': 1} @@ -187,13 +268,12 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.Role` instances. For more      information, see the `MockGuild` docstring.      """ -    def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None: -        super().__init__(spec=role_instance, **kwargs) +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} +        super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) -        self.name = name -        self.id = role_id -        self.position = position -        self.mention = f'&{self.name}' +        if 'mention' not in kwargs: +            self.mention = f'&{self.name}'      def __lt__(self, other):          """Simplified position-based comparisons similar to those of `discord.Role`.""" @@ -213,27 +293,41 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin      Instances of this class will follow the specifications of `discord.Member` instances. For more      information, see the `MockGuild` docstring.      """ -    def __init__( -        self, -        name: str = "member", -        user_id: int = 1, -        roles: Optional[Iterable[MockRole]] = None, -        **kwargs, -    ) -> None: -        super().__init__(spec=member_instance, **kwargs) +    def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: +        default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} +        super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) -        self.name = name -        self.id = user_id - -        self.roles = [MockRole("@everyone", 1)] +        self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles:              self.roles.extend(roles) -        self.mention = f"@{self.name}" +        if 'mention' not in kwargs: +            self.mention = f"@{self.name}" + + +# Create a User instance to get a realistic Mock of `discord.User` +user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock()) + + +class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): +    """ +    A Mock subclass to mock User objects. + +    Instances of this class will follow the specifications of `discord.User` instances. For more +    information, see the `MockGuild` docstring. +    """ +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} +        super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + +        if 'mention' not in kwargs: +            self.mention = f"@{self.name}"  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`  bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) +bot_instance.http_session = None +bot_instance.api_client = None  class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -243,18 +337,20 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ -    def __init__(self, **kwargs) -> None: -        super().__init__(spec=bot_instance, **kwargs) -        # Our custom attributes and methods -        self.http_session = unittest.mock.MagicMock() -        self.api_client = unittest.mock.MagicMock() +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=bot_instance, **kwargs)          # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and          # and should therefore be awaited. (The documentation calls it a coroutine as well, which          # is technically incorrect, since it's a regular def.)          self.wait_for = AsyncMock() +        # Since calling `create_task` on our MockBot does not actually schedule the coroutine object +        # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object +        # to prevent "has not been awaited"-warnings. +        self.loop.create_task.side_effect = lambda coroutine: coroutine.close() +  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`  channel_data = { @@ -279,12 +375,13 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ +      def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: -        super().__init__(spec=channel_instance, **kwargs) -        self.id = channel_id -        self.name = name -        self.guild = kwargs.get('guild', MockGuild()) -        self.mention = f"#{self.name}" +        default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} +        super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) + +        if 'mention' not in kwargs: +            self.mention = f"#{self.name}"  # Create a Message instance to get a realistic MagicMock of `discord.Message` @@ -320,13 +417,27 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Context`      instances. For more information, see the `MockGuild` docstring.      """ +      def __init__(self, **kwargs) -> None: -        super().__init__(spec=context_instance, **kwargs) +        super().__init__(spec_set=context_instance, **kwargs)          self.bot = kwargs.get('bot', MockBot())          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) -        self.command = kwargs.get('command', unittest.mock.MagicMock()) + + +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) + + +class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Attachment objects. + +    Instances of this class will follow the specifications of `discord.Attachment` instances. For +    more information, see the `MockGuild` docstring. +    """ +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=attachment_instance, **kwargs)  class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -336,8 +447,10 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Message` instances. For more      information, see the `MockGuild` docstring.      """ +      def __init__(self, **kwargs) -> None: -        super().__init__(spec=message_instance, **kwargs) +        default_kwargs = {'attachments': []} +        super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) @@ -353,13 +466,11 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Emoji` instances. For more      information, see the `MockGuild` docstring.      """ +      def __init__(self, **kwargs) -> None: -        super().__init__(spec=emoji_instance, **kwargs) +        super().__init__(spec_set=emoji_instance, **kwargs)          self.guild = kwargs.get('guild', MockGuild()) -        # Get all coroutine functions and set them as AsyncMock attributes -        self._extract_coroutine_methods_from_spec_instance(emoji_instance) -  partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') @@ -371,8 +482,9 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For      more information, see the `MockGuild` docstring.      """ +      def __init__(self, **kwargs) -> None: -        super().__init__(spec=partial_emoji_instance, **kwargs) +        super().__init__(spec_set=partial_emoji_instance, **kwargs)  reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -385,7 +497,31 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Reaction` instances. For      more information, see the `MockGuild` docstring.      """ +      def __init__(self, **kwargs) -> None: -        super().__init__(spec=reaction_instance, **kwargs) +        super().__init__(spec_set=reaction_instance, **kwargs)          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage()) +        self.users = AsyncIteratorMock(kwargs.get('users', [])) + + +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) + + +class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. + +    Instances of this class will follow the specifications of `discord.Webhook` instances. For +    more information, see the `MockGuild` docstring. +    """ + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec_set=webhook_instance, **kwargs) + +        # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined +        # as coroutines. That's why we need to set the methods manually. +        self.send = AsyncMock() +        self.edit = AsyncMock() +        self.delete = AsyncMock() +        self.execute = AsyncMock() | 
