diff options
Diffstat (limited to 'tests/helpers.py')
| -rw-r--r-- | tests/helpers.py | 290 | 
1 files changed, 120 insertions, 170 deletions
| diff --git a/tests/helpers.py b/tests/helpers.py index 6f50f6ae3..facc4e1af 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,18 +1,18 @@  from __future__ import annotations -import asyncio  import collections -import functools -import inspect  import itertools  import logging  import unittest.mock -from typing import Any, Iterable, Optional +from asyncio import AbstractEventLoop +from typing import Callable, Iterable, Optional  import discord +from aiohttp import ClientSession  from discord.ext.commands import Context  from bot.api import APIClient +from bot.async_stats import AsyncStatsClient  from bot.bot import Bot @@ -26,19 +26,22 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def async_test(wrapped): -    """ -    Run a test case via asyncio. -    Example: -        >>> @async_test -        ... async def lemon_wins(): -        ...     assert True -    """ +def autospec(target, *attributes: str, **kwargs) -> Callable: +    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = {'spec_set': True, 'autospec': True, **kwargs} -    @functools.wraps(wrapped) -    def wrapper(*args, **kwargs): -        return asyncio.run(wrapped(*args, **kwargs)) -    return wrapper +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            func = patcher(func) +        return func +    return decorator  class HashableMixin(discord.mixins.EqualityComparable): @@ -69,24 +72,31 @@ class CustomMockMixin:      """      Provides common functionality for our custom Mock types. -    The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine -    function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care -    of making sure child mocks are instantiated with the correct class. By default, the mock of the -    children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute -    `child_mock_type` on the custom mock inheriting from this mixin. +    The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock +    object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the +    class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional +    attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The +    class method `spec_set` can be overwritten with the object that should be uses as the specification +    for the mock. + +    Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to +    implement custom behavior.      """      child_mock_type = unittest.mock.MagicMock      discord_id = itertools.count(0) +    spec_set = None +    additional_spec_asyncs = None -    def __init__(self, spec_set: Any = None, **kwargs): +    def __init__(self, **kwargs):          name = kwargs.pop('name', None)  # `name` has special meaning for Mock classes, so we need to set it manually. -        super().__init__(spec_set=spec_set, **kwargs) +        super().__init__(spec_set=self.spec_set, **kwargs) + +        if self.additional_spec_asyncs: +            self._spec_asyncs.extend(self.additional_spec_asyncs)          if name:              self.name = name -        if spec_set: -            self._extract_coroutine_methods_from_spec_instance(spec_set)      def _get_child_mock(self, **kw):          """ @@ -100,7 +110,16 @@ class CustomMockMixin:          This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.          """ -        klass = self.child_mock_type +        _new_name = kw.get("_new_name") +        if _new_name in self.__dict__['_spec_asyncs']: +            return unittest.mock.AsyncMock(**kw) + +        _type = type(self) +        if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: +            # Any asynchronous magic becomes an AsyncMock +            klass = unittest.mock.AsyncMock +        else: +            klass = self.child_mock_type          if self._mock_sealed:              attribute = "." + kw["name"] if "name" in kw else "()" @@ -109,107 +128,6 @@ class CustomMockMixin:          return klass(**kw) -    def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: -        """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" -        for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): -            setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): -    """ -    A MagicMock subclass to mock async callables. - -    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more -    features; this stand-in only overwrites the `__call__` method to an async version. -    """ - -    async def __call__(self, *args, **kwargs): -        return super().__call__(*args, **kwargs) - - -class AsyncContextManagerMock(unittest.mock.MagicMock): -    def __init__(self, return_value: Any): -        super().__init__() -        self._return_value = return_value - -    async def __aenter__(self): -        return self._return_value - -    async def __aexit__(self, *args): -        pass - - -class AsyncIteratorMock: -    """ -    A class to mock asynchronous iterators. - -    This allows async for, which is used in certain Discord.py objects. For example, -    an async iterator is returned by the Reaction.users() method. -    """ - -    def __init__(self, iterable: Iterable = None): -        if iterable is None: -            iterable = [] - -        self.iter = iter(iterable) -        self.iterable = iterable - -        self.call_count = 0 - -    def __aiter__(self): -        return self - -    async def __anext__(self): -        try: -            return next(self.iter) -        except StopIteration: -            raise StopAsyncIteration - -    def __call__(self): -        """ -        Keeps track of the number of times an instance has been called. - -        This is useful, since it typically shows that the iterator has actually been used somewhere after we have -        instantiated the mock for an attribute that normally returns an iterator when called. -        """ -        self.call_count += 1 -        return self - -    @property -    def return_value(self): -        """Makes `self.iterable` accessible as self.return_value.""" -        return self.iterable - -    @return_value.setter -    def return_value(self, iterable): -        """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" -        self.iter = iter(iterable) -        self.iterable = iterable - -    def assert_called(self): -        """Asserts if the AsyncIteratorMock instance has been called at least once.""" -        if self.call_count == 0: -            raise AssertionError("Expected AsyncIteratorMock to have been called.") - -    def assert_called_once(self): -        """Asserts if the AsyncIteratorMock instance has been called exactly once.""" -        if self.call_count != 1: -            raise AssertionError( -                f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." -            ) - -    def assert_not_called(self): -        """Asserts if the AsyncIteratorMock instance has not been called.""" -        if self.call_count != 0: -            raise AssertionError( -                f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." -            ) - -    def reset_mock(self): -        """Resets the call count, but not the return value or iterator.""" -        self.call_count = 0 -  # Create a guild instance to get a realistic Mock of `discord.Guild`  guild_data = { @@ -260,9 +178,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):      For more info, see the `Mocking` section in `tests/README.md`.      """ +    spec_set = guild_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'members': []} -        super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -281,6 +201,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.Role` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = role_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {              'id': next(self.discord_id), @@ -289,7 +211,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):              'colour': discord.Colour(0xdeadbf),              'permissions': discord.Permissions(),          } -        super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if isinstance(self.colour, int):              self.colour = discord.Colour(self.colour) @@ -304,6 +226,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):          """Simplified position-based comparisons similar to those of `discord.Role`."""          return self.position < other.position +    def __ge__(self, other): +        """Simplified position-based comparisons similar to those of `discord.Role`.""" +        return self.position >= other.position +  # Create a Member instance to get a realistic Mock of `discord.Member`  member_data = {'user': 'lemon', 'roles': [1]} @@ -318,9 +244,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin      Instances of this class will follow the specifications of `discord.Member` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = member_instance +      def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:          default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.roles = [MockRole(name="@everyone", position=1, id=0)]          if roles: @@ -341,9 +269,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      Instances of this class will follow the specifications of `discord.User` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = user_instance +      def __init__(self, **kwargs) -> None:          default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} -        super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"@{self.name}" @@ -356,15 +286,19 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `bot.api.APIClient` instances.      For more information, see the `MockGuild` docstring.      """ +    spec_set = APIClient -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=APIClient, **kwargs) +def _get_mock_loop() -> unittest.mock.Mock: +    """Return a mocked asyncio.AbstractEventLoop.""" +    loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) -# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -bot_instance.http_session = None -bot_instance.api_client = None +    # Since calling `create_task` on our MockBot does not actually schedule the coroutine object +    # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object +    # to prevent "has not been awaited"-warnings. +    loop.create_task.side_effect = lambda coroutine: coroutine.close() + +    return loop  class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -374,20 +308,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ +    spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) +    additional_spec_asyncs = ("wait_for", "redis_ready")      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=bot_instance, **kwargs) -        self.api_client = MockAPIClient() - -        # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and -        # and should therefore be awaited. (The documentation calls it a coroutine as well, which -        # is technically incorrect, since it's a regular def.) -        self.wait_for = AsyncMock() +        super().__init__(**kwargs) -        # Since calling `create_task` on our MockBot does not actually schedule the coroutine object -        # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object -        # to prevent "has not been awaited"-warnings. -        self.loop.create_task.side_effect = lambda coroutine: coroutine.close() +        self.loop = _get_mock_loop() +        self.api_client = MockAPIClient(loop=self.loop) +        self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) +        self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -413,15 +343,37 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = channel_instance -    def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: +    def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} -        super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if 'mention' not in kwargs:              self.mention = f"#{self.name}" +# Create data for the DMChannel instance +state = unittest.mock.MagicMock() +me = unittest.mock.MagicMock() +dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) + + +class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock TextChannel objects. + +    Instances of this class will follow the specifications of `discord.TextChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ +    spec_set = dm_channel_instance + +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} +        super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + +  # Create a Message instance to get a realistic MagicMock of `discord.Message`  message_data = {      'id': 1, @@ -455,9 +407,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Context`      instances. For more information, see the `MockGuild` docstring.      """ +    spec_set = context_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=context_instance, **kwargs) +        super().__init__(**kwargs)          self.bot = kwargs.get('bot', MockBot())          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember()) @@ -474,8 +427,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Attachment` instances. For      more information, see the `MockGuild` docstring.      """ -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=attachment_instance, **kwargs) +    spec_set = attachment_instance  class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -485,10 +437,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Message` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = message_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'attachments': []} -        super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) @@ -504,9 +457,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Emoji` instances. For more      information, see the `MockGuild` docstring.      """ +    spec_set = emoji_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=emoji_instance, **kwargs) +        super().__init__(**kwargs)          self.guild = kwargs.get('guild', MockGuild()) @@ -520,9 +474,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=partial_emoji_instance, **kwargs) +    spec_set = partial_emoji_instance  reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -535,12 +487,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Reaction` instances. For      more information, see the `MockGuild` docstring.      """ +    spec_set = reaction_instance      def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=reaction_instance, **kwargs) +        _users = kwargs.pop("users", []) +        super().__init__(**kwargs)          self.emoji = kwargs.get('emoji', MockEmoji())          self.message = kwargs.get('message', MockMessage()) -        self.users = AsyncIteratorMock(kwargs.get('users', [])) + +        user_iterator = unittest.mock.AsyncMock() +        user_iterator.__aiter__.return_value = _users +        self.users.return_value = user_iterator +          self.__str__.return_value = str(self.emoji) @@ -554,13 +512,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.Webhook` instances. For      more information, see the `MockGuild` docstring.      """ - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec_set=webhook_instance, **kwargs) - -        # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined -        # as coroutines. That's why we need to set the methods manually. -        self.send = AsyncMock() -        self.edit = AsyncMock() -        self.delete = AsyncMock() -        self.execute = AsyncMock() +    spec_set = webhook_instance +    additional_spec_asyncs = ("send", "edit", "delete", "execute") | 
