diff options
| -rw-r--r-- | bot/cogs/off_topic_names.py | 16 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 11 | ||||
| -rw-r--r-- | bot/cogs/site.py | 4 | ||||
| -rw-r--r-- | bot/constants.py | 4 | ||||
| -rw-r--r-- | config-default.yml | 4 | ||||
| -rw-r--r-- | tests/bot/test_utils.py | 52 | ||||
| -rw-r--r-- | tests/helpers.py | 315 | ||||
| -rw-r--r-- | tests/test_helpers.py | 68 | 
8 files changed, 277 insertions, 197 deletions
| diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 1f9fb0b4f..78792240f 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -24,6 +24,9 @@ class OffTopicName(Converter):          """Attempt to replace any invalid characters with their approximate Unicode equivalent."""          allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" +        # Chain multiple words to a single one +        argument = "-".join(argument.split()) +          if not (2 <= len(argument) <= 96):              raise BadArgument("Channel name must be between 2 and 96 chars long") @@ -97,15 +100,12 @@ class OffTopicNames(Cog):      @otname_group.command(name='add', aliases=('a',))      @with_role(*MODERATION_ROLES) -    async def add_command(self, ctx: Context, *names: OffTopicName) -> None: +    async def add_command(self, ctx: Context, *, name: OffTopicName) -> None:          """          Adds a new off-topic name to the rotation.          The name is not added if it is too similar to an existing name.          """ -        # Chain multiple words to a single one -        name = "-".join(names) -          existing_names = await self.bot.api_client.get('bot/off-topic-channel-names')          close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8) @@ -123,10 +123,8 @@ class OffTopicNames(Cog):      @otname_group.command(name='forceadd', aliases=('fa',))      @with_role(*MODERATION_ROLES) -    async def force_add_command(self, ctx: Context, *names: OffTopicName) -> None: +    async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None:          """Forcefully adds a new off-topic name to the rotation.""" -        # Chain multiple words to a single one -        name = "-".join(names)          await self._add_name(ctx, name)      async def _add_name(self, ctx: Context, name: str) -> None: @@ -138,10 +136,8 @@ class OffTopicNames(Cog):      @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd'))      @with_role(*MODERATION_ROLES) -    async def delete_command(self, ctx: Context, *names: OffTopicName) -> None: +    async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None:          """Removes a off-topic name from the rotation.""" -        # Chain multiple words to a single one -        name = "-".join(names)          await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}')          log.info(f"{ctx.author} deleted the off-topic channel name '{name}'") diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7749d237f..0d06e9c26 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,14 +2,14 @@ import asyncio  import logging  import random  import textwrap -from datetime import datetime +from datetime import datetime, timedelta  from typing import List  from discord import Colour, Embed, TextChannel  from discord.ext.commands import Bot, Cog, Context, group  from discord.ext.tasks import loop -from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES, Webhooks +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks  from bot.converters import Subreddit  from bot.decorators import with_role  from bot.pagination import LinePaginator @@ -117,9 +117,9 @@ class Reddit(Cog):              link = self.URL + data["permalink"]              embed.description += ( -                f"[**{title}**]({link})\n" +                f"**[{title}]({link})**\n"                  f"{text}" -                f"| {ups} upvotes | {comments} comments | u/{author} | {subreddit} |\n\n" +                f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n"              )          embed.colour = Colour.blurple() @@ -130,7 +130,8 @@ class Reddit(Cog):          """Post the top 5 posts daily, and the top 5 posts weekly."""          # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter          now = datetime.utcnow() -        midnight_tomorrow = now.replace(day=now.day + 1, hour=0, minute=0, second=0) +        tomorrow = now + timedelta(days=1) +        midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)          seconds_until = (midnight_tomorrow - now).total_seconds()          await asyncio.sleep(seconds_until) diff --git a/bot/cogs/site.py b/bot/cogs/site.py index d95359159..683613788 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -3,8 +3,7 @@ import logging  from discord import Colour, Embed  from discord.ext.commands import Bot, Cog, Context, group -from bot.constants import Channels, STAFF_ROLES, URLs -from bot.decorators import redirect_output +from bot.constants import URLs  from bot.pagination import LinePaginator  log = logging.getLogger(__name__) @@ -105,7 +104,6 @@ class Site(Cog):          await ctx.send(embed=embed)      @site_group.command(aliases=['r', 'rule'], name='rules') -    @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)      async def site_rules(self, ctx: Context, *rules: int) -> None:          """Provides a link to all rules or, if specified, displays specific rule(s)."""          rules_embed = Embed(title='Rules', color=Colour.blurple()) diff --git a/bot/constants.py b/bot/constants.py index dbbf32063..c78c06227 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -269,6 +269,10 @@ class Emojis(metaclass=YAMLGetter):      ducky_devil: int      ducky_tube: int +    upvotes: str +    comments: str +    user: str +  class Icons(metaclass=YAMLGetter):      section = "style" diff --git a/config-default.yml b/config-default.yml index 76892677e..8583cb6da 100644 --- a/config-default.yml +++ b/config-default.yml @@ -40,6 +40,10 @@ style:          ducky_devil:    &DUCKY_DEVIL    637925314982576139          ducky_tube:     &DUCKY_TUBE     637881368008851456 +        upvotes:        "<:upvotes:638729835245731840>" +        comments:       "<:comments:638729835073765387>" +        user:           "<:user:638729835442602003>" +      icons:          crown_blurple: "https://cdn.discordapp.com/emojis/469964153289965568.png"          crown_green:   "https://cdn.discordapp.com/emojis/469964154719961088.png" diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py new file mode 100644 index 000000000..58ae2a81a --- /dev/null +++ b/tests/bot/test_utils.py @@ -0,0 +1,52 @@ +import unittest + +from bot import utils + + +class CaseInsensitiveDictTests(unittest.TestCase): +    """Tests for the `CaseInsensitiveDict` container.""" + +    def test_case_insensitive_key_access(self): +        """Tests case insensitive key access and storage.""" +        instance = utils.CaseInsensitiveDict() + +        key = 'LEMON' +        value = 'trees' + +        instance[key] = value +        self.assertIn(key, instance) +        self.assertEqual(instance.get(key), value) +        self.assertEqual(instance.get(key.casefold()), value) +        self.assertEqual(instance.pop(key.casefold()), value) +        self.assertNotIn(key, instance) +        self.assertNotIn(key.casefold(), instance) + +        instance.setdefault(key, value) +        del instance[key] +        self.assertNotIn(key, instance) + +    def test_initialization_from_kwargs(self): +        """Tests creating the dictionary from keyword arguments.""" +        instance = utils.CaseInsensitiveDict({'FOO': 'bar'}) +        self.assertEqual(instance['foo'], 'bar') + +    def test_update_from_other_mapping(self): +        """Tests updating the dictionary from another mapping.""" +        instance = utils.CaseInsensitiveDict() +        instance.update({'FOO': 'bar'}) +        self.assertEqual(instance['foo'], 'bar') + + +class ChunkTests(unittest.TestCase): +    """Tests the `chunk` method.""" + +    def test_empty_chunking(self): +        """Tests chunking on an empty iterable.""" +        generator = utils.chunks(iterable=[], size=5) +        self.assertEqual(list(generator), []) + +    def test_list_chunking(self): +        """Tests chunking a non-empty list.""" +        iterable = [1, 2, 3, 4, 5] +        generator = utils.chunks(iterable=iterable, size=2) +        self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/helpers.py b/tests/helpers.py index 892d42e6c..8496ba031 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,8 +2,9 @@ from __future__ import annotations  import asyncio  import functools +import inspect  import unittest.mock -from typing import Iterable, Optional +from typing import Any, Iterable, Optional  import discord  from discord.ext.commands import Bot, Context @@ -24,19 +25,6 @@ def async_test(wrapped):      return wrapper -# TODO: Remove me in Python 3.8 -class AsyncMock(unittest.mock.MagicMock): -    """ -    A MagicMock subclass to mock async callables. - -    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more -    features; this stand-in only overwrites the `__call__` method to an async version. -    """ - -    async def __call__(self, *args, **kwargs): -        return super(AsyncMock, self).__call__(*args, **kwargs) - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. @@ -61,15 +49,61 @@ class ColourMixin:          self.colour = color -class AttributeMock: -    """Ensures attributes of our mock types will be instantiated with the correct mock type.""" +class CustomMockMixin: +    """ +    Provides common functionality for our custom Mock types. -    def __new__(cls, *args, **kwargs): -        """Stops the regular parent class from propagating to newly mocked attributes.""" -        if 'parent' in kwargs: -            return cls.attribute_mocktype(*args, **kwargs) +    The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine +    function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care +    of making sure child mocks are instantiated with the correct class. By default, the mock of the +    children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute +    `child_mock_type` on the custom mock inheriting from this mixin. +    """ -        return super().__new__(cls) +    child_mock_type = unittest.mock.MagicMock + +    def __init__(self, spec: Any = None, **kwargs): +        super().__init__(spec=spec, **kwargs) +        if spec: +            self._extract_coroutine_methods_from_spec_instance(spec) + +    def _get_child_mock(self, **kw): +        """ +        Overwrite of the `_get_child_mock` method to stop the propagation of our custom mock classes. + +        Mock objects automatically create children when you access an attribute or call a method on them. By default, +        the class of these children is the type of the parent itself. However, this would mean that the children created +        for our custom mock types would also be instances of that custom mock type. This is not desirable, as attributes +        of, e.g., a `Bot` object are not `Bot` objects themselves. The Python docs for `unittest.mock` hint that +        overwriting this method is the best way to deal with that. + +        This override will look for an attribute called `child_mock_type` and use that as the type of the child mock. +        """ +        klass = self.child_mock_type + +        if self._mock_sealed: +            attribute = "." + kw["name"] if "name" in kw else "()" +            mock_name = self._extract_mock_name() + attribute +            raise AttributeError(mock_name) + +        return klass(**kw) + +    def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: +        """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" +        for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): +            setattr(self, name, AsyncMock()) + + +# TODO: Remove me in Python 3.8 +class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock async callables. + +    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more +    features; this stand-in only overwrites the `__call__` method to an async version. +    """ +    async def __call__(self, *args, **kwargs): +        return super(AsyncMock, self).__call__(*args, **kwargs)  # Create a guild instance to get a realistic Mock of `discord.Guild` @@ -95,7 +129,7 @@ guild_data = {  guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) -class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin): +class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """      A `Mock` subclass to mock `discord.Guild` objects. @@ -121,9 +155,6 @@ class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin):      For more info, see the `Mocking` section in `tests/README.md`.      """ - -    attribute_mocktype = unittest.mock.MagicMock -      def __init__(          self,          guild_id: int = 1, @@ -143,48 +174,19 @@ class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin):          if members:              self.members.extend(members) -        # `discord.Guild` coroutines -        self.create_category_channel = AsyncMock() -        self.ban = AsyncMock() -        self.bans = AsyncMock() -        self.create_category = AsyncMock() -        self.create_custom_emoji = AsyncMock() -        self.create_role = AsyncMock() -        self.create_text_channel = AsyncMock() -        self.create_voice_channel = AsyncMock() -        self.delete = AsyncMock() -        self.edit = AsyncMock() -        self.estimate_pruned_members = AsyncMock() -        self.fetch_ban = AsyncMock() -        self.fetch_channels = AsyncMock() -        self.fetch_emoji = AsyncMock() -        self.fetch_emojis = AsyncMock() -        self.fetch_member = AsyncMock() -        self.invites = AsyncMock() -        self.kick = AsyncMock() -        self.leave = AsyncMock() -        self.prune_members = AsyncMock() -        self.unban = AsyncMock() -        self.vanity_invite = AsyncMock() -        self.webhooks = AsyncMock() -        self.widget = AsyncMock() -  # Create a Role instance to get a realistic Mock of `discord.Role`  role_data = {'name': 'role', 'id': 1}  role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) -class MockRole(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin): +class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      """      A Mock subclass to mock `discord.Role` objects.      Instances of this class will follow the specifications of `discord.Role` instances. For more      information, see the `MockGuild` docstring.      """ - -    attribute_mocktype = unittest.mock.MagicMock -      def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None:          super().__init__(spec=role_instance, **kwargs) @@ -193,10 +195,6 @@ class MockRole(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):          self.position = position          self.mention = f'&{self.name}' -        # 'discord.Role' coroutines -        self.delete = AsyncMock() -        self.edit = AsyncMock() -      def __lt__(self, other):          """Simplified position-based comparisons similar to those of `discord.Role`."""          return self.position < other.position @@ -208,16 +206,13 @@ state_mock = unittest.mock.MagicMock()  member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) -class MockMember(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin): +class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):      """      A Mock subclass to mock Member objects.      Instances of this class will follow the specifications of `discord.Member` instances. For more      information, see the `MockGuild` docstring.      """ - -    attribute_mocktype = unittest.mock.MagicMock -      def __init__(          self,          name: str = "member", @@ -236,98 +231,29 @@ class MockMember(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):          self.mention = f"@{self.name}" -        # `discord.Member` coroutines -        self.add_roles = AsyncMock() -        self.ban = AsyncMock() -        self.edit = AsyncMock() -        self.fetch_message = AsyncMock() -        self.kick = AsyncMock() -        self.move_to = AsyncMock() -        self.pins = AsyncMock() -        self.remove_roles = AsyncMock() -        self.send = AsyncMock() -        self.trigger_typing = AsyncMock() -        self.unban = AsyncMock() -  # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`  bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -class MockBot(AttributeMock, unittest.mock.MagicMock): +class MockBot(CustomMockMixin, unittest.mock.MagicMock):      """      A MagicMock subclass to mock Bot objects.      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ - -    attribute_mocktype = unittest.mock.MagicMock -      def __init__(self, **kwargs) -> None:          super().__init__(spec=bot_instance, **kwargs) -        # `discord.ext.commands.Bot` coroutines -        self._before_invoke = AsyncMock() -        self._after_invoke = AsyncMock() -        self.application_info = AsyncMock() -        self.change_presence = AsyncMock() -        self.connect = AsyncMock() -        self.close = AsyncMock() -        self.create_guild = AsyncMock() -        self.delete_invite = AsyncMock() -        self.fetch_channel = AsyncMock() -        self.fetch_guild = AsyncMock() -        self.fetch_guilds = AsyncMock() -        self.fetch_invite = AsyncMock() -        self.fetch_user = AsyncMock() -        self.fetch_user_profile = AsyncMock() -        self.fetch_webhook = AsyncMock() -        self.fetch_widget = AsyncMock() -        self.get_context = AsyncMock() -        self.get_prefix = AsyncMock() -        self.invoke = AsyncMock() -        self.is_owner = AsyncMock() -        self.login = AsyncMock() -        self.logout = AsyncMock() -        self.on_command_error = AsyncMock() -        self.on_error = AsyncMock() -        self.process_commands = AsyncMock() -        self.request_offline_members = AsyncMock() -        self.start = AsyncMock() -        self.wait_until_ready = AsyncMock() -        self.wait_for = AsyncMock() - - -# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` -context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) - - -class MockContext(AttributeMock, unittest.mock.MagicMock): -    """ -    A MagicMock subclass to mock Context objects. - -    Instances of this class will follow the specifications of `discord.ext.commands.Context` -    instances. For more information, see the `MockGuild` docstring. -    """ - -    attribute_mocktype = unittest.mock.MagicMock - -    def __init__(self, **kwargs) -> None: -        super().__init__(spec=context_instance, **kwargs) -        self.bot = MockBot() -        self.guild = MockGuild() -        self.author = MockMember() -        self.command = unittest.mock.MagicMock() +        # Our custom attributes and methods +        self.http_session = unittest.mock.MagicMock() +        self.api_client = unittest.mock.MagicMock() -        # `discord.ext.commands.Context` coroutines -        self.fetch_message = AsyncMock() -        self.invoke = AsyncMock() -        self.pins = AsyncMock() -        self.reinvoke = AsyncMock() -        self.send = AsyncMock() -        self.send_help = AsyncMock() -        self.trigger_typing = AsyncMock() +        # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and +        # and should therefore be awaited. (The documentation calls it a coroutine as well, which +        # is technically incorrect, since it's a regular def.) +        self.wait_for = AsyncMock()  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -346,39 +272,20 @@ guild = unittest.mock.MagicMock()  channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) -class MockTextChannel(AttributeMock, unittest.mock.Mock, HashableMixin): +class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """      A MagicMock subclass to mock TextChannel objects.      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ - -    attribute_mocktype = unittest.mock.MagicMock -      def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:          super().__init__(spec=channel_instance, **kwargs)          self.id = channel_id          self.name = name -        self.guild = MockGuild() +        self.guild = kwargs.get('guild', MockGuild())          self.mention = f"#{self.name}" -        # `discord.TextChannel` coroutines -        self.clone = AsyncMock() -        self.create_invite = AsyncMock() -        self.create_webhook = AsyncMock() -        self.delete = AsyncMock() -        self.delete_messages = AsyncMock() -        self.edit = AsyncMock() -        self.fetch_message = AsyncMock() -        self.invites = AsyncMock() -        self.pins = AsyncMock() -        self.purge = AsyncMock() -        self.send = AsyncMock() -        self.set_permissions = AsyncMock() -        self.trigger_typing = AsyncMock() -        self.webhooks = AsyncMock() -  # Create a Message instance to get a realistic MagicMock of `discord.Message`  message_data = { @@ -402,27 +309,83 @@ channel = unittest.mock.MagicMock()  message_instance = discord.Message(state=state, channel=channel, data=message_data) -class MockMessage(AttributeMock, unittest.mock.MagicMock): +# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` +context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) + + +class MockContext(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Context objects. + +    Instances of this class will follow the specifications of `discord.ext.commands.Context` +    instances. For more information, see the `MockGuild` docstring. +    """ +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=context_instance, **kwargs) +        self.bot = kwargs.get('bot', MockBot()) +        self.guild = kwargs.get('guild', MockGuild()) +        self.author = kwargs.get('author', MockMember()) +        self.channel = kwargs.get('channel', MockTextChannel()) +        self.command = kwargs.get('command', unittest.mock.MagicMock()) + + +class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      """      A MagicMock subclass to mock Message objects.      Instances of this class will follow the specifications of `discord.Message` instances. For more      information, see the `MockGuild` docstring.      """ +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=message_instance, **kwargs) +        self.author = kwargs.get('author', MockMember()) +        self.channel = kwargs.get('channel', MockTextChannel()) -    attribute_mocktype = unittest.mock.MagicMock +emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} +emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) + + +class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Emoji objects. + +    Instances of this class will follow the specifications of `discord.Emoji` instances. For more +    information, see the `MockGuild` docstring. +    """      def __init__(self, **kwargs) -> None: -        super().__init__(spec=message_instance, **kwargs) -        self.author = MockMember() -        self.channel = MockTextChannel() - -        # `discord.Message` coroutines -        self.ack = AsyncMock() -        self.add_reaction = AsyncMock() -        self.clear_reactions = AsyncMock() -        self.delete = AsyncMock() -        self.edit = AsyncMock() -        self.pin = AsyncMock() -        self.remove_reaction = AsyncMock() -        self.unpin = AsyncMock() +        super().__init__(spec=emoji_instance, **kwargs) +        self.guild = kwargs.get('guild', MockGuild()) + +        # Get all coroutine functions and set them as AsyncMock attributes +        self._extract_coroutine_methods_from_spec_instance(emoji_instance) + + +partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') + + +class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock PartialEmoji objects. + +    Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For +    more information, see the `MockGuild` docstring. +    """ +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=partial_emoji_instance, **kwargs) + + +reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) + + +class MockReaction(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Reaction objects. + +    Instances of this class will follow the specifications of `discord.Reaction` instances. For +    more information, see the `MockGuild` docstring. +    """ +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=reaction_instance, **kwargs) +        self.emoji = kwargs.get('emoji', MockEmoji()) +        self.message = kwargs.get('message', MockMessage()) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index f08239981..2b58634dd 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -221,10 +221,10 @@ class DiscordMocksTests(unittest.TestCase):      @unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest')      def test_the_custom_mock_methods_test(self, subtest_mock):          """The custom method test should raise AssertionError for invalid methods.""" -        class FakeMockBot(helpers.AttributeMock, unittest.mock.MagicMock): +        class FakeMockBot(helpers.CustomMockMixin, unittest.mock.MagicMock):              """Fake MockBot class with invalid attribute/method `release_the_walrus`.""" -            attribute_mocktype = unittest.mock.MagicMock +            child_mock_type = unittest.mock.MagicMock              def __init__(self, **kwargs):                  super().__init__(spec=helpers.bot_instance, **kwargs) @@ -331,6 +331,18 @@ class MockObjectTests(unittest.TestCase):                  self.assertFalse(instance_one != instance_two)                  self.assertTrue(instance_one != instance_three) +    def test_custom_mock_mixin_accepts_mock_seal(self): +        """The `CustomMockMixin` should support `unittest.mock.seal`.""" +        class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock): + +            child_mock_type = unittest.mock.MagicMock +            pass + +        mock = MyMock() +        unittest.mock.seal(mock) +        with self.assertRaises(AttributeError, msg="MyMock.shirayuki"): +            mock.shirayuki = "hello!" +      def test_spec_propagation_of_mock_subclasses(self):          """Test if the `spec` does not propagate to attributes of the mock object."""          test_values = ( @@ -339,6 +351,10 @@ class MockObjectTests(unittest.TestCase):              (helpers.MockMember, "display_name"),              (helpers.MockBot, "owner_id"),              (helpers.MockContext, "command_failed"), +            (helpers.MockMessage, "mention_everyone"), +            (helpers.MockEmoji, 'managed'), +            (helpers.MockPartialEmoji, 'url'), +            (helpers.MockReaction, 'me'),          )          for mock_type, valid_attribute in test_values: @@ -346,7 +362,53 @@ class MockObjectTests(unittest.TestCase):                  mock = mock_type()                  self.assertTrue(isinstance(mock, mock_type))                  attribute = getattr(mock, valid_attribute) -                self.assertTrue(isinstance(attribute, mock_type.attribute_mocktype)) +                self.assertTrue(isinstance(attribute, mock_type.child_mock_type)) + +    def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self): +        """Test if all coroutine functions are extracted, but not regular methods or attributes.""" +        class CoroutineDonor: +            def __init__(self): +                self.some_attribute = 'alpha' + +            async def first_coroutine(): +                """This coroutine function should be extracted.""" + +            async def second_coroutine(): +                """This coroutine function should be extracted.""" + +            def regular_method(): +                """This regular function should not be extracted.""" + +        class Receiver: +            pass + +        donor = CoroutineDonor() +        receiver = Receiver() + +        helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor) + +        self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock) +        self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock) +        self.assertFalse(hasattr(receiver, 'regular_method')) +        self.assertFalse(hasattr(receiver, 'some_attribute')) + +    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) +    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") +    def test_custom_mock_mixin_init_with_spec(self, extract_method_mock): +        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" +        spec = "pydis" + +        helpers.CustomMockMixin(spec=spec) + +        extract_method_mock.assert_called_once_with(spec) + +    @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) +    @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") +    def test_custom_mock_mixin_init_without_spec(self, extract_method_mock): +        """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" +        helpers.CustomMockMixin() + +        extract_method_mock.assert_not_called()      def test_async_mock_provides_coroutine_for_dunder_call(self):          """Test if AsyncMock objects have a coroutine for their __call__ method.""" | 
