aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/base.py10
-rw-r--r--tests/bot/cogs/test_duck_pond.py35
-rw-r--r--tests/bot/cogs/test_information.py34
-rw-r--r--tests/bot/cogs/test_token_remover.py4
-rw-r--r--tests/bot/rules/__init__.py6
-rw-r--r--tests/bot/rules/test_attachments.py4
-rw-r--r--tests/bot/rules/test_burst.py4
-rw-r--r--tests/bot/rules/test_burst_shared.py4
-rw-r--r--tests/bot/rules/test_chars.py4
-rw-r--r--tests/bot/rules/test_discord_emojis.py4
-rw-r--r--tests/bot/rules/test_duplicates.py4
-rw-r--r--tests/bot/rules/test_links.py4
-rw-r--r--tests/bot/rules/test_mentions.py4
-rw-r--r--tests/bot/rules/test_newlines.py5
-rw-r--r--tests/bot/rules/test_role_mentions.py4
-rw-r--r--tests/bot/test_api.py4
-rw-r--r--tests/bot/utils/test_time.py5
-rw-r--r--tests/helpers.py211
-rw-r--r--tests/test_base.py18
-rw-r--r--tests/test_helpers.py71
-rw-r--r--tests/utils/test_time.py62
21 files changed, 133 insertions, 368 deletions
diff --git a/tests/base.py b/tests/base.py
index 88693f382..21613110e 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -1,5 +1,4 @@
import logging
-import unittest
from contextlib import contextmanager
from typing import Dict
@@ -22,8 +21,13 @@ class _CaptureLogHandler(logging.Handler):
self.records.append(record)
-class LoggingTestCase(unittest.TestCase):
- """TestCase subclass that adds more logging assertion tools."""
+class LoggingTestsMixin:
+ """
+ A mixin that defines additional test methods for logging behavior.
+
+ This mixin relies on the availability of the `fail` attribute defined by the
+ test classes included in Python's unittest method to signal test failure.
+ """
@contextmanager
def assertNotLogs(self, logger=None, level=None, msg=None):
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 5b0a3b8c3..7e6bfc748 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import typing
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import discord
@@ -14,7 +14,7 @@ from tests import helpers
MODULE_PATH = "bot.cogs.duck_pond"
-class DuckPondTests(base.LoggingTestCase):
+class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
"""Tests for DuckPond functionality."""
@classmethod
@@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
self.assertEqual(expected_return, actual_return)
- @helpers.async_test
async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
"""The `has_green_checkmark` method should only return `True` if one is present."""
test_cases = (
@@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):
nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
- @helpers.async_test
async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
"""The `count_ducks` method should return the number of unique staffers who gave a duck."""
test_cases = (
@@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
self.assertEqual(expected_count, actual_count)
- @helpers.async_test
async def test_relay_message_correctly_relays_content_and_attachments(self):
"""The `relay_message` method should correctly relay message content and attachments."""
send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
@@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):
)
for message, expect_webhook_call, expect_attachment_call in test_values:
- with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
- with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:
with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
await self.cog.relay_message(message)
@@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):
message.add_reaction.assert_called_once_with(self.checkmark_emoji)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):
self.cog.webhook = helpers.MockAsyncWebhook()
log = logging.getLogger("bot.cogs.duck_pond")
- for side_effect in side_effects:
+ for side_effect in side_effects: # pragma: no cover
send_attachments.side_effect = side_effect
- with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:
with self.subTest(side_effect=type(side_effect).__name__):
with self.assertNotLogs(logger=log, level=logging.ERROR):
await self.cog.relay_message(message)
self.assertEqual(send_webhook.call_count, 2)
- @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji.name = emoji_name
return payload
- @helpers.async_test
async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
"""The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
test_values = (
@@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):
return channel, message, member, payload
- @helpers.async_test
async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
"""The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
channel_id = 1234
@@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):
channel.fetch_message.reset_mock()
@patch(f"{MODULE_PATH}.DuckPond.is_staff")
- @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)
def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
"""The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
channel_id = 31415926535
@@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):
# Assert that we've made it past `self.is_staff`
is_staff.assert_called_once()
- @helpers.async_test
async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
"""The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
test_cases = (
@@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji = self.duck_pond_emoji
for duck_count, should_relay in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_relay=should_relay):
await self.cog.on_raw_reaction_add(payload)
@@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):
if should_relay:
relay_message.assert_called_once_with(message)
- @helpers.async_test
async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
"""The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
@@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):
(constants.DuckPond.threshold + 1, True),
)
for duck_count, should_re_add_checkmark in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
await self.cog.on_raw_reaction_remove(payload)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index deae7ebad..f5e937356 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
- self.cog.roles_info.can_run = helpers.AsyncMock()
+ self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
@@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase):
self.ctx.guild.roles.append([dummy_role, admin_role])
- self.cog.role_info.can_run = helpers.AsyncMock()
+ self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
@@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
@@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Mr. Hemlock")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase):
self.assertIn("&Admins", embed.description)
self.assertNotIn("&Everyone", embed.description)
- @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock)
- @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
@@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
@@ -529,7 +529,7 @@ class UserCommandTests(unittest.TestCase):
with self.assertRaises(InChannelCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -542,7 +542,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -555,7 +555,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index a54b839d7..33d1ec170 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -1,7 +1,7 @@
import asyncio
import logging
import unittest
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
from discord import Colour
@@ -11,7 +11,7 @@ from bot.cogs.token_remover import (
setup as setup_cog,
)
from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import AsyncMock, MockBot, MockMessage
+from tests.helpers import MockBot, MockMessage
class TokenRemoverTests(unittest.TestCase):
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index 36c986fe1..0d570f5a3 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple):
n_violations: int
-class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
"""
Abstract class for antispam rule test cases.
@@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta):
@abstractmethod
def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
"""Give expected relevant messages for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
@abstractmethod
def get_report(self, case: DisallowedCase) -> str:
"""Give expected error report for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index e54b4b5b8..d7e779221 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import attachments
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_attachments: int) -> MockMessage:
@@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest):
self.apply = attachments.apply
self.config = {"max": 5, "interval": 10}
- @async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
@@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
index 72f0be0c7..03682966b 100644
--- a/tests/bot/rules/test_burst.py
+++ b/tests/bot/rules/test_burst.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest):
self.apply = burst.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
index 47367a5f8..3275143d5 100644
--- a/tests/bot/rules/test_burst_shared.py
+++ b/tests/bot/rules/test_burst_shared.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst_shared
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest):
self.apply = burst_shared.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""
Cases that do not violate the rule.
@@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
index 7cc36f49e..f1e3c76a7 100644
--- a/tests/bot/rules/test_chars.py
+++ b/tests/bot/rules/test_chars.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import chars
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_chars: int) -> MockMessage:
@@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of chars within limit."""
cases = (
@@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the total amount of chars exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
index 0239b0b00..9a72723e2 100644
--- a/tests/bot/rules/test_discord_emojis.py
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import discord_emojis
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
@@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest):
self.apply = discord_emojis.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of discord emojis within limit."""
cases = (
@@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of discord emojis."""
cases = (
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
index 59e0fb6ef..9bd886a77 100644
--- a/tests/bot/rules/test_duplicates.py
+++ b/tests/bot/rules/test_duplicates.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import duplicates
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, content: str) -> MockMessage:
@@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest):
self.apply = duplicates.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with too many duplicate messages from the same author."""
cases = (
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 3c3f90e5f..b091bd9d7 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import links
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_links: int) -> MockMessage:
@@ -21,7 +21,6 @@ class LinksTests(RuleTest):
"interval": 10
}
- @async_test
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
@@ -34,7 +33,6 @@ class LinksTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ebcdabac6..6444532f2 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_mentions: int) -> MockMessage:
@@ -20,7 +20,6 @@ class TestMentions(RuleTest):
"interval": 10,
}
- @async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
@@ -32,7 +31,6 @@ class TestMentions(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
index d61c4609d..e35377773 100644
--- a/tests/bot/rules/test_newlines.py
+++ b/tests/bot/rules/test_newlines.py
@@ -2,7 +2,7 @@ from typing import Iterable, List
from bot.rules import newlines
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
@@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_total(self):
"""Cases which violate the rule by having too many newlines in total."""
cases = (
@@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest):
self.apply = newlines.apply
self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
- @async_test
async def test_disallows_messages_consecutive(self):
"""Cases which violate the rule due to having too many consecutive newlines."""
cases = (
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
index b339cccf7..26c05d527 100644
--- a/tests/bot/rules/test_role_mentions.py
+++ b/tests/bot/rules/test_role_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import role_mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_mentions: int) -> MockMessage:
@@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest):
self.apply = role_mentions.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of role mentions within limit."""
cases = (
@@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of role mentions."""
cases = (
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index bdfcc73e4..99e942813 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -2,10 +2,9 @@ import unittest
from unittest.mock import MagicMock
from bot import api
-from tests.helpers import async_test
-class APIClientTests(unittest.TestCase):
+class APIClientTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the bot's API client."""
@classmethod
@@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase):
"""The event loop should not be running by default."""
self.assertFalse(api.loop_is_running())
- @async_test
async def test_loop_is_running_in_async_context(self):
"""The event loop should be running in an async context."""
self.assertTrue(api.loop_is_running())
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 69f35f2f5..694d3a40f 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -1,12 +1,11 @@
import asyncio
import unittest
from datetime import datetime, timezone
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
from dateutil.relativedelta import relativedelta
from bot.utils import time
-from tests.helpers import AsyncMock
class TimeTests(unittest.TestCase):
@@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):
for max_units in test_cases:
with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
- self.assertEqual(str(error), 'max_units must be positive')
+ self.assertEqual(str(error.exception), 'max_units must be positive')
def test_parse_rfc1123(self):
"""Testing parse_rfc1123."""
diff --git a/tests/helpers.py b/tests/helpers.py
index 9d9dd5da6..7ae7ed621 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,13 +1,10 @@
from __future__ import annotations
-import asyncio
import collections
-import functools
-import inspect
import itertools
import logging
import unittest.mock
-from typing import Any, Iterable, Optional
+from typing import Iterable, Optional
import discord
from discord.ext.commands import Context
@@ -26,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def async_test(wrapped):
- """
- Run a test case via asyncio.
- Example:
- >>> @async_test
- ... async def lemon_wins():
- ... assert True
- """
-
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return asyncio.run(wrapped(*args, **kwargs))
- return wrapper
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.
@@ -69,24 +51,31 @@ class CustomMockMixin:
"""
Provides common functionality for our custom Mock types.
- The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
- function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
- of making sure child mocks are instantiated with the correct class. By default, the mock of the
- children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
- `child_mock_type` on the custom mock inheriting from this mixin.
+ The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock
+ object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the
+ class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional
+ attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The
+ class method `spec_set` can be overwritten with the object that should be uses as the specification
+ for the mock.
+
+ Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to
+ implement custom behavior.
"""
child_mock_type = unittest.mock.MagicMock
discord_id = itertools.count(0)
+ spec_set = None
+ additional_spec_asyncs = None
- def __init__(self, spec_set: Any = None, **kwargs):
+ def __init__(self, **kwargs):
name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually.
- super().__init__(spec_set=spec_set, **kwargs)
+ super().__init__(spec_set=self.spec_set, **kwargs)
+
+ if self.additional_spec_asyncs:
+ self._spec_asyncs.extend(self.additional_spec_asyncs)
if name:
self.name = name
- if spec_set:
- self._extract_coroutine_methods_from_spec_instance(spec_set)
def _get_child_mock(self, **kw):
"""
@@ -100,7 +89,16 @@ class CustomMockMixin:
This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
"""
- klass = self.child_mock_type
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return unittest.mock.AsyncMock(**kw)
+
+ _type = type(self)
+ if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = unittest.mock.AsyncMock
+ else:
+ klass = self.child_mock_type
if self._mock_sealed:
attribute = "." + kw["name"] if "name" in kw else "()"
@@ -109,95 +107,6 @@ class CustomMockMixin:
return klass(**kw)
- def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
- """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
- for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
- setattr(self, name, AsyncMock())
-
-
-# TODO: Remove me in Python 3.8
-class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
-
-
-class AsyncIteratorMock:
- """
- A class to mock asynchronous iterators.
-
- This allows async for, which is used in certain Discord.py objects. For example,
- an async iterator is returned by the Reaction.users() method.
- """
-
- def __init__(self, iterable: Iterable = None):
- if iterable is None:
- iterable = []
-
- self.iter = iter(iterable)
- self.iterable = iterable
-
- self.call_count = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return next(self.iter)
- except StopIteration:
- raise StopAsyncIteration
-
- def __call__(self):
- """
- Keeps track of the number of times an instance has been called.
-
- This is useful, since it typically shows that the iterator has actually been used somewhere after we have
- instantiated the mock for an attribute that normally returns an iterator when called.
- """
- self.call_count += 1
- return self
-
- @property
- def return_value(self):
- """Makes `self.iterable` accessible as self.return_value."""
- return self.iterable
-
- @return_value.setter
- def return_value(self, iterable):
- """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
- self.iter = iter(iterable)
- self.iterable = iterable
-
- def assert_called(self):
- """Asserts if the AsyncIteratorMock instance has been called at least once."""
- if self.call_count == 0:
- raise AssertionError("Expected AsyncIteratorMock to have been called.")
-
- def assert_called_once(self):
- """Asserts if the AsyncIteratorMock instance has been called exactly once."""
- if self.call_count != 1:
- raise AssertionError(
- f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
- )
-
- def assert_not_called(self):
- """Asserts if the AsyncIteratorMock instance has not been called."""
- if self.call_count != 0:
- raise AssertionError(
- f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
- )
-
- def reset_mock(self):
- """Resets the call count, but not the return value or iterator."""
- self.call_count = 0
-
# Create a guild instance to get a realistic Mock of `discord.Guild`
guild_data = {
@@ -248,9 +157,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
+ spec_set = guild_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'members': []}
- super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -269,6 +180,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = role_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {
'id': next(self.discord_id),
@@ -277,7 +190,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
'colour': discord.Colour(0xdeadbf),
'permissions': discord.Permissions(),
}
- super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if isinstance(self.colour, int):
self.colour = discord.Colour(self.colour)
@@ -306,9 +219,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = member_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -329,9 +244,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.User` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = user_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
@@ -362,16 +279,13 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = bot_instance
+ additional_spec_asyncs = ("wait_for",)
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=bot_instance, **kwargs)
+ super().__init__(**kwargs)
self.api_client = MockAPIClient()
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- self.wait_for = AsyncMock()
-
# Since calling `create_task` on our MockBot does not actually schedule the coroutine object
# as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
# to prevent "has not been awaited"-warnings.
@@ -401,10 +315,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = channel_instance
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
- super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"#{self.name}"
@@ -443,9 +358,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+ spec_set = context_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=context_instance, **kwargs)
+ super().__init__(**kwargs)
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -462,8 +378,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Attachment` instances. For
more information, see the `MockGuild` docstring.
"""
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=attachment_instance, **kwargs)
+ spec_set = attachment_instance
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
@@ -473,10 +388,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = message_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'attachments': []}
- super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -492,9 +408,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = emoji_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=emoji_instance, **kwargs)
+ super().__init__(**kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -508,9 +425,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=partial_emoji_instance, **kwargs)
+ spec_set = partial_emoji_instance
reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
@@ -523,12 +438,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = reaction_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=reaction_instance, **kwargs)
+ _users = kwargs.pop("users", [])
+ super().__init__(**kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
- self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+ user_iterator = unittest.mock.AsyncMock()
+ user_iterator.__aiter__.return_value = _users
+ self.users.return_value = user_iterator
+
self.__str__.return_value = str(self.emoji)
@@ -542,13 +463,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Webhook` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=webhook_instance, **kwargs)
-
- # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
- # as coroutines. That's why we need to set the methods manually.
- self.send = AsyncMock()
- self.edit = AsyncMock()
- self.delete = AsyncMock()
- self.execute = AsyncMock()
+ spec_set = webhook_instance
+ additional_spec_asyncs = ("send", "edit", "delete", "execute")
diff --git a/tests/test_base.py b/tests/test_base.py
index a16e2af8f..23abb1dfd 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -3,7 +3,11 @@ import unittest
import unittest.mock
-from tests.base import LoggingTestCase, _CaptureLogHandler
+from tests.base import LoggingTestsMixin, _CaptureLogHandler
+
+
+class LoggingTestCase(LoggingTestsMixin):
+ pass
class LoggingTestCaseTests(unittest.TestCase):
@@ -18,19 +22,9 @@ class LoggingTestCaseTests(unittest.TestCase):
try:
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
pass
- except AssertionError:
+ except AssertionError: # pragma: no cover
self.fail("`self.assertNotLogs` raised an AssertionError when it should not!")
- @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs")
- def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs):
- """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly."""
- assertNotLogs.return_value = iter([None])
- assertNotLogs.side_effect = AssertionError
-
- message = "`self.assertNotLogs` raised an AssertionError when it should not!"
- with self.assertRaises(AssertionError, msg=message):
- self.test_assert_not_logs_does_not_raise_with_no_logs()
-
def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):
"""Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""
msg_regex = (
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 7894e104a..81285e009 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
import unittest
import unittest.mock
@@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):
with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):
asyncio.run(coroutine_object)
+ def test_user_mock_uses_explicitly_passed_mention_attribute(self):
+ """MockUser should use an explicitly passed value for user.mention."""
+ user = helpers.MockUser(mention="hello")
+ self.assertEqual(user.mention, "hello")
+
class MockObjectTests(unittest.TestCase):
"""Tests the mock objects and mixins we've defined."""
@@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase):
attribute = getattr(mock, valid_attribute)
self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
- def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
- """Test if all coroutine functions are extracted, but not regular methods or attributes."""
- class CoroutineDonor:
- def __init__(self):
- self.some_attribute = 'alpha'
-
- async def first_coroutine():
- """This coroutine function should be extracted."""
-
- async def second_coroutine():
- """This coroutine function should be extracted."""
-
- def regular_method():
- """This regular function should not be extracted."""
-
- class Receiver:
+ def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self):
+ """The CustomMockMixin should mock async magic methods with an AsyncMock."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
pass
- donor = CoroutineDonor()
- receiver = Receiver()
-
- helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
-
- self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
- self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
- self.assertFalse(hasattr(receiver, 'regular_method'))
- self.assertFalse(hasattr(receiver, 'some_attribute'))
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- spec_set = "pydis"
-
- helpers.CustomMockMixin(spec_set=spec_set)
-
- extract_method_mock.assert_called_once_with(spec_set)
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- helpers.CustomMockMixin()
-
- extract_method_mock.assert_not_called()
-
- def test_async_mock_provides_coroutine_for_dunder_call(self):
- """Test if AsyncMock objects have a coroutine for their __call__ method."""
- async_mock = helpers.AsyncMock()
- self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__))
-
- coroutine = async_mock()
- self.assertTrue(inspect.iscoroutine(coroutine))
- self.assertIsNotNone(asyncio.run(coroutine))
-
- def test_async_test_decorator_allows_synchronous_call_to_async_def(self):
- """Test if the `async_test` decorator allows an `async def` to be called synchronously."""
- @helpers.async_test
- async def kosayoda():
- return "return value"
-
- self.assertEqual(kosayoda(), "return value")
+ mock = MyMock()
+ self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock)
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
deleted file mode 100644
index 4baa6395c..000000000
--- a/tests/utils/test_time.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import asyncio
-from datetime import datetime, timezone
-from unittest.mock import patch
-
-import pytest
-from dateutil.relativedelta import relativedelta
-
-from bot.utils import time
-from tests.helpers import AsyncMock
-
-
- ('delta', 'precision', 'max_units', 'expected'),
- (
- (relativedelta(days=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
- (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
-
- # Does not abort for unknown units, as the unit name is checked
- # against the attribute of the relativedelta instance.
- (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'),
-
- # Very high maximum units, but it only ever iterates over
- # each value the relativedelta might have.
- (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'),
- )
-)
-def test_humanize_delta(
- delta: relativedelta,
- precision: str,
- max_units: int,
- expected: str
-):
- assert time.humanize_delta(delta, precision, max_units) == expected
-
-
[email protected]('max_units', (-1, 0))
-def test_humanize_delta_raises_for_invalid_max_units(max_units: int):
- with pytest.raises(ValueError, match='max_units must be positive'):
- time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
-
-
- ('stamp', 'expected'),
- (
- ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)),
- )
-)
-def test_parse_rfc1123(stamp: str, expected: str):
- assert time.parse_rfc1123(stamp) == expected
-
-
-@patch('asyncio.sleep', new_callable=AsyncMock)
-def test_wait_until(sleep_patch):
- start = datetime(2019, 1, 1, 0, 0)
- then = datetime(2019, 1, 1, 0, 10)
-
- # No return value
- assert asyncio.run(time.wait_until(then, start)) is None
-
- sleep_patch.assert_called_once_with(10 * 60)