aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/off_topic_names.py16
-rw-r--r--bot/cogs/reddit.py11
-rw-r--r--bot/cogs/site.py4
-rw-r--r--bot/constants.py4
-rw-r--r--config-default.yml4
-rw-r--r--tests/bot/test_utils.py52
-rw-r--r--tests/helpers.py315
-rw-r--r--tests/test_helpers.py68
8 files changed, 277 insertions, 197 deletions
diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py
index 1f9fb0b4f..78792240f 100644
--- a/bot/cogs/off_topic_names.py
+++ b/bot/cogs/off_topic_names.py
@@ -24,6 +24,9 @@ class OffTopicName(Converter):
"""Attempt to replace any invalid characters with their approximate Unicode equivalent."""
allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+ # Chain multiple words to a single one
+ argument = "-".join(argument.split())
+
if not (2 <= len(argument) <= 96):
raise BadArgument("Channel name must be between 2 and 96 chars long")
@@ -97,15 +100,12 @@ class OffTopicNames(Cog):
@otname_group.command(name='add', aliases=('a',))
@with_role(*MODERATION_ROLES)
- async def add_command(self, ctx: Context, *names: OffTopicName) -> None:
+ async def add_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""
Adds a new off-topic name to the rotation.
The name is not added if it is too similar to an existing name.
"""
- # Chain multiple words to a single one
- name = "-".join(names)
-
existing_names = await self.bot.api_client.get('bot/off-topic-channel-names')
close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8)
@@ -123,10 +123,8 @@ class OffTopicNames(Cog):
@otname_group.command(name='forceadd', aliases=('fa',))
@with_role(*MODERATION_ROLES)
- async def force_add_command(self, ctx: Context, *names: OffTopicName) -> None:
+ async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""Forcefully adds a new off-topic name to the rotation."""
- # Chain multiple words to a single one
- name = "-".join(names)
await self._add_name(ctx, name)
async def _add_name(self, ctx: Context, name: str) -> None:
@@ -138,10 +136,8 @@ class OffTopicNames(Cog):
@otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd'))
@with_role(*MODERATION_ROLES)
- async def delete_command(self, ctx: Context, *names: OffTopicName) -> None:
+ async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""Removes a off-topic name from the rotation."""
- # Chain multiple words to a single one
- name = "-".join(names)
await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}')
log.info(f"{ctx.author} deleted the off-topic channel name '{name}'")
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 7749d237f..0d06e9c26 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,14 +2,14 @@ import asyncio
import logging
import random
import textwrap
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import List
from discord import Colour, Embed, TextChannel
from discord.ext.commands import Bot, Cog, Context, group
from discord.ext.tasks import loop
-from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES, Webhooks
+from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
from bot.converters import Subreddit
from bot.decorators import with_role
from bot.pagination import LinePaginator
@@ -117,9 +117,9 @@ class Reddit(Cog):
link = self.URL + data["permalink"]
embed.description += (
- f"[**{title}**]({link})\n"
+ f"**[{title}]({link})**\n"
f"{text}"
- f"| {ups} upvotes | {comments} comments | u/{author} | {subreddit} |\n\n"
+ f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n"
)
embed.colour = Colour.blurple()
@@ -130,7 +130,8 @@ class Reddit(Cog):
"""Post the top 5 posts daily, and the top 5 posts weekly."""
# once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter
now = datetime.utcnow()
- midnight_tomorrow = now.replace(day=now.day + 1, hour=0, minute=0, second=0)
+ tomorrow = now + timedelta(days=1)
+ midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)
seconds_until = (midnight_tomorrow - now).total_seconds()
await asyncio.sleep(seconds_until)
diff --git a/bot/cogs/site.py b/bot/cogs/site.py
index d95359159..683613788 100644
--- a/bot/cogs/site.py
+++ b/bot/cogs/site.py
@@ -3,8 +3,7 @@ import logging
from discord import Colour, Embed
from discord.ext.commands import Bot, Cog, Context, group
-from bot.constants import Channels, STAFF_ROLES, URLs
-from bot.decorators import redirect_output
+from bot.constants import URLs
from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
@@ -105,7 +104,6 @@ class Site(Cog):
await ctx.send(embed=embed)
@site_group.command(aliases=['r', 'rule'], name='rules')
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
async def site_rules(self, ctx: Context, *rules: int) -> None:
"""Provides a link to all rules or, if specified, displays specific rule(s)."""
rules_embed = Embed(title='Rules', color=Colour.blurple())
diff --git a/bot/constants.py b/bot/constants.py
index dbbf32063..c78c06227 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -269,6 +269,10 @@ class Emojis(metaclass=YAMLGetter):
ducky_devil: int
ducky_tube: int
+ upvotes: str
+ comments: str
+ user: str
+
class Icons(metaclass=YAMLGetter):
section = "style"
diff --git a/config-default.yml b/config-default.yml
index 76892677e..8583cb6da 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -40,6 +40,10 @@ style:
ducky_devil: &DUCKY_DEVIL 637925314982576139
ducky_tube: &DUCKY_TUBE 637881368008851456
+ upvotes: "<:upvotes:638729835245731840>"
+ comments: "<:comments:638729835073765387>"
+ user: "<:user:638729835442602003>"
+
icons:
crown_blurple: "https://cdn.discordapp.com/emojis/469964153289965568.png"
crown_green: "https://cdn.discordapp.com/emojis/469964154719961088.png"
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
new file mode 100644
index 000000000..58ae2a81a
--- /dev/null
+++ b/tests/bot/test_utils.py
@@ -0,0 +1,52 @@
+import unittest
+
+from bot import utils
+
+
+class CaseInsensitiveDictTests(unittest.TestCase):
+ """Tests for the `CaseInsensitiveDict` container."""
+
+ def test_case_insensitive_key_access(self):
+ """Tests case insensitive key access and storage."""
+ instance = utils.CaseInsensitiveDict()
+
+ key = 'LEMON'
+ value = 'trees'
+
+ instance[key] = value
+ self.assertIn(key, instance)
+ self.assertEqual(instance.get(key), value)
+ self.assertEqual(instance.get(key.casefold()), value)
+ self.assertEqual(instance.pop(key.casefold()), value)
+ self.assertNotIn(key, instance)
+ self.assertNotIn(key.casefold(), instance)
+
+ instance.setdefault(key, value)
+ del instance[key]
+ self.assertNotIn(key, instance)
+
+ def test_initialization_from_kwargs(self):
+ """Tests creating the dictionary from keyword arguments."""
+ instance = utils.CaseInsensitiveDict({'FOO': 'bar'})
+ self.assertEqual(instance['foo'], 'bar')
+
+ def test_update_from_other_mapping(self):
+ """Tests updating the dictionary from another mapping."""
+ instance = utils.CaseInsensitiveDict()
+ instance.update({'FOO': 'bar'})
+ self.assertEqual(instance['foo'], 'bar')
+
+
+class ChunkTests(unittest.TestCase):
+ """Tests the `chunk` method."""
+
+ def test_empty_chunking(self):
+ """Tests chunking on an empty iterable."""
+ generator = utils.chunks(iterable=[], size=5)
+ self.assertEqual(list(generator), [])
+
+ def test_list_chunking(self):
+ """Tests chunking a non-empty list."""
+ iterable = [1, 2, 3, 4, 5]
+ generator = utils.chunks(iterable=iterable, size=2)
+ self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/helpers.py b/tests/helpers.py
index 892d42e6c..8496ba031 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -2,8 +2,9 @@ from __future__ import annotations
import asyncio
import functools
+import inspect
import unittest.mock
-from typing import Iterable, Optional
+from typing import Any, Iterable, Optional
import discord
from discord.ext.commands import Bot, Context
@@ -24,19 +25,6 @@ def async_test(wrapped):
return wrapper
-# TODO: Remove me in Python 3.8
-class AsyncMock(unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super(AsyncMock, self).__call__(*args, **kwargs)
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.
@@ -61,15 +49,61 @@ class ColourMixin:
self.colour = color
-class AttributeMock:
- """Ensures attributes of our mock types will be instantiated with the correct mock type."""
+class CustomMockMixin:
+ """
+ Provides common functionality for our custom Mock types.
- def __new__(cls, *args, **kwargs):
- """Stops the regular parent class from propagating to newly mocked attributes."""
- if 'parent' in kwargs:
- return cls.attribute_mocktype(*args, **kwargs)
+ The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
+ function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
+ of making sure child mocks are instantiated with the correct class. By default, the mock of the
+ children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
+ `child_mock_type` on the custom mock inheriting from this mixin.
+ """
- return super().__new__(cls)
+ child_mock_type = unittest.mock.MagicMock
+
+ def __init__(self, spec: Any = None, **kwargs):
+ super().__init__(spec=spec, **kwargs)
+ if spec:
+ self._extract_coroutine_methods_from_spec_instance(spec)
+
+ def _get_child_mock(self, **kw):
+ """
+ Overwrite of the `_get_child_mock` method to stop the propagation of our custom mock classes.
+
+ Mock objects automatically create children when you access an attribute or call a method on them. By default,
+ the class of these children is the type of the parent itself. However, this would mean that the children created
+ for our custom mock types would also be instances of that custom mock type. This is not desirable, as attributes
+ of, e.g., a `Bot` object are not `Bot` objects themselves. The Python docs for `unittest.mock` hint that
+ overwriting this method is the best way to deal with that.
+
+ This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
+ """
+ klass = self.child_mock_type
+
+ if self._mock_sealed:
+ attribute = "." + kw["name"] if "name" in kw else "()"
+ mock_name = self._extract_mock_name() + attribute
+ raise AttributeError(mock_name)
+
+ return klass(**kw)
+
+ def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
+ """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
+ for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
+ setattr(self, name, AsyncMock())
+
+
+# TODO: Remove me in Python 3.8
+class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock async callables.
+
+ Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
+ features; this stand-in only overwrites the `__call__` method to an async version.
+ """
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
# Create a guild instance to get a realistic Mock of `discord.Guild`
@@ -95,7 +129,7 @@ guild_data = {
guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock())
-class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin):
+class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
A `Mock` subclass to mock `discord.Guild` objects.
@@ -121,9 +155,6 @@ class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
-
- attribute_mocktype = unittest.mock.MagicMock
-
def __init__(
self,
guild_id: int = 1,
@@ -143,48 +174,19 @@ class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin):
if members:
self.members.extend(members)
- # `discord.Guild` coroutines
- self.create_category_channel = AsyncMock()
- self.ban = AsyncMock()
- self.bans = AsyncMock()
- self.create_category = AsyncMock()
- self.create_custom_emoji = AsyncMock()
- self.create_role = AsyncMock()
- self.create_text_channel = AsyncMock()
- self.create_voice_channel = AsyncMock()
- self.delete = AsyncMock()
- self.edit = AsyncMock()
- self.estimate_pruned_members = AsyncMock()
- self.fetch_ban = AsyncMock()
- self.fetch_channels = AsyncMock()
- self.fetch_emoji = AsyncMock()
- self.fetch_emojis = AsyncMock()
- self.fetch_member = AsyncMock()
- self.invites = AsyncMock()
- self.kick = AsyncMock()
- self.leave = AsyncMock()
- self.prune_members = AsyncMock()
- self.unban = AsyncMock()
- self.vanity_invite = AsyncMock()
- self.webhooks = AsyncMock()
- self.widget = AsyncMock()
-
# Create a Role instance to get a realistic Mock of `discord.Role`
role_data = {'name': 'role', 'id': 1}
role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data)
-class MockRole(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):
+class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""
A Mock subclass to mock `discord.Role` objects.
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
-
- attribute_mocktype = unittest.mock.MagicMock
-
def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None:
super().__init__(spec=role_instance, **kwargs)
@@ -193,10 +195,6 @@ class MockRole(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):
self.position = position
self.mention = f'&{self.name}'
- # 'discord.Role' coroutines
- self.delete = AsyncMock()
- self.edit = AsyncMock()
-
def __lt__(self, other):
"""Simplified position-based comparisons similar to those of `discord.Role`."""
return self.position < other.position
@@ -208,16 +206,13 @@ state_mock = unittest.mock.MagicMock()
member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock)
-class MockMember(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):
+class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""
A Mock subclass to mock Member objects.
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
-
- attribute_mocktype = unittest.mock.MagicMock
-
def __init__(
self,
name: str = "member",
@@ -236,98 +231,29 @@ class MockMember(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin):
self.mention = f"@{self.name}"
- # `discord.Member` coroutines
- self.add_roles = AsyncMock()
- self.ban = AsyncMock()
- self.edit = AsyncMock()
- self.fetch_message = AsyncMock()
- self.kick = AsyncMock()
- self.move_to = AsyncMock()
- self.pins = AsyncMock()
- self.remove_roles = AsyncMock()
- self.send = AsyncMock()
- self.trigger_typing = AsyncMock()
- self.unban = AsyncMock()
-
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
-class MockBot(AttributeMock, unittest.mock.MagicMock):
+class MockBot(CustomMockMixin, unittest.mock.MagicMock):
"""
A MagicMock subclass to mock Bot objects.
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
-
- attribute_mocktype = unittest.mock.MagicMock
-
def __init__(self, **kwargs) -> None:
super().__init__(spec=bot_instance, **kwargs)
- # `discord.ext.commands.Bot` coroutines
- self._before_invoke = AsyncMock()
- self._after_invoke = AsyncMock()
- self.application_info = AsyncMock()
- self.change_presence = AsyncMock()
- self.connect = AsyncMock()
- self.close = AsyncMock()
- self.create_guild = AsyncMock()
- self.delete_invite = AsyncMock()
- self.fetch_channel = AsyncMock()
- self.fetch_guild = AsyncMock()
- self.fetch_guilds = AsyncMock()
- self.fetch_invite = AsyncMock()
- self.fetch_user = AsyncMock()
- self.fetch_user_profile = AsyncMock()
- self.fetch_webhook = AsyncMock()
- self.fetch_widget = AsyncMock()
- self.get_context = AsyncMock()
- self.get_prefix = AsyncMock()
- self.invoke = AsyncMock()
- self.is_owner = AsyncMock()
- self.login = AsyncMock()
- self.logout = AsyncMock()
- self.on_command_error = AsyncMock()
- self.on_error = AsyncMock()
- self.process_commands = AsyncMock()
- self.request_offline_members = AsyncMock()
- self.start = AsyncMock()
- self.wait_until_ready = AsyncMock()
- self.wait_for = AsyncMock()
-
-
-# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
-context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
-
-
-class MockContext(AttributeMock, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock Context objects.
-
- Instances of this class will follow the specifications of `discord.ext.commands.Context`
- instances. For more information, see the `MockGuild` docstring.
- """
-
- attribute_mocktype = unittest.mock.MagicMock
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec=context_instance, **kwargs)
- self.bot = MockBot()
- self.guild = MockGuild()
- self.author = MockMember()
- self.command = unittest.mock.MagicMock()
+ # Our custom attributes and methods
+ self.http_session = unittest.mock.MagicMock()
+ self.api_client = unittest.mock.MagicMock()
- # `discord.ext.commands.Context` coroutines
- self.fetch_message = AsyncMock()
- self.invoke = AsyncMock()
- self.pins = AsyncMock()
- self.reinvoke = AsyncMock()
- self.send = AsyncMock()
- self.send_help = AsyncMock()
- self.trigger_typing = AsyncMock()
+ # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
+ # and should therefore be awaited. (The documentation calls it a coroutine as well, which
+ # is technically incorrect, since it's a regular def.)
+ self.wait_for = AsyncMock()
# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`
@@ -346,39 +272,20 @@ guild = unittest.mock.MagicMock()
channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data)
-class MockTextChannel(AttributeMock, unittest.mock.Mock, HashableMixin):
+class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
A MagicMock subclass to mock TextChannel objects.
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
-
- attribute_mocktype = unittest.mock.MagicMock
-
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
super().__init__(spec=channel_instance, **kwargs)
self.id = channel_id
self.name = name
- self.guild = MockGuild()
+ self.guild = kwargs.get('guild', MockGuild())
self.mention = f"#{self.name}"
- # `discord.TextChannel` coroutines
- self.clone = AsyncMock()
- self.create_invite = AsyncMock()
- self.create_webhook = AsyncMock()
- self.delete = AsyncMock()
- self.delete_messages = AsyncMock()
- self.edit = AsyncMock()
- self.fetch_message = AsyncMock()
- self.invites = AsyncMock()
- self.pins = AsyncMock()
- self.purge = AsyncMock()
- self.send = AsyncMock()
- self.set_permissions = AsyncMock()
- self.trigger_typing = AsyncMock()
- self.webhooks = AsyncMock()
-
# Create a Message instance to get a realistic MagicMock of `discord.Message`
message_data = {
@@ -402,27 +309,83 @@ channel = unittest.mock.MagicMock()
message_instance = discord.Message(state=state, channel=channel, data=message_data)
-class MockMessage(AttributeMock, unittest.mock.MagicMock):
+# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
+context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+
+
+class MockContext(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Context objects.
+
+ Instances of this class will follow the specifications of `discord.ext.commands.Context`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=context_instance, **kwargs)
+ self.bot = kwargs.get('bot', MockBot())
+ self.guild = kwargs.get('guild', MockGuild())
+ self.author = kwargs.get('author', MockMember())
+ self.channel = kwargs.get('channel', MockTextChannel())
+ self.command = kwargs.get('command', unittest.mock.MagicMock())
+
+
+class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
"""
A MagicMock subclass to mock Message objects.
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=message_instance, **kwargs)
+ self.author = kwargs.get('author', MockMember())
+ self.channel = kwargs.get('channel', MockTextChannel())
- attribute_mocktype = unittest.mock.MagicMock
+emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'}
+emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data)
+
+
+class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Emoji objects.
+
+ Instances of this class will follow the specifications of `discord.Emoji` instances. For more
+ information, see the `MockGuild` docstring.
+ """
def __init__(self, **kwargs) -> None:
- super().__init__(spec=message_instance, **kwargs)
- self.author = MockMember()
- self.channel = MockTextChannel()
-
- # `discord.Message` coroutines
- self.ack = AsyncMock()
- self.add_reaction = AsyncMock()
- self.clear_reactions = AsyncMock()
- self.delete = AsyncMock()
- self.edit = AsyncMock()
- self.pin = AsyncMock()
- self.remove_reaction = AsyncMock()
- self.unpin = AsyncMock()
+ super().__init__(spec=emoji_instance, **kwargs)
+ self.guild = kwargs.get('guild', MockGuild())
+
+ # Get all coroutine functions and set them as AsyncMock attributes
+ self._extract_coroutine_methods_from_spec_instance(emoji_instance)
+
+
+partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido')
+
+
+class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock PartialEmoji objects.
+
+ Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=partial_emoji_instance, **kwargs)
+
+
+reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
+
+
+class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Reaction objects.
+
+ Instances of this class will follow the specifications of `discord.Reaction` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=reaction_instance, **kwargs)
+ self.emoji = kwargs.get('emoji', MockEmoji())
+ self.message = kwargs.get('message', MockMessage())
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index f08239981..2b58634dd 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -221,10 +221,10 @@ class DiscordMocksTests(unittest.TestCase):
@unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest')
def test_the_custom_mock_methods_test(self, subtest_mock):
"""The custom method test should raise AssertionError for invalid methods."""
- class FakeMockBot(helpers.AttributeMock, unittest.mock.MagicMock):
+ class FakeMockBot(helpers.CustomMockMixin, unittest.mock.MagicMock):
"""Fake MockBot class with invalid attribute/method `release_the_walrus`."""
- attribute_mocktype = unittest.mock.MagicMock
+ child_mock_type = unittest.mock.MagicMock
def __init__(self, **kwargs):
super().__init__(spec=helpers.bot_instance, **kwargs)
@@ -331,6 +331,18 @@ class MockObjectTests(unittest.TestCase):
self.assertFalse(instance_one != instance_two)
self.assertTrue(instance_one != instance_three)
+ def test_custom_mock_mixin_accepts_mock_seal(self):
+ """The `CustomMockMixin` should support `unittest.mock.seal`."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
+
+ child_mock_type = unittest.mock.MagicMock
+ pass
+
+ mock = MyMock()
+ unittest.mock.seal(mock)
+ with self.assertRaises(AttributeError, msg="MyMock.shirayuki"):
+ mock.shirayuki = "hello!"
+
def test_spec_propagation_of_mock_subclasses(self):
"""Test if the `spec` does not propagate to attributes of the mock object."""
test_values = (
@@ -339,6 +351,10 @@ class MockObjectTests(unittest.TestCase):
(helpers.MockMember, "display_name"),
(helpers.MockBot, "owner_id"),
(helpers.MockContext, "command_failed"),
+ (helpers.MockMessage, "mention_everyone"),
+ (helpers.MockEmoji, 'managed'),
+ (helpers.MockPartialEmoji, 'url'),
+ (helpers.MockReaction, 'me'),
)
for mock_type, valid_attribute in test_values:
@@ -346,7 +362,53 @@ class MockObjectTests(unittest.TestCase):
mock = mock_type()
self.assertTrue(isinstance(mock, mock_type))
attribute = getattr(mock, valid_attribute)
- self.assertTrue(isinstance(attribute, mock_type.attribute_mocktype))
+ self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
+
+ def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
+ """Test if all coroutine functions are extracted, but not regular methods or attributes."""
+ class CoroutineDonor:
+ def __init__(self):
+ self.some_attribute = 'alpha'
+
+ async def first_coroutine():
+ """This coroutine function should be extracted."""
+
+ async def second_coroutine():
+ """This coroutine function should be extracted."""
+
+ def regular_method():
+ """This regular function should not be extracted."""
+
+ class Receiver:
+ pass
+
+ donor = CoroutineDonor()
+ receiver = Receiver()
+
+ helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
+
+ self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
+ self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
+ self.assertFalse(hasattr(receiver, 'regular_method'))
+ self.assertFalse(hasattr(receiver, 'some_attribute'))
+
+ @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
+ @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
+ def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
+ """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
+ spec = "pydis"
+
+ helpers.CustomMockMixin(spec=spec)
+
+ extract_method_mock.assert_called_once_with(spec)
+
+ @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
+ @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
+ def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
+ """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
+ helpers.CustomMockMixin()
+
+ extract_method_mock.assert_not_called()
def test_async_mock_provides_coroutine_for_dunder_call(self):
"""Test if AsyncMock objects have a coroutine for their __call__ method."""