aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/README.md10
-rw-r--r--tests/base.py16
-rw-r--r--tests/bot/cogs/moderation/__init__.py0
-rw-r--r--tests/bot/cogs/moderation/test_infractions.py55
-rw-r--r--tests/bot/cogs/moderation/test_modlog.py29
-rw-r--r--tests/bot/cogs/moderation/test_silence.py261
-rw-r--r--tests/bot/cogs/sync/test_base.py50
-rw-r--r--tests/bot/cogs/sync/test_cog.py52
-rw-r--r--tests/bot/cogs/sync/test_roles.py12
-rw-r--r--tests/bot/cogs/sync/test_users.py15
-rw-r--r--tests/bot/cogs/test_antimalware.py159
-rw-r--r--tests/bot/cogs/test_cogs.py80
-rw-r--r--tests/bot/cogs/test_duck_pond.py37
-rw-r--r--tests/bot/cogs/test_information.py58
-rw-r--r--tests/bot/cogs/test_snekbox.py135
-rw-r--r--tests/bot/cogs/test_token_remover.py365
-rw-r--r--tests/bot/rules/__init__.py6
-rw-r--r--tests/bot/rules/test_attachments.py4
-rw-r--r--tests/bot/rules/test_burst.py4
-rw-r--r--tests/bot/rules/test_burst_shared.py4
-rw-r--r--tests/bot/rules/test_chars.py4
-rw-r--r--tests/bot/rules/test_discord_emojis.py4
-rw-r--r--tests/bot/rules/test_duplicates.py4
-rw-r--r--tests/bot/rules/test_links.py4
-rw-r--r--tests/bot/rules/test_mentions.py4
-rw-r--r--tests/bot/rules/test_newlines.py5
-rw-r--r--tests/bot/rules/test_role_mentions.py4
-rw-r--r--tests/bot/test_api.py4
-rw-r--r--tests/bot/test_constants.py43
-rw-r--r--tests/bot/test_converters.py133
-rw-r--r--tests/bot/test_decorators.py147
-rw-r--r--tests/bot/test_utils.py37
-rw-r--r--tests/bot/utils/test_checks.py52
-rw-r--r--tests/bot/utils/test_redis_cache.py265
-rw-r--r--tests/bot/utils/test_time.py5
-rw-r--r--tests/helpers.py290
-rw-r--r--tests/test_base.py20
-rw-r--r--tests/test_helpers.py71
-rw-r--r--tests/utils/test_time.py62
39 files changed, 1761 insertions, 749 deletions
diff --git a/tests/README.md b/tests/README.md
index be78821bf..4f62edd68 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -83,7 +83,7 @@ TagContentConverter should return correct values for valid input.
As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it.
-However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".
+However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".
To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.).
@@ -114,13 +114,13 @@ class BotCogTests(unittest.TestCase):
### Mocking coroutines
-By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
+By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
### Special mocks for some `discord.py` types
To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass.
-In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.
+In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.
These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR.
@@ -144,7 +144,7 @@ Finally, there are some considerations to make when writing tests, both for writ
### Test coverage is a starting point
-Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.
+Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.
One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it:
@@ -169,7 +169,7 @@ class FunctionsTests(unittest.TestCase):
If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch?
-The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).
+The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).
Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data:
diff --git a/tests/base.py b/tests/base.py
index 88693f382..d99b9ac31 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -22,11 +22,16 @@ class _CaptureLogHandler(logging.Handler):
self.records.append(record)
-class LoggingTestCase(unittest.TestCase):
- """TestCase subclass that adds more logging assertion tools."""
+class LoggingTestsMixin:
+ """
+ A mixin that defines additional test methods for logging behavior.
+
+ This mixin relies on the availability of the `fail` attribute defined by the
+ test classes included in Python's unittest method to signal test failure.
+ """
@contextmanager
- def assertNotLogs(self, logger=None, level=None, msg=None):
+ def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802
"""
Asserts that no logs of `level` and higher were emitted by `logger`.
@@ -73,11 +78,10 @@ class LoggingTestCase(unittest.TestCase):
self.fail(msg)
-class CommandTestCase(unittest.TestCase):
+class CommandTestCase(unittest.IsolatedAsyncioTestCase):
"""TestCase with additional assertions that are useful for testing Discord commands."""
- @helpers.async_test
- async def assertHasPermissionsCheck(
+ async def assertHasPermissionsCheck( # noqa: N802
self,
cmd: commands.Command,
permissions: Dict[str, bool],
diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/cogs/moderation/__init__.py
diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py
new file mode 100644
index 000000000..da4e92ccc
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_infractions.py
@@ -0,0 +1,55 @@
+import textwrap
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from bot.cogs.moderation.infractions import Infractions
+from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole
+
+
+class TruncationTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for ban and kick command reason truncation."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = Infractions(self.bot)
+ self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10))
+ self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0))
+ self.guild = MockGuild(id=4567)
+ self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild)
+
+ @patch("bot.cogs.moderation.utils.get_active_infraction")
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock):
+ """Should truncate reason for `ctx.guild.ban`."""
+ get_active_mock.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.bot.get_cog.return_value = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.ctx.guild.ban = Mock()
+
+ await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000)
+ self.ctx.guild.ban.assert_called_once_with(
+ self.target,
+ reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),
+ delete_message_days=0
+ )
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value
+ )
+
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_kick_reason_truncation(self, post_infraction_mock):
+ """Should truncate reason for `Member.kick`."""
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.target.kick = Mock()
+
+ await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000)
+ self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
+ )
diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py
new file mode 100644
index 000000000..f2809f40a
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_modlog.py
@@ -0,0 +1,29 @@
+import unittest
+
+import discord
+
+from bot.cogs.moderation.modlog import ModLog
+from tests.helpers import MockBot, MockTextChannel
+
+
+class ModLogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for moderation logs."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = ModLog(self.bot)
+ self.channel = MockTextChannel()
+
+ async def test_log_entry_description_truncation(self):
+ """Test that embed description for ModLog entry is truncated."""
+ self.bot.get_channel.return_value = self.channel
+ await self.cog.send_log_message(
+ icon_url="foo",
+ colour=discord.Colour.blue(),
+ title="bar",
+ text="foo bar" * 3000
+ )
+ embed = self.channel.send.call_args[1]["embed"]
+ self.assertEqual(
+ embed.description, ("foo bar" * 3000)[:2045] + "..."
+ )
diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py
new file mode 100644
index 000000000..ab3d0742a
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_silence.py
@@ -0,0 +1,261 @@
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock, Mock
+
+from discord import PermissionOverwrite
+
+from bot.cogs.moderation.silence import Silence, SilenceNotifier
+from bot.constants import Channels, Emojis, Guild, Roles
+from tests.helpers import MockBot, MockContext, MockTextChannel
+
+
+class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self) -> None:
+ self.alert_channel = MockTextChannel()
+ self.notifier = SilenceNotifier(self.alert_channel)
+ self.notifier.stop = self.notifier_stop_mock = Mock()
+ self.notifier.start = self.notifier_start_mock = Mock()
+
+ def test_add_channel_adds_channel(self):
+ """Channel in FirstHash with current loop is added to internal set."""
+ channel = Mock()
+ with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
+ self.notifier.add_channel(channel)
+ silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop)
+
+ def test_add_channel_starts_loop(self):
+ """Loop is started if `_silenced_channels` was empty."""
+ self.notifier.add_channel(Mock())
+ self.notifier_start_mock.assert_called_once()
+
+ def test_add_channel_skips_start_with_channels(self):
+ """Loop start is not called when `_silenced_channels` is not empty."""
+ with mock.patch.object(self.notifier, "_silenced_channels"):
+ self.notifier.add_channel(Mock())
+ self.notifier_start_mock.assert_not_called()
+
+ def test_remove_channel_removes_channel(self):
+ """Channel in FirstHash is removed from `_silenced_channels`."""
+ channel = Mock()
+ with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
+ self.notifier.remove_channel(channel)
+ silenced_channels.__delitem__.assert_called_with(channel)
+
+ def test_remove_channel_stops_loop(self):
+ """Notifier loop is stopped if `_silenced_channels` is empty after remove."""
+ with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False):
+ self.notifier.remove_channel(Mock())
+ self.notifier_stop_mock.assert_called_once()
+
+ def test_remove_channel_skips_stop_with_channels(self):
+ """Notifier loop is not stopped if `_silenced_channels` is not empty after remove."""
+ self.notifier.remove_channel(Mock())
+ self.notifier_stop_mock.assert_not_called()
+
+ async def test_notifier_private_sends_alert(self):
+ """Alert is sent on 15 min intervals."""
+ test_cases = (900, 1800, 2700)
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
+ await self.notifier._notifier()
+ self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ")
+ self.alert_channel.send.reset_mock()
+
+ async def test_notifier_skips_alert(self):
+ """Alert is skipped on first loop or not an increment of 900."""
+ test_cases = (0, 15, 5000)
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
+ await self.notifier._notifier()
+ self.alert_channel.send.assert_not_called()
+
+
+class SilenceTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = Silence(self.bot)
+ self.ctx = MockContext()
+ self.cog._verified_role = None
+ # Set event so command callbacks can continue.
+ self.cog._get_instance_vars_event.set()
+
+ async def test_instance_vars_got_guild(self):
+ """Bot got guild after it became available."""
+ await self.cog._get_instance_vars()
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(Guild.id)
+
+ async def test_instance_vars_got_role(self):
+ """Got `Roles.verified` role from guild."""
+ await self.cog._get_instance_vars()
+ guild = self.bot.get_guild()
+ guild.get_role.assert_called_once_with(Roles.verified)
+
+ async def test_instance_vars_got_channels(self):
+ """Got channels from bot."""
+ await self.cog._get_instance_vars()
+ self.bot.get_channel.called_once_with(Channels.mod_alerts)
+ self.bot.get_channel.called_once_with(Channels.mod_log)
+
+ @mock.patch("bot.cogs.moderation.silence.SilenceNotifier")
+ async def test_instance_vars_got_notifier(self, notifier):
+ """Notifier was started with channel."""
+ mod_log = MockTextChannel()
+ self.bot.get_channel.side_effect = (None, mod_log)
+ await self.cog._get_instance_vars()
+ notifier.assert_called_once_with(mod_log)
+ self.bot.get_channel.side_effect = None
+
+ async def test_silence_sent_correct_discord_message(self):
+ """Check if proper message was sent when called with duration in channel with previous state."""
+ test_cases = (
+ (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,),
+ (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,),
+ (5, f"{Emojis.cross_mark} current channel is already silenced.", False,),
+ )
+ for duration, result_message, _silence_patch_return in test_cases:
+ with self.subTest(
+ silence_duration=duration,
+ result_message=result_message,
+ starting_unsilenced_state=_silence_patch_return
+ ):
+ with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
+ await self.cog.silence.callback(self.cog, self.ctx, duration)
+ self.ctx.send.assert_called_once_with(result_message)
+ self.ctx.reset_mock()
+
+ async def test_unsilence_sent_correct_discord_message(self):
+ """Check if proper message was sent when unsilencing channel."""
+ test_cases = (
+ (True, f"{Emojis.check_mark} unsilenced current channel."),
+ (False, f"{Emojis.cross_mark} current channel was not silenced.")
+ )
+ for _unsilence_patch_return, result_message in test_cases:
+ with self.subTest(
+ starting_silenced_state=_unsilence_patch_return,
+ result_message=result_message
+ ):
+ with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
+ await self.cog.unsilence.callback(self.cog, self.ctx)
+ self.ctx.send.assert_called_once_with(result_message)
+ self.ctx.reset_mock()
+
+ async def test_silence_private_for_false(self):
+ """Permissions are not set and `False` is returned in an already silenced channel."""
+ perm_overwrite = Mock(send_messages=False)
+ channel = Mock(overwrites_for=Mock(return_value=perm_overwrite))
+
+ self.assertFalse(await self.cog._silence(channel, True, None))
+ channel.set_permissions.assert_not_called()
+
+ async def test_silence_private_silenced_channel(self):
+ """Channel had `send_message` permissions revoked."""
+ channel = MockTextChannel()
+ self.assertTrue(await self.cog._silence(channel, False, None))
+ channel.set_permissions.assert_called_once()
+ self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages'])
+
+ async def test_silence_private_preserves_permissions(self):
+ """Previous permissions were preserved when channel was silenced."""
+ channel = MockTextChannel()
+ # Set up mock channel permission state.
+ mock_permissions = PermissionOverwrite()
+ mock_permissions_dict = dict(mock_permissions)
+ channel.overwrites_for.return_value = mock_permissions
+ await self.cog._silence(channel, False, None)
+ new_permissions = channel.set_permissions.call_args.kwargs
+ # Remove 'send_messages' key because it got changed in the method.
+ del new_permissions['send_messages']
+ del mock_permissions_dict['send_messages']
+ self.assertDictEqual(mock_permissions_dict, new_permissions)
+
+ async def test_silence_private_notifier(self):
+ """Channel should be added to notifier with `persistent` set to `True`, and the other way around."""
+ channel = MockTextChannel()
+ with mock.patch.object(self.cog, "notifier", create=True):
+ with self.subTest(persistent=True):
+ await self.cog._silence(channel, True, None)
+ self.cog.notifier.add_channel.assert_called_once()
+
+ with mock.patch.object(self.cog, "notifier", create=True):
+ with self.subTest(persistent=False):
+ await self.cog._silence(channel, False, None)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_silence_private_added_muted_channel(self):
+ """Channel was added to `muted_channels` on silence."""
+ channel = MockTextChannel()
+ with mock.patch.object(self.cog, "muted_channels") as muted_channels:
+ await self.cog._silence(channel, False, None)
+ muted_channels.add.assert_called_once_with(channel)
+
+ async def test_unsilence_private_for_false(self):
+ """Permissions are not set and `False` is returned in an unsilenced channel."""
+ channel = Mock()
+ self.assertFalse(await self.cog._unsilence(channel))
+ channel.set_permissions.assert_not_called()
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_unsilenced_channel(self, _):
+ """Channel had `send_message` permissions restored"""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ self.assertTrue(await self.cog._unsilence(channel))
+ channel.set_permissions.assert_called_once()
+ self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages'])
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_removed_notifier(self, notifier):
+ """Channel was removed from `notifier` on unsilence."""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ await self.cog._unsilence(channel)
+ notifier.remove_channel.assert_called_once_with(channel)
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_removed_muted_channel(self, _):
+ """Channel was removed from `muted_channels` on unsilence."""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ with mock.patch.object(self.cog, "muted_channels") as muted_channels:
+ await self.cog._unsilence(channel)
+ muted_channels.discard.assert_called_once_with(channel)
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_preserves_permissions(self, _):
+ """Previous permissions were preserved when channel was unsilenced."""
+ channel = MockTextChannel()
+ # Set up mock channel permission state.
+ mock_permissions = PermissionOverwrite(send_messages=False)
+ mock_permissions_dict = dict(mock_permissions)
+ channel.overwrites_for.return_value = mock_permissions
+ await self.cog._unsilence(channel)
+ new_permissions = channel.set_permissions.call_args.kwargs
+ # Remove 'send_messages' key because it got changed in the method.
+ del new_permissions['send_messages']
+ del mock_permissions_dict['send_messages']
+ self.assertDictEqual(mock_permissions_dict, new_permissions)
+
+ @mock.patch("bot.cogs.moderation.silence.asyncio")
+ @mock.patch.object(Silence, "_mod_alerts_channel", create=True)
+ def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):
+ """Task for sending an alert was created with present `muted_channels`."""
+ with mock.patch.object(self.cog, "muted_channels"):
+ self.cog.cog_unload()
+ alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")
+ asyncio_mock.create_task.assert_called_once_with(alert_channel.send())
+
+ @mock.patch("bot.cogs.moderation.silence.asyncio")
+ def test_cog_unload_skips_task_start(self, asyncio_mock):
+ """No task created with no channels."""
+ self.cog.cog_unload()
+ asyncio_mock.create_task.assert_not_called()
+
+ @mock.patch("bot.cogs.moderation.silence.with_role_check")
+ @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
+ def test_cog_check(self, role_check):
+ """Role check is called with `MODERATION_ROLES`"""
+ self.cog.cog_check(self.ctx)
+ role_check.assert_called_once_with(self.ctx, *(1, 2, 3))
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
index c2e143865..70aea2bab 100644
--- a/tests/bot/cogs/sync/test_base.py
+++ b/tests/bot/cogs/sync/test_base.py
@@ -1,3 +1,4 @@
+import asyncio
import unittest
from unittest import mock
@@ -13,8 +14,8 @@ class TestSyncer(Syncer):
"""Syncer subclass with mocks for abstract methods for testing purposes."""
name = "test"
- _get_diff = helpers.AsyncMock()
- _sync = helpers.AsyncMock()
+ _get_diff = mock.AsyncMock()
+ _sync = mock.AsyncMock()
class SyncerBaseTests(unittest.TestCase):
@@ -29,7 +30,7 @@ class SyncerBaseTests(unittest.TestCase):
Syncer(self.bot)
-class SyncerSendPromptTests(unittest.TestCase):
+class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):
"""Tests for sending the sync confirmation prompt."""
def setUp(self):
@@ -61,7 +62,6 @@ class SyncerSendPromptTests(unittest.TestCase):
return mock_channel, mock_message
- @helpers.async_test
async def test_send_prompt_edits_and_returns_message(self):
"""The given message should be edited to display the prompt and then should be returned."""
msg = helpers.MockMessage()
@@ -71,7 +71,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIn("content", msg.edit.call_args[1])
self.assertEqual(ret_val, msg)
- @helpers.async_test
async def test_send_prompt_gets_dev_core_channel(self):
"""The dev-core channel should be retrieved if an extant message isn't given."""
subtests = (
@@ -86,8 +85,7 @@ class SyncerSendPromptTests(unittest.TestCase):
method.assert_called_once_with(constants.Channels.dev_core)
- @helpers.async_test
- async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
+ async def test_send_prompt_returns_none_if_channel_fetch_fails(self):
"""None should be returned if there's an HTTPException when fetching the channel."""
self.bot.get_channel.return_value = None
self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
@@ -96,7 +94,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIsNone(ret_val)
- @helpers.async_test
async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
"""A new message mentioning core devs should be sent and returned if message isn't given."""
for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
@@ -108,7 +105,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
self.assertEqual(ret_val, mock_message)
- @helpers.async_test
async def test_send_prompt_adds_reactions(self):
"""The message should have reactions for confirmation added."""
extant_message = helpers.MockMessage()
@@ -129,7 +125,7 @@ class SyncerSendPromptTests(unittest.TestCase):
mock_message.add_reaction.assert_has_calls(calls)
-class SyncerConfirmationTests(unittest.TestCase):
+class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):
"""Tests for waiting for a sync confirmation reaction on the prompt."""
def setUp(self):
@@ -211,13 +207,12 @@ class SyncerConfirmationTests(unittest.TestCase):
ret_val = self.syncer._reaction_check(*args)
self.assertFalse(ret_val)
- @helpers.async_test
async def test_wait_for_confirmation(self):
"""The message should always be edited and only return True if the emoji is a check mark."""
subtests = (
(constants.Emojis.check_mark, True, None),
("InVaLiD", False, None),
- (None, False, TimeoutError),
+ (None, False, asyncio.TimeoutError),
)
for emoji, ret_val, side_effect in subtests:
@@ -251,14 +246,13 @@ class SyncerConfirmationTests(unittest.TestCase):
self.assertIs(actual_return, ret_val)
-class SyncerSyncTests(unittest.TestCase):
+class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for main function orchestrating the sync."""
def setUp(self):
self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
self.syncer = TestSyncer(self.bot)
- @helpers.async_test
async def test_sync_respects_confirmation_result(self):
"""The sync should abort if confirmation fails and continue if confirmed."""
mock_message = helpers.MockMessage()
@@ -274,7 +268,7 @@ class SyncerSyncTests(unittest.TestCase):
diff = _Diff({1, 2, 3}, {4, 5}, None)
self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = helpers.AsyncMock(
+ self.syncer._get_confirmation_result = mock.AsyncMock(
return_value=(confirmed, message)
)
@@ -289,7 +283,6 @@ class SyncerSyncTests(unittest.TestCase):
else:
self.syncer._sync.assert_not_called()
- @helpers.async_test
async def test_sync_diff_size(self):
"""The diff size should be correctly calculated."""
subtests = (
@@ -303,7 +296,7 @@ class SyncerSyncTests(unittest.TestCase):
with self.subTest(size=size, diff=diff):
self.syncer._get_diff.reset_mock()
self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
guild = helpers.MockGuild()
await self.syncer.sync(guild)
@@ -312,7 +305,6 @@ class SyncerSyncTests(unittest.TestCase):
self.syncer._get_confirmation_result.assert_called_once()
self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
- @helpers.async_test
async def test_sync_message_edited(self):
"""The message should be edited if one was sent, even if the sync has an API error."""
subtests = (
@@ -324,7 +316,7 @@ class SyncerSyncTests(unittest.TestCase):
for message, side_effect, should_edit in subtests:
with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
self.syncer._sync.side_effect = side_effect
- self.syncer._get_confirmation_result = helpers.AsyncMock(
+ self.syncer._get_confirmation_result = mock.AsyncMock(
return_value=(True, message)
)
@@ -335,7 +327,6 @@ class SyncerSyncTests(unittest.TestCase):
message.edit.assert_called_once()
self.assertIn("content", message.edit.call_args[1])
- @helpers.async_test
async def test_sync_confirmation_context_redirect(self):
"""If ctx is given, a new message should be sent and author should be ctx's author."""
mock_member = helpers.MockMember()
@@ -349,7 +340,10 @@ class SyncerSyncTests(unittest.TestCase):
if ctx is not None:
ctx.send.return_value = message
- self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+ # Make sure `_get_diff` returns a MagicMock, not an AsyncMock
+ self.syncer._get_diff.return_value = mock.MagicMock()
+
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
guild = helpers.MockGuild()
await self.syncer.sync(guild, ctx)
@@ -362,16 +356,15 @@ class SyncerSyncTests(unittest.TestCase):
self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
@mock.patch.object(constants.Sync, "max_diff", new=3)
- @helpers.async_test
async def test_confirmation_result_small_diff(self):
"""Should always return True and the given message if the diff size is too small."""
author = helpers.MockMember()
expected_message = helpers.MockMessage()
- for size in (3, 2):
+ for size in (3, 2): # pragma: no cover
with self.subTest(size=size):
- self.syncer._send_prompt = helpers.AsyncMock()
- self.syncer._wait_for_confirmation = helpers.AsyncMock()
+ self.syncer._send_prompt = mock.AsyncMock()
+ self.syncer._wait_for_confirmation = mock.AsyncMock()
coro = self.syncer._get_confirmation_result(size, author, expected_message)
result, actual_message = await coro
@@ -382,7 +375,6 @@ class SyncerSyncTests(unittest.TestCase):
self.syncer._wait_for_confirmation.assert_not_called()
@mock.patch.object(constants.Sync, "max_diff", new=3)
- @helpers.async_test
async def test_confirmation_result_large_diff(self):
"""Should return True if confirmed and False if _send_prompt fails or aborted."""
author = helpers.MockMember()
@@ -394,10 +386,10 @@ class SyncerSyncTests(unittest.TestCase):
(False, mock_message, False, "aborted"),
)
- for expected_result, expected_message, confirmed, msg in subtests:
+ for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover
with self.subTest(msg=msg):
- self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message)
- self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed)
+ self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)
coro = self.syncer._get_confirmation_result(4, author)
actual_result, actual_message = await coro
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
index 98c9afc0d..14fd909c4 100644
--- a/tests/bot/cogs/sync/test_cog.py
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -11,19 +11,7 @@ from tests import helpers
from tests.base import CommandTestCase
-class MockSyncer(helpers.CustomMockMixin, mock.MagicMock):
- """
- A MagicMock subclass to mock Syncer objects.
-
- Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer`
- instances. For more information, see the `MockGuild` docstring.
- """
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=Syncer, **kwargs)
-
-
-class SyncExtensionTests(unittest.TestCase):
+class SyncExtensionTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the sync extension."""
@staticmethod
@@ -34,22 +22,21 @@ class SyncExtensionTests(unittest.TestCase):
bot.add_cog.assert_called_once()
-class SyncCogTestCase(unittest.TestCase):
+class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
"""Base class for Sync cog tests. Sets up patches for syncers."""
def setUp(self):
self.bot = helpers.MockBot()
- # These patch the type. When the type is called, a MockSyncer instanced is returned.
- # MockSyncer is needed so that our custom AsyncMock is used.
- # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.
self.role_syncer_patcher = mock.patch(
"bot.cogs.sync.syncers.RoleSyncer",
- new=mock.MagicMock(return_value=MockSyncer())
+ autospec=Syncer,
+ spec_set=True
)
self.user_syncer_patcher = mock.patch(
"bot.cogs.sync.syncers.UserSyncer",
- new=mock.MagicMock(return_value=MockSyncer())
+ autospec=Syncer,
+ spec_set=True
)
self.RoleSyncer = self.role_syncer_patcher.start()
self.UserSyncer = self.user_syncer_patcher.start()
@@ -72,13 +59,13 @@ class SyncCogTestCase(unittest.TestCase):
class SyncCogTests(SyncCogTestCase):
"""Tests for the Sync cog."""
- @mock.patch.object(sync.Sync, "sync_guild")
+ @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock)
def test_sync_cog_init(self, sync_guild):
"""Should instantiate syncers and run a sync for the guild."""
# Reset because a Sync cog was already instantiated in setUp.
self.RoleSyncer.reset_mock()
self.UserSyncer.reset_mock()
- self.bot.loop.create_task.reset_mock()
+ self.bot.loop.create_task = mock.MagicMock()
mock_sync_guild_coro = mock.MagicMock()
sync_guild.return_value = mock_sync_guild_coro
@@ -90,7 +77,6 @@ class SyncCogTests(SyncCogTestCase):
sync_guild.assert_called_once_with()
self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
- @helpers.async_test
async def test_sync_cog_sync_guild(self):
"""Roles and users should be synced only if a guild is successfully retrieved."""
for guild in (helpers.MockGuild(), None):
@@ -126,14 +112,12 @@ class SyncCogTests(SyncCogTestCase):
json=updated_information,
)
- @helpers.async_test
async def test_sync_cog_patch_user(self):
"""A PATCH request should be sent and 404 errors ignored."""
for side_effect in (None, self.response_error(404)):
with self.subTest(side_effect=side_effect):
await self.patch_user_helper(side_effect)
- @helpers.async_test
async def test_sync_cog_patch_user_non_404(self):
"""A PATCH request should be sent and the error raised if it's not a 404."""
with self.assertRaises(ResponseCodeError):
@@ -145,9 +129,8 @@ class SyncCogListenerTests(SyncCogTestCase):
def setUp(self):
super().setUp()
- self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user)
+ self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user)
- @helpers.async_test
async def test_sync_cog_on_guild_role_create(self):
"""A POST request should be sent with the new role's data."""
self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
@@ -164,7 +147,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
- @helpers.async_test
async def test_sync_cog_on_guild_role_delete(self):
"""A DELETE request should be sent."""
self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
@@ -174,7 +156,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
- @helpers.async_test
async def test_sync_cog_on_guild_role_update(self):
"""A PUT request should be sent if the colour, name, permissions, or position changes."""
self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
@@ -212,7 +193,6 @@ class SyncCogListenerTests(SyncCogTestCase):
else:
self.bot.api_client.put.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_member_remove(self):
"""Member should patched to set in_guild as False."""
self.assertTrue(self.cog.on_member_remove.__cog_listener__)
@@ -225,7 +205,6 @@ class SyncCogListenerTests(SyncCogTestCase):
updated_information={"in_guild": False}
)
- @helpers.async_test
async def test_sync_cog_on_member_update_roles(self):
"""Members should be patched if their roles have changed."""
self.assertTrue(self.cog.on_member_update.__cog_listener__)
@@ -240,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase):
data = {"roles": sorted(role.id for role in after_member.roles)}
self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
- @helpers.async_test
async def test_sync_cog_on_member_update_other(self):
"""Members should not be patched if other attributes have changed."""
self.assertTrue(self.cog.on_member_update.__cog_listener__)
@@ -262,7 +240,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.cog.patch_user.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_user_update(self):
"""A user should be patched only if the name, discriminator, or avatar changes."""
self.assertTrue(self.cog.on_user_update.__cog_listener__)
@@ -270,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase):
before_data = {
"name": "old name",
"discriminator": "1234",
- "avatar": "old avatar",
"bot": False,
}
subtests = (
(True, "name", "name", "new name", "new name"),
(True, "discriminator", "discriminator", "8765", 8765),
- (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
(False, "bot", "bot", True, True),
)
@@ -318,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase):
)
data = {
- "avatar_hash": member.avatar,
"discriminator": int(member.discriminator),
"id": member.id,
"in_guild": True,
@@ -341,7 +315,6 @@ class SyncCogListenerTests(SyncCogTestCase):
return data
- @helpers.async_test
async def test_sync_cog_on_member_join(self):
"""Should PUT user's data or POST it if the user doesn't exist."""
for side_effect in (None, self.response_error(404)):
@@ -354,7 +327,6 @@ class SyncCogListenerTests(SyncCogTestCase):
else:
self.bot.api_client.post.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_member_join_non_404(self):
"""ResponseCodeError should be re-raised if status code isn't a 404."""
with self.assertRaises(ResponseCodeError):
@@ -366,7 +338,6 @@ class SyncCogListenerTests(SyncCogTestCase):
class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
"""Tests for the commands in the Sync cog."""
- @helpers.async_test
async def test_sync_roles_command(self):
"""sync() should be called on the RoleSyncer."""
ctx = helpers.MockContext()
@@ -374,7 +345,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
- @helpers.async_test
async def test_sync_users_command(self):
"""sync() should be called on the UserSyncer."""
ctx = helpers.MockContext()
@@ -382,7 +352,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
- def test_commands_require_admin(self):
+ async def test_commands_require_admin(self):
"""The sync commands should only run if the author has the administrator permission."""
cmds = (
self.cog.sync_group,
@@ -392,4 +362,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
for cmd in cmds:
with self.subTest(cmd=cmd):
- self.assertHasPermissionsCheck(cmd, {"administrator": True})
+ await self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 14fb2577a..79eee98f4 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -18,7 +18,7 @@ def fake_role(**kwargs):
return kwargs
-class RoleSyncerDiffTests(unittest.TestCase):
+class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between roles in the DB and roles in the Guild cache."""
def setUp(self):
@@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
return guild
- @helpers.async_test
async def test_empty_diff_for_identical_roles(self):
"""No differences should be found if the roles in the guild and DB are identical."""
self.bot.api_client.get.return_value = [fake_role()]
@@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_updated_roles(self):
"""Only updated roles should be added to the 'updated' set of the diff."""
updated_role = fake_role(id=41, name="new")
@@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_roles(self):
"""Only new roles should be added to the 'created' set of the diff."""
new_role = fake_role(id=41, name="new")
@@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_deleted_roles(self):
"""Only deleted roles should be added to the 'deleted' set of the diff."""
deleted_role = fake_role(id=61, name="deleted")
@@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_updated_and_deleted_roles(self):
"""When roles are added, updated, and removed, all of them are returned properly."""
new = fake_role(id=41, name="new")
@@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
-class RoleSyncerSyncTests(unittest.TestCase):
+class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync roles."""
def setUp(self):
self.bot = helpers.MockBot()
self.syncer = RoleSyncer(self.bot)
- @helpers.async_test
async def test_sync_created_roles(self):
"""Only POST requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
@@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase):
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_updated_roles(self):
"""Only PUT requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
@@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase):
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_deleted_roles(self):
"""Only DELETE requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index 421bf6bb6..002a947ad 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -10,14 +10,13 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("avatar_hash", None)
kwargs.setdefault("roles", (666,))
kwargs.setdefault("in_guild", True)
return kwargs
-class UserSyncerDiffTests(unittest.TestCase):
+class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between users in the DB and users in the Guild cache."""
def setUp(self):
@@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.TestCase):
for member in members:
member = member.copy()
- member["avatar"] = member.pop("avatar_hash")
del member["in_guild"]
mock_member = helpers.MockMember(**member)
@@ -42,7 +40,6 @@ class UserSyncerDiffTests(unittest.TestCase):
return guild
- @helpers.async_test
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
guild = self.get_guild()
@@ -52,7 +49,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
self.bot.api_client.get.return_value = [fake_user()]
@@ -63,7 +59,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_updated_users(self):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
@@ -76,7 +71,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_users(self):
"""Only new users should be added to the 'created' set of the diff."""
new_user = fake_user(id=99, name="new")
@@ -89,7 +83,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
leaving_user = fake_user(id=63, in_guild=False)
@@ -102,7 +95,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
@@ -117,7 +109,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_empty_diff_for_db_users_not_in_guild(self):
"""When the DB knows a user the guild doesn't, no difference is found."""
self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
@@ -129,14 +120,13 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
-class UserSyncerSyncTests(unittest.TestCase):
+class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync users."""
def setUp(self):
self.bot = helpers.MockBot()
self.syncer = UserSyncer(self.bot)
- @helpers.async_test
async def test_sync_created_users(self):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
@@ -152,7 +142,6 @@ class UserSyncerSyncTests(unittest.TestCase):
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_updated_users(self):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
new file mode 100644
index 000000000..f219fc1ba
--- /dev/null
+++ b/tests/bot/cogs/test_antimalware.py
@@ -0,0 +1,159 @@
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from discord import NotFound
+
+from bot.cogs import antimalware
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
+
+MODULE = "bot.cogs.antimalware"
+
+
+@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
+class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
+ """Test the AntiMalware cog."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = MockBot()
+ self.cog = antimalware.AntiMalware(self.bot)
+ self.message = MockMessage()
+
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should not be deleted"""
+ attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_message_without_attachment(self):
+ """Messages without attachments should result in no action."""
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_direct_message_with_attachment(self):
+ """Direct messages should have no action taken."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.guild = None
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_message_with_illegal_extension_gets_deleted(self):
+ """A message containing an illegal extension should send an embed."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_called_once()
+
+ async def test_message_send_by_staff(self):
+ """A message send by a member of staff should be ignored."""
+ staff_role = MockRole(id=STAFF_ROLES[0])
+ self.message.author.roles.append(staff_role)
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site"""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+
+ self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
+
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt file should result in the correct embed."""
+ attachment = MockAttachment(filename="python.txt")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+
+ async def test_other_disallowed_extention_embed_description(self):
+ """Test the description for a non .py/.txt disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ meta_channel = self.bot.get_channel(Channels.meta)
+
+ self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+
+ async def test_removing_deleted_message_logs(self):
+ """Removing an already deleted message logs the correct message"""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_called_once()
+
+ async def test_message_with_illegal_attachment_logs(self):
+ """Deleting a message with an illegal attachment should result in a log."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (AntiMalwareConfig.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], [".disallowed"]),
+ ([".disallowed"], [".disallowed"]),
+ ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
+ disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
+
+
+class AntiMalwareSetupTests(unittest.TestCase):
+ """Tests setup of the `AntiMalware` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ antimalware.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py
new file mode 100644
index 000000000..fdda59a8f
--- /dev/null
+++ b/tests/bot/cogs/test_cogs.py
@@ -0,0 +1,80 @@
+"""Test suite for general tests which apply to all cogs."""
+
+import importlib
+import pkgutil
+import typing as t
+import unittest
+from collections import defaultdict
+from types import ModuleType
+from unittest import mock
+
+from discord.ext import commands
+
+from bot import cogs
+
+
+class CommandNameTests(unittest.TestCase):
+ """Tests for shadowing command names and aliases."""
+
+ @staticmethod
+ def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]:
+ """An iterator that recursively walks through `cog`'s commands and subcommands."""
+ # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods.
+ for command in cog.__cog_commands__:
+ if command.parent is None:
+ yield command
+ if isinstance(command, commands.GroupMixin):
+ # Annoyingly it returns duplicates for each alias so use a set to fix that
+ yield from set(command.walk_commands())
+
+ @staticmethod
+ def walk_modules() -> t.Iterator[ModuleType]:
+ """Yield imported modules from the bot.cogs subpackage."""
+ def on_error(name: str) -> t.NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ # The mock prevents asyncio.get_event_loop() from being called.
+ with mock.patch("discord.ext.tasks.loop"):
+ for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error):
+ if not module.ispkg:
+ yield importlib.import_module(module.name)
+
+ @staticmethod
+ def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]:
+ """Yield all cogs defined in an extension."""
+ for obj in module.__dict__.values():
+ # Check if it's a class type cause otherwise issubclass() may raise a TypeError.
+ is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog)
+ if is_cog and obj.__module__ == module.__name__:
+ yield obj
+
+ @staticmethod
+ def get_qualified_names(command: commands.Command) -> t.List[str]:
+ """Return a list of all qualified names, including aliases, for the `command`."""
+ names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases]
+ names.append(command.qualified_name)
+
+ return names
+
+ def get_all_commands(self) -> t.Iterator[commands.Command]:
+ """Yield all commands for all cogs in all extensions."""
+ for module in self.walk_modules():
+ for cog in self.walk_cogs(module):
+ for cmd in self.walk_commands(cog):
+ yield cmd
+
+ def test_names_dont_shadow(self):
+ """Names and aliases of commands should be unique."""
+ all_names = defaultdict(list)
+ for cmd in self.get_all_commands():
+ func_name = f"{cmd.module}.{cmd.callback.__qualname__}"
+
+ for name in self.get_qualified_names(cmd):
+ with self.subTest(cmd=func_name, name=name):
+ if name in all_names: # pragma: no cover
+ conflicts = ", ".join(all_names.get(name, ""))
+ self.fail(
+ f"Name '{name}' of the command {func_name} conflicts with {conflicts}."
+ )
+
+ all_names[name].append(func_name)
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 5b0a3b8c3..a8c0107c6 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import typing
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import discord
@@ -14,7 +14,7 @@ from tests import helpers
MODULE_PATH = "bot.cogs.duck_pond"
-class DuckPondTests(base.LoggingTestCase):
+class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
"""Tests for DuckPond functionality."""
@classmethod
@@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestCase):
self.assertEqual(cog.bot, bot)
self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond)
- bot.loop.create_loop.called_once_with(cog.fetch_webhook())
+ bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())
def test_fetch_webhook_succeeds_without_connectivity_issues(self):
"""The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute."""
@@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
self.assertEqual(expected_return, actual_return)
- @helpers.async_test
async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
"""The `has_green_checkmark` method should only return `True` if one is present."""
test_cases = (
@@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):
nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
- @helpers.async_test
async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
"""The `count_ducks` method should return the number of unique staffers who gave a duck."""
test_cases = (
@@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
self.assertEqual(expected_count, actual_count)
- @helpers.async_test
async def test_relay_message_correctly_relays_content_and_attachments(self):
"""The `relay_message` method should correctly relay message content and attachments."""
send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
@@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):
)
for message, expect_webhook_call, expect_attachment_call in test_values:
- with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
- with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:
with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
await self.cog.relay_message(message)
@@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):
message.add_reaction.assert_called_once_with(self.checkmark_emoji)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):
self.cog.webhook = helpers.MockAsyncWebhook()
log = logging.getLogger("bot.cogs.duck_pond")
- for side_effect in side_effects:
+ for side_effect in side_effects: # pragma: no cover
send_attachments.side_effect = side_effect
- with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:
with self.subTest(side_effect=type(side_effect).__name__):
with self.assertNotLogs(logger=log, level=logging.ERROR):
await self.cog.relay_message(message)
self.assertEqual(send_webhook.call_count, 2)
- @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji.name = emoji_name
return payload
- @helpers.async_test
async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
"""The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
test_values = (
@@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):
return channel, message, member, payload
- @helpers.async_test
async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
"""The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
channel_id = 1234
@@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):
channel.fetch_message.reset_mock()
@patch(f"{MODULE_PATH}.DuckPond.is_staff")
- @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)
def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
"""The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
channel_id = 31415926535
@@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):
# Assert that we've made it past `self.is_staff`
is_staff.assert_called_once()
- @helpers.async_test
async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
"""The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
test_cases = (
@@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji = self.duck_pond_emoji
for duck_count, should_relay in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_relay=should_relay):
await self.cog.on_raw_reaction_add(payload)
@@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):
if should_relay:
relay_message.assert_called_once_with(message)
- @helpers.async_test
async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
"""The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
@@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):
(constants.DuckPond.threshold + 1, True),
)
for duck_count, should_re_add_checkmark in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
await self.cog.on_raw_reaction_remove(payload)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 8443cfe71..79c0e0ad3 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,10 +7,9 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InChannelCheckFailure
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
COG_PATH = "bot.cogs.information.Information"
@@ -34,7 +33,7 @@ class InformationCogTests(unittest.TestCase):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
- self.cog.roles_info.can_run = helpers.AsyncMock()
+ self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
@@ -45,10 +44,9 @@ class InformationCogTests(unittest.TestCase):
_, kwargs = self.ctx.send.call_args
embed = kwargs.pop('embed')
- self.assertEqual(embed.title, "Role information")
+ self.assertEqual(embed.title, "Role information (Total 1 role)")
self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
- self.assertEqual(embed.footer.text, "Total roles: 1")
+ self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
def test_role_info_command(self):
"""Tests the `role info` command."""
@@ -72,7 +70,7 @@ class InformationCogTests(unittest.TestCase):
self.ctx.guild.roles.append([dummy_role, admin_role])
- self.cog.role_info.can_run = helpers.AsyncMock()
+ self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
@@ -150,14 +148,18 @@ class InformationCogTests(unittest.TestCase):
Voice region: {self.ctx.guild.region}
Features: {', '.join(self.ctx.guild.features)}
- **Counts**
- Members: {self.ctx.guild.member_count:,}
- Roles: {len(self.ctx.guild.roles)}
+ **Channel counts**
Category channels: 1
Text channels: 1
Voice channels: 1
+ Staff channels: 0
+
+ **Member counts**
+ Members: {self.ctx.guild.member_count:,}
+ Staff members: 0
+ Roles: {len(self.ctx.guild.roles)}
- **Members**
+ **Member statuses**
{constants.Emojis.status_online} 2
{constants.Emojis.status_idle} 1
{constants.Emojis.status_dnd} 4
@@ -174,7 +176,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
@@ -345,10 +347,10 @@ class UserEmbedTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -360,7 +362,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Mr. Hemlock")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -372,7 +374,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -387,8 +389,8 @@ class UserEmbedTests(unittest.TestCase):
self.assertIn("&Admins", embed.description)
self.assertNotIn("&Everyone", embed.description)
- @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock)
- @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -423,7 +425,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -454,7 +456,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -467,7 +469,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
@@ -477,7 +479,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
@@ -486,7 +488,7 @@ class UserEmbedTests(unittest.TestCase):
user.avatar_url_as.return_value = "avatar url"
embed = asyncio.run(self.cog.create_user_embed(ctx, user))
- user.avatar_url_as.assert_called_once_with(format="png")
+ user.avatar_url_as.assert_called_once_with(static_format="png")
self.assertEqual(embed.thumbnail.url, "avatar url")
@@ -526,10 +528,10 @@ class UserCommandTests(unittest.TestCase):
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
msg = "Sorry, but you may only use this command within <#50>."
- with self.assertRaises(InChannelCheckFailure, msg=msg):
+ with self.assertRaises(InWhitelistCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -542,7 +544,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -555,7 +557,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -568,7 +570,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
index 985bc66a1..cf9adbee0 100644
--- a/tests/bot/cogs/test_snekbox.py
+++ b/tests/bot/cogs/test_snekbox.py
@@ -1,74 +1,79 @@
import asyncio
import logging
import unittest
-from functools import partial
-from unittest.mock import MagicMock, Mock, call, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
+from discord.ext import commands
+
+from bot import constants
from bot.cogs import snekbox
from bot.cogs.snekbox import Snekbox
-from bot.constants import URLs
-from tests.helpers import (
- AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test
-)
+from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
-class SnekboxTests(unittest.TestCase):
+class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
"""Add mocked bot and cog to the instance."""
self.bot = MockBot()
-
- self.mocked_post = MagicMock()
- self.mocked_post.json = AsyncMock()
- self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post))
-
self.cog = Snekbox(bot=self.bot)
- @async_test
async def test_post_eval(self):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
- self.mocked_post.json.return_value = {'lemon': 'AI'}
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value="return")
- self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'})
- self.bot.http_session.post.assert_called_once_with(
- URLs.snekbox_eval_api,
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
+
+ self.assertEqual(await self.cog.post_eval("import random"), "return")
+ self.bot.http_session.post.assert_called_with(
+ constants.URLs.snekbox_eval_api,
json={"input": "import random"},
raise_for_status=True
)
+ resp.json.assert_awaited_once()
- @async_test
async def test_upload_output_reject_too_long(self):
"""Reject output longer than MAX_PASTE_LEN."""
result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))
self.assertEqual(result, "too long to upload")
- @async_test
async def test_upload_output(self):
"""Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
- key = "RainbowDash"
- self.mocked_post.json.return_value = {"key": key}
+ key = "MarkDiamond"
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value={"key": key})
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
self.assertEqual(
await self.cog.upload_output("My awesome output"),
- URLs.paste_service.format(key=key)
+ constants.URLs.paste_service.format(key=key)
)
- self.bot.http_session.post.assert_called_once_with(
- URLs.paste_service.format(key="documents"),
+ self.bot.http_session.post.assert_called_with(
+ constants.URLs.paste_service.format(key="documents"),
data="My awesome output",
raise_for_status=True
)
- @async_test
async def test_upload_output_gracefully_fallback_if_exception_during_request(self):
"""Output upload gracefully fallback if the upload fail."""
- self.mocked_post.json.side_effect = Exception
+ resp = MagicMock()
+ resp.json = AsyncMock(side_effect=Exception)
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
+
log = logging.getLogger("bot.cogs.snekbox")
with self.assertLogs(logger=log, level='ERROR'):
await self.cog.upload_output('My awesome output!')
- @async_test
async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):
"""Output upload gracefully fallback if there is no key entry in the response body."""
- self.mocked_post.json.return_value = {}
self.assertEqual((await self.cog.upload_output('My awesome output!')), None)
def test_prepare_input(self):
@@ -95,15 +100,15 @@ class SnekboxTests(unittest.TestCase):
self.assertEqual(actual, expected)
@patch('bot.cogs.snekbox.Signals', side_effect=ValueError)
- def test_get_results_message_invalid_signal(self, mock_Signals: Mock):
+ def test_get_results_message_invalid_signal(self, mock_signals: Mock):
self.assertEqual(
self.cog.get_results_message({'stdout': '', 'returncode': 127}),
('Your eval job has completed with return code 127', '')
)
@patch('bot.cogs.snekbox.Signals')
- def test_get_results_message_valid_signal(self, mock_Signals: Mock):
- mock_Signals.return_value.name = 'SIGTEST'
+ def test_get_results_message_valid_signal(self, mock_signals: Mock):
+ mock_signals.return_value.name = 'SIGTEST'
self.assertEqual(
self.cog.get_results_message({'stdout': '', 'returncode': 127}),
('Your eval job has completed with return code 127 (SIGTEST)', '')
@@ -121,7 +126,6 @@ class SnekboxTests(unittest.TestCase):
actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
self.assertEqual(actual, expected)
- @async_test
async def test_format_output(self):
"""Test output formatting."""
self.cog.upload_output = AsyncMock(return_value='https://testificate.com/')
@@ -172,7 +176,6 @@ class SnekboxTests(unittest.TestCase):
with self.subTest(msg=testname, case=case, expected=expected):
self.assertEqual(await self.cog.format_output(case), expected)
- @async_test
async def test_eval_command_evaluate_once(self):
"""Test the eval command procedure."""
ctx = MockContext()
@@ -186,7 +189,6 @@ class SnekboxTests(unittest.TestCase):
self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_once_with(ctx, response)
- @async_test
async def test_eval_command_evaluate_twice(self):
"""Test the eval and re-eval command procedure."""
ctx = MockContext()
@@ -201,7 +203,6 @@ class SnekboxTests(unittest.TestCase):
self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_with(ctx, response)
- @async_test
async def test_eval_command_reject_two_eval_at_the_same_time(self):
"""Test if the eval command rejects an eval if the author already have a running eval."""
ctx = MockContext()
@@ -214,22 +215,19 @@ class SnekboxTests(unittest.TestCase):
"@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
)
- @async_test
async def test_eval_command_call_help(self):
"""Test if the eval command call the help command if no code is provided."""
- ctx = MockContext()
- ctx.invoke = AsyncMock()
+ ctx = MockContext(command="sentinel")
await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
- ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval")
+ ctx.send_help.assert_called_once_with("sentinel")
- @async_test
async def test_send_eval(self):
"""Test the send_eval function."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
+
self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
@@ -244,14 +242,13 @@ class SnekboxTests(unittest.TestCase):
self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
self.cog.format_output.assert_called_once_with('')
- @async_test
async def test_send_eval_with_paste_link(self):
"""Test the send_eval function with a too long output that generate a paste link."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
+
self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
@@ -267,14 +264,12 @@ class SnekboxTests(unittest.TestCase):
self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
self.cog.format_output.assert_called_once_with('Way too long beard')
- @async_test
async def test_send_eval_with_non_zero_eval(self):
"""Test the send_eval function with a code returning a non-zero code."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
@@ -289,25 +284,33 @@ class SnekboxTests(unittest.TestCase):
self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
self.cog.format_output.assert_not_called()
- @async_test
- async def test_continue_eval_does_continue(self):
+ @patch("bot.cogs.snekbox.partial")
+ async def test_continue_eval_does_continue(self, partial_mock):
"""Test that the continue_eval function does continue if required conditions are met."""
ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
response = MockMessage(delete=AsyncMock())
- new_msg = MockMessage(content='!e NewCode')
+ new_msg = MockMessage()
self.bot.wait_for.side_effect = ((None, new_msg), None)
+ expected = "NewCode"
+ self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
actual = await self.cog.continue_eval(ctx, response)
- self.assertEqual(actual, 'NewCode')
- self.bot.wait_for.has_calls(
- call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10),
- call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ self.cog.get_code.assert_awaited_once_with(new_msg)
+ self.assertEqual(actual, expected)
+ self.bot.wait_for.assert_has_awaits(
+ (
+ call(
+ 'message_edit',
+ check=partial_mock(snekbox.predicate_eval_message_edit, ctx),
+ timeout=snekbox.REEVAL_TIMEOUT,
+ ),
+ call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ )
)
ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
ctx.message.clear_reactions.assert_called_once()
response.delete.assert_called_once()
- @async_test
async def test_continue_eval_does_not_continue(self):
ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))
self.bot.wait_for.side_effect = asyncio.TimeoutError
@@ -316,6 +319,32 @@ class SnekboxTests(unittest.TestCase):
self.assertEqual(actual, None)
ctx.message.clear_reactions.assert_called_once()
+ async def test_get_code(self):
+ """Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
+ prefix = constants.Bot.prefix
+ subtests = (
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"),
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None),
+ (MagicMock(spec=commands.Command), f"{prefix}tags get foo"),
+ (None, "print(123)")
+ )
+
+ for command, content, *expected_code in subtests:
+ if not expected_code:
+ expected_code = content
+ else:
+ [expected_code] = expected_code
+
+ with self.subTest(content=content, expected_code=expected_code):
+ self.bot.get_context.reset_mock()
+ self.bot.get_context.return_value = MockContext(command=command)
+ message = MockMessage(content=content)
+
+ actual_code = await self.cog.get_code(message)
+
+ self.bot.get_context.assert_awaited_once_with(message)
+ self.assertEqual(actual_code, expected_code)
+
def test_predicate_eval_message_edit(self):
"""Test the predicate_eval_message_edit function."""
msg0 = MockMessage(id=1, content='abc')
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index a54b839d7..a10124d2d 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -1,56 +1,89 @@
-import asyncio
-import logging
import unittest
+from re import Match
+from unittest import mock
from unittest.mock import MagicMock
from discord import Colour
-from bot.cogs.token_remover import (
- DELETION_MESSAGE_TEMPLATE,
- TokenRemover,
- setup as setup_cog,
-)
-from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import AsyncMock, MockBot, MockMessage
+from bot import constants
+from bot.cogs import token_remover
+from bot.cogs.moderation import ModLog
+from bot.cogs.token_remover import Token, TokenRemover
+from tests.helpers import MockBot, MockMessage, autospec
-class TokenRemoverTests(unittest.TestCase):
+class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
"""Tests the `TokenRemover` cog."""
def setUp(self):
"""Adds the cog, a bot, and a message to the instance for usage in tests."""
self.bot = MockBot()
- self.bot.get_cog.return_value = MagicMock()
- self.bot.get_cog.return_value.send_log_message = AsyncMock()
self.cog = TokenRemover(bot=self.bot)
- self.msg = MockMessage(id=555, content='')
- self.msg.author.__str__ = MagicMock()
- self.msg.author.__str__.return_value = 'lemon'
- self.msg.author.bot = False
- self.msg.author.avatar_url_as.return_value = 'picture-lemon.png'
- self.msg.author.id = 42
- self.msg.author.mention = '@lemon'
+ self.msg = MockMessage(id=555, content="hello world")
self.msg.channel.mention = "#lemonade-stand"
+ self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
+ self.msg.author.avatar_url_as.return_value = "picture-lemon.png"
- def test_is_valid_user_id_is_true_for_numeric_content(self):
- """A string decoding to numeric characters is a valid user ID."""
- # MTIz = base64(123)
- self.assertTrue(TokenRemover.is_valid_user_id('MTIz'))
+ def test_is_valid_user_id_valid(self):
+ """Should consider user IDs valid if they decode entirely to ASCII digits."""
+ ids = (
+ "NDcyMjY1OTQzMDYyNDEzMzMy",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0",
+ "NDY3MjIzMjMwNjUwNzc3NjQx",
+ )
+
+ for user_id in ids:
+ with self.subTest(user_id=user_id):
+ result = TokenRemover.is_valid_user_id(user_id)
+ self.assertTrue(result)
- def test_is_valid_user_id_is_false_for_alphabetic_content(self):
- """A string decoding to alphabetic characters is not a valid user ID."""
- # YWJj = base64(abc)
- self.assertFalse(TokenRemover.is_valid_user_id('YWJj'))
+ def test_is_valid_user_id_invalid(self):
+ """Should consider non-digit and non-ASCII IDs invalid."""
+ ids = (
+ ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
+ ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
+ ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
+ ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
+ ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
- def test_is_valid_timestamp_is_true_for_valid_timestamps(self):
- """A string decoding to a valid timestamp should be recognized as such."""
- self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A'))
+ for user_id, msg in ids:
+ with self.subTest(msg=msg):
+ result = TokenRemover.is_valid_user_id(user_id)
+ self.assertFalse(result)
- def test_is_valid_timestamp_is_false_for_invalid_values(self):
- """A string not decoding to a valid timestamp should not be recognized as such."""
- # MTIz = base64(123)
- self.assertFalse(TokenRemover.is_valid_timestamp('MTIz'))
+ def test_is_valid_timestamp_valid(self):
+ """Should consider timestamps valid if they're greater than the Discord epoch."""
+ timestamps = (
+ "XsyRkw",
+ "Xrim9Q",
+ "XsyR-w",
+ "XsySD_",
+ "Dn9r_A",
+ )
+
+ for timestamp in timestamps:
+ with self.subTest(timestamp=timestamp):
+ result = TokenRemover.is_valid_timestamp(timestamp)
+ self.assertTrue(result)
+
+ def test_is_valid_timestamp_invalid(self):
+ """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
+ timestamps = (
+ ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
+ ("ew", "123"),
+ ("AoIKgA", "42076800"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for timestamp, msg in timestamps:
+ with self.subTest(msg=msg):
+ result = TokenRemover.is_valid_timestamp(timestamp)
+ self.assertFalse(result)
def test_mod_log_property(self):
"""The `mod_log` property should ask the bot to return the `ModLog` cog."""
@@ -58,74 +91,206 @@ class TokenRemoverTests(unittest.TestCase):
self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)
self.bot.get_cog.assert_called_once_with('ModLog')
- def test_ignores_bot_messages(self):
- """When the message event handler is called with a bot message, nothing is done."""
+ async def test_on_message_edit_uses_on_message(self):
+ """The edit listener should delegate handling of the message to the normal listener."""
+ self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True)
+
+ await self.cog.on_message_edit(MockMessage(), self.msg)
+ self.cog.on_message.assert_awaited_once_with(self.msg)
+
+ @autospec(TokenRemover, "find_token_in_message", "take_action")
+ async def test_on_message_takes_action(self, find_token_in_message, take_action):
+ """Should take action if a valid token is found when a message is sent."""
+ cog = TokenRemover(self.bot)
+ found_token = "foobar"
+ find_token_in_message.return_value = found_token
+
+ await cog.on_message(self.msg)
+
+ find_token_in_message.assert_called_once_with(self.msg)
+ take_action.assert_awaited_once_with(cog, self.msg, found_token)
+
+ @autospec(TokenRemover, "find_token_in_message", "take_action")
+ async def test_on_message_skips_missing_token(self, find_token_in_message, take_action):
+ """Shouldn't take action if a valid token isn't found when a message is sent."""
+ cog = TokenRemover(self.bot)
+ find_token_in_message.return_value = False
+
+ await cog.on_message(self.msg)
+
+ find_token_in_message.assert_called_once_with(self.msg)
+ take_action.assert_not_awaited()
+
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_ignores_bot_messages(self, token_re):
+ """The token finder should ignore messages authored by bots."""
self.msg.author.bot = True
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_ignores_messages_without_tokens(self):
- """Messages without anything looking like a token are ignored."""
- for content in ('', 'lemon wins'):
- with self.subTest(content=content):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_ignores_messages_with_invalid_tokens(self):
- """Messages with values that are invalid tokens are ignored."""
- for content in ('foo.bar.baz', 'x.y.'):
- with self.subTest(content=content):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- self.assertIsNone(asyncio.run(coroutine))
-
- def test_censors_valid_tokens(self):
- """Valid tokens are censored."""
- cases = (
- # (content, censored_token)
- ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'),
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_not_called()
+
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_no_matches(self, token_re):
+ """None should be returned if the regex matches no tokens in a message."""
+ token_re.finditer.return_value = ()
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec("bot.cogs.token_remover", "Token")
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
+ """The first match with a valid user ID and timestamp should be returned as a `Token`."""
+ matches = [
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ ]
+ tokens = [
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ ]
+
+ token_re.finditer.return_value = matches
+ token_cls.side_effect = tokens
+ is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid.
+ is_valid_timestamp.return_value = True
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertEqual(tokens[1], return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec("bot.cogs.token_remover", "Token")
+ @autospec("bot.cogs.token_remover", "TOKEN_RE")
+ def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
+ """None should be returned if no matches have valid user IDs or timestamps."""
+ token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
+ token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
+ is_valid_id.return_value = False
+ is_valid_timestamp.return_value = False
+
+ return_value = TokenRemover.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+ token_re.finditer.assert_called_once_with(self.msg.content)
+
+ def test_regex_invalid_tokens(self):
+ """Messages without anything looking like a token are not matched."""
+ tokens = (
+ "",
+ "lemon wins",
+ "..",
+ "x.y",
+ "x.y.",
+ ".y.z",
+ ".y.",
+ "..z",
+ "x..z",
+ " . . ",
+ "\n.\n.\n",
+ "hellö.world.bye",
+ "base64.nötbåse64.morebase64",
+ "19jd3J.dfkm3d.€víł§tüff",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = token_remover.TOKEN_RE.findall(token)
+ self.assertEqual(len(results), 0)
+
+ def test_regex_valid_tokens(self):
+ """Messages that look like tokens should be matched."""
+ # Don't worry, these tokens have been invalidated.
+ tokens = (
+ "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
+ "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
+ "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
)
- for content, censored_token in cases:
- with self.subTest(content=content, censored_token=censored_token):
- self.msg.content = content
- coroutine = self.cog.on_message(self.msg)
- with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm:
- self.assertIsNone(asyncio.run(coroutine)) # no return value
-
- [line] = cm.output
- log_message = (
- "Censored a seemingly valid token sent by "
- "lemon (`42`) in #lemonade-stand, "
- f"token was `{censored_token}`"
- )
- self.assertIn(log_message, line)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_called_once_with(
- DELETION_MESSAGE_TEMPLATE.format(mention='@lemon')
- )
- self.bot.get_cog.assert_called_with('ModLog')
- self.msg.author.avatar_url_as.assert_called_once_with(static_format='png')
-
- mod_log = self.bot.get_cog.return_value
- mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id)
- mod_log.send_log_message.assert_called_once_with(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=log_message,
- thumbnail='picture-lemon.png',
- channel_id=Channels.mod_alerts
- )
-
-
-class TokenRemoverSetupTests(unittest.TestCase):
- """Tests setup of the `TokenRemover` cog."""
-
- def test_setup(self):
- """Setup of the extension should call add_cog."""
+ for token in tokens:
+ with self.subTest(token=token):
+ results = token_remover.TOKEN_RE.fullmatch(token)
+ self.assertIsNotNone(results, f"{token} was not matched by the regex")
+
+ def test_regex_matches_multiple_valid(self):
+ """Should support multiple matches in the middle of a string."""
+ token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
+ token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
+ message = f"garbage {token_1} hello {token_2} world"
+
+ results = token_remover.TOKEN_RE.finditer(message)
+ results = [match[0] for match in results]
+ self.assertCountEqual((token_1, token_2), results)
+
+ @autospec("bot.cogs.token_remover", "LOG_MESSAGE")
+ def test_format_log_message(self, log_message):
+ """Should correctly format the log message with info from the message and token."""
+ token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ log_message.format.return_value = "Howdy"
+
+ return_value = TokenRemover.format_log_message(self.msg, token)
+
+ self.assertEqual(return_value, log_message.format.return_value)
+ log_message.format.assert_called_once_with(
+ author=self.msg.author,
+ author_id=self.msg.author.id,
+ channel=self.msg.channel.mention,
+ user_id=token.user_id,
+ timestamp=token.timestamp,
+ hmac="x" * len(token.hmac),
+ )
+
+ @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
+ @autospec("bot.cogs.token_remover", "log")
+ @autospec(TokenRemover, "format_log_message")
+ async def test_take_action(self, format_log_message, logger, mod_log_property):
+ """Should delete the message and send a mod log."""
+ cog = TokenRemover(self.bot)
+ mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True)
+ token = mock.create_autospec(Token, spec_set=True, instance=True)
+ log_msg = "testing123"
+
+ mod_log_property.return_value = mod_log
+ format_log_message.return_value = log_msg
+
+ await cog.take_action(self.msg, token)
+
+ self.msg.delete.assert_called_once_with()
+ self.msg.channel.send.assert_called_once_with(
+ token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention)
+ )
+
+ format_log_message.assert_called_once_with(self.msg, token)
+ logger.debug.assert_called_with(log_msg)
+ self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens")
+
+ mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id)
+ mod_log.send_log_message.assert_called_once_with(
+ icon_url=constants.Icons.token_removed,
+ colour=Colour(constants.Colours.soft_red),
+ title="Token removed!",
+ text=log_msg,
+ thumbnail=self.msg.author.avatar_url_as.return_value,
+ channel_id=constants.Channels.mod_alerts
+ )
+
+
+class TokenRemoverExtensionTests(unittest.TestCase):
+ """Tests for the token_remover extension."""
+
+ @autospec("bot.cogs.token_remover", "TokenRemover")
+ def test_extension_setup(self, cog):
+ """The TokenRemover cog should be added."""
bot = MockBot()
- setup_cog(bot)
+ token_remover.setup(bot)
+
+ cog.assert_called_once_with(bot)
bot.add_cog.assert_called_once()
+ self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover))
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index 36c986fe1..0d570f5a3 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple):
n_violations: int
-class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
"""
Abstract class for antispam rule test cases.
@@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta):
@abstractmethod
def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
"""Give expected relevant messages for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
@abstractmethod
def get_report(self, case: DisallowedCase) -> str:
"""Give expected error report for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index e54b4b5b8..d7e779221 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import attachments
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_attachments: int) -> MockMessage:
@@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest):
self.apply = attachments.apply
self.config = {"max": 5, "interval": 10}
- @async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
@@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
index 72f0be0c7..03682966b 100644
--- a/tests/bot/rules/test_burst.py
+++ b/tests/bot/rules/test_burst.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest):
self.apply = burst.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
index 47367a5f8..3275143d5 100644
--- a/tests/bot/rules/test_burst_shared.py
+++ b/tests/bot/rules/test_burst_shared.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst_shared
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest):
self.apply = burst_shared.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""
Cases that do not violate the rule.
@@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
index 7cc36f49e..f1e3c76a7 100644
--- a/tests/bot/rules/test_chars.py
+++ b/tests/bot/rules/test_chars.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import chars
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_chars: int) -> MockMessage:
@@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of chars within limit."""
cases = (
@@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the total amount of chars exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
index 0239b0b00..9a72723e2 100644
--- a/tests/bot/rules/test_discord_emojis.py
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import discord_emojis
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
@@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest):
self.apply = discord_emojis.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of discord emojis within limit."""
cases = (
@@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of discord emojis."""
cases = (
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
index 59e0fb6ef..9bd886a77 100644
--- a/tests/bot/rules/test_duplicates.py
+++ b/tests/bot/rules/test_duplicates.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import duplicates
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, content: str) -> MockMessage:
@@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest):
self.apply = duplicates.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with too many duplicate messages from the same author."""
cases = (
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 3c3f90e5f..b091bd9d7 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import links
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_links: int) -> MockMessage:
@@ -21,7 +21,6 @@ class LinksTests(RuleTest):
"interval": 10
}
- @async_test
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
@@ -34,7 +33,6 @@ class LinksTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ebcdabac6..6444532f2 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_mentions: int) -> MockMessage:
@@ -20,7 +20,6 @@ class TestMentions(RuleTest):
"interval": 10,
}
- @async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
@@ -32,7 +31,6 @@ class TestMentions(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
index d61c4609d..e35377773 100644
--- a/tests/bot/rules/test_newlines.py
+++ b/tests/bot/rules/test_newlines.py
@@ -2,7 +2,7 @@ from typing import Iterable, List
from bot.rules import newlines
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
@@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_total(self):
"""Cases which violate the rule by having too many newlines in total."""
cases = (
@@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest):
self.apply = newlines.apply
self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
- @async_test
async def test_disallows_messages_consecutive(self):
"""Cases which violate the rule due to having too many consecutive newlines."""
cases = (
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
index b339cccf7..26c05d527 100644
--- a/tests/bot/rules/test_role_mentions.py
+++ b/tests/bot/rules/test_role_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import role_mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_mentions: int) -> MockMessage:
@@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest):
self.apply = role_mentions.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of role mentions within limit."""
cases = (
@@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of role mentions."""
cases = (
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index bdfcc73e4..99e942813 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -2,10 +2,9 @@ import unittest
from unittest.mock import MagicMock
from bot import api
-from tests.helpers import async_test
-class APIClientTests(unittest.TestCase):
+class APIClientTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the bot's API client."""
@classmethod
@@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase):
"""The event loop should not be running by default."""
self.assertFalse(api.loop_is_running())
- @async_test
async def test_loop_is_running_in_async_context(self):
"""The event loop should be running in an async context."""
self.assertTrue(api.loop_is_running())
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
index dae7c066c..f10d6fbe8 100644
--- a/tests/bot/test_constants.py
+++ b/tests/bot/test_constants.py
@@ -1,14 +1,40 @@
import inspect
+import typing
import unittest
from bot import constants
+def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool:
+ """
+ Return True if `value` is an instance of the type represented by `annotation`.
+
+ This doesn't account for things like Unions or checking for homogenous types in collections.
+ """
+ origin = typing.get_origin(annotation)
+
+ # This is done in case a bare e.g. `typing.List` is used.
+ # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`.
+ # `get_origin()` does this normalisation for us.
+ type_ = annotation if origin is None else origin
+
+ return isinstance(value, type_)
+
+
+def is_any_instance(value: typing.Any, types: typing.Collection) -> bool:
+ """Return True if `value` is an instance of any type in `types`."""
+ for type_ in types:
+ if is_annotation_instance(value, type_):
+ return True
+
+ return False
+
+
class ConstantsTests(unittest.TestCase):
"""Tests for our constants."""
def test_section_configuration_matches_type_specification(self):
- """The section annotations should match the actual types of the sections."""
+ """"The section annotations should match the actual types of the sections."""
sections = (
cls
@@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):
)
for section in sections:
for name, annotation in section.__annotations__.items():
- with self.subTest(section=section, name=name, annotation=annotation):
+ with self.subTest(section=section.__name__, name=name, annotation=annotation):
value = getattr(section, name)
+ origin = typing.get_origin(annotation)
+ annotation_args = typing.get_args(annotation)
+ failure_msg = f"{value} is not an instance of {annotation}"
- if getattr(annotation, '_name', None) in ('Dict', 'List'):
- self.skipTest("Cannot validate containers yet.")
-
- self.assertIsInstance(value, annotation)
+ if origin is typing.Union:
+ is_instance = is_any_instance(value, annotation_args)
+ self.assertTrue(is_instance, failure_msg)
+ else:
+ is_instance = is_annotation_instance(value, annotation)
+ self.assertTrue(is_instance, failure_msg)
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index 1e5ca62ae..c42111f3f 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -1,5 +1,5 @@
-import asyncio
import datetime
+import re
import unittest
from unittest.mock import MagicMock, patch
@@ -8,6 +8,7 @@ from discord.ext.commands import BadArgument
from bot.converters import (
Duration,
+ HushDurationConverter,
ISODateTime,
TagContentConverter,
TagNameConverter,
@@ -15,7 +16,7 @@ from bot.converters import (
)
-class ConverterTests(unittest.TestCase):
+class ConverterTests(unittest.IsolatedAsyncioTestCase):
"""Tests our custom argument converters."""
@classmethod
@@ -25,7 +26,7 @@ class ConverterTests(unittest.TestCase):
cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
- def test_tag_content_converter_for_valid(self):
+ async def test_tag_content_converter_for_valid(self):
"""TagContentConverter should return correct values for valid input."""
test_values = (
('hello', 'hello'),
@@ -34,10 +35,10 @@ class ConverterTests(unittest.TestCase):
for content, expected_conversion in test_values:
with self.subTest(content=content, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagContentConverter.convert(self.context, content))
+ conversion = await TagContentConverter.convert(self.context, content)
self.assertEqual(conversion, expected_conversion)
- def test_tag_content_converter_for_invalid(self):
+ async def test_tag_content_converter_for_invalid(self):
"""TagContentConverter should raise the proper exception for invalid input."""
test_values = (
('', "Tag contents should not be empty, or filled with whitespace."),
@@ -46,10 +47,10 @@ class ConverterTests(unittest.TestCase):
for value, exception_message in test_values:
with self.subTest(tag_content=value, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagContentConverter.convert(self.context, value))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagContentConverter.convert(self.context, value)
- def test_tag_name_converter_for_valid(self):
+ async def test_tag_name_converter_for_valid(self):
"""TagNameConverter should return the correct values for valid tag names."""
test_values = (
('tracebacks', 'tracebacks'),
@@ -59,10 +60,10 @@ class ConverterTests(unittest.TestCase):
for name, expected_conversion in test_values:
with self.subTest(name=name, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagNameConverter.convert(self.context, name))
+ conversion = await TagNameConverter.convert(self.context, name)
self.assertEqual(conversion, expected_conversion)
- def test_tag_name_converter_for_invalid(self):
+ async def test_tag_name_converter_for_invalid(self):
"""TagNameConverter should raise the correct exception for invalid tag names."""
test_values = (
('👋', "Don't be ridiculous, you can't use that character!"),
@@ -74,29 +75,29 @@ class ConverterTests(unittest.TestCase):
for invalid_name, exception_message in test_values:
with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagNameConverter.convert(self.context, invalid_name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagNameConverter.convert(self.context, invalid_name)
- def test_valid_python_identifier_for_valid(self):
+ async def test_valid_python_identifier_for_valid(self):
"""ValidPythonIdentifier returns valid identifiers unchanged."""
test_values = ('foo', 'lemon')
for name in test_values:
with self.subTest(identifier=name):
- conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ conversion = await ValidPythonIdentifier.convert(self.context, name)
self.assertEqual(name, conversion)
- def test_valid_python_identifier_for_invalid(self):
+ async def test_valid_python_identifier_for_invalid(self):
"""ValidPythonIdentifier raises the proper exception for invalid identifiers."""
test_values = ('nested.stuff', '#####')
for name in test_values:
with self.subTest(identifier=name):
exception_message = f'`{name}` is not a valid Python identifier'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await ValidPythonIdentifier.convert(self.context, name)
- def test_duration_converter_for_valid(self):
+ async def test_duration_converter_for_valid(self):
"""Duration returns the correct `datetime` for valid duration strings."""
test_values = (
# Simple duration strings
@@ -158,35 +159,35 @@ class ConverterTests(unittest.TestCase):
mock_datetime.utcnow.return_value = self.fixed_utc_now
with self.subTest(duration=duration, duration_dict=duration_dict):
- converted_datetime = asyncio.run(converter.convert(self.context, duration))
+ converted_datetime = await converter.convert(self.context, duration)
self.assertEqual(converted_datetime, expected_datetime)
- def test_duration_converter_for_invalid(self):
+ async def test_duration_converter_for_invalid(self):
"""Duration raises the right exception for invalid duration strings."""
test_values = (
# Units in wrong order
- ('1d1w'),
- ('1s1y'),
+ '1d1w',
+ '1s1y',
# Duplicated units
- ('1 year 2 years'),
- ('1 M 10 minutes'),
+ '1 year 2 years',
+ '1 M 10 minutes',
# Unknown substrings
- ('1MVes'),
- ('1y3breads'),
+ '1MVes',
+ '1y3breads',
# Missing amount
- ('ym'),
+ 'ym',
# Incorrect whitespace
- (" 1y"),
- ("1S "),
- ("1y 1m"),
+ " 1y",
+ "1S ",
+ "1y 1m",
# Garbage
- ('Guido van Rossum'),
- ('lemon lemon lemon lemon lemon lemon lemon'),
+ 'Guido van Rossum',
+ 'lemon lemon lemon lemon lemon lemon lemon',
)
converter = Duration()
@@ -194,10 +195,21 @@ class ConverterTests(unittest.TestCase):
for invalid_duration in test_values:
with self.subTest(invalid_duration=invalid_duration):
exception_message = f'`{invalid_duration}` is not a valid duration string.'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, invalid_duration))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_duration)
- def test_isodatetime_converter_for_valid(self):
+ @patch("bot.converters.datetime")
+ async def test_duration_converter_out_of_range(self, mock_datetime):
+ """Duration converter should raise BadArgument if datetime raises a ValueError."""
+ mock_datetime.__add__.side_effect = ValueError
+ mock_datetime.utcnow.return_value = mock_datetime
+
+ duration = f"{datetime.MAXYEAR}y"
+ exception_message = f"`{duration}` results in a datetime outside the supported range."
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await Duration().convert(self.context, duration)
+
+ async def test_isodatetime_converter_for_valid(self):
"""ISODateTime converter returns correct datetime for valid datetime string."""
test_values = (
# `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
@@ -242,32 +254,61 @@ class ConverterTests(unittest.TestCase):
for datetime_string, expected_dt in test_values:
with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt):
- converted_dt = asyncio.run(converter.convert(self.context, datetime_string))
+ converted_dt = await converter.convert(self.context, datetime_string)
self.assertIsNone(converted_dt.tzinfo)
self.assertEqual(converted_dt, expected_dt)
- def test_isodatetime_converter_for_invalid(self):
+ async def test_isodatetime_converter_for_invalid(self):
"""ISODateTime converter raises the correct exception for invalid datetime strings."""
test_values = (
# Make sure it doesn't interfere with the Duration converter
- ('1Y'),
- ('1d'),
- ('1H'),
+ '1Y',
+ '1d',
+ '1H',
# Check if it fails when only providing the optional time part
- ('10:10:10'),
- ('10:00'),
+ '10:10:10',
+ '10:00',
# Invalid date format
- ('19-01-01'),
+ '19-01-01',
# Other non-valid strings
- ('fisk the tag master'),
+ 'fisk the tag master',
)
converter = ISODateTime()
for datetime_string in test_values:
with self.subTest(datetime_string=datetime_string):
exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string"
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, datetime_string))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, datetime_string)
+
+ async def test_hush_duration_converter_for_valid(self):
+ """HushDurationConverter returns correct value for minutes duration or `"forever"` strings."""
+ test_values = (
+ ("0", 0),
+ ("15", 15),
+ ("10", 10),
+ ("5m", 5),
+ ("5M", 5),
+ ("forever", None),
+ )
+ converter = HushDurationConverter()
+ for minutes_string, expected_minutes in test_values:
+ with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes):
+ converted = await converter.convert(self.context, minutes_string)
+ self.assertEqual(expected_minutes, converted)
+
+ async def test_hush_duration_converter_for_invalid(self):
+ """HushDurationConverter raises correct exception for invalid minutes duration strings."""
+ test_values = (
+ ("16", "Duration must be at most 15 minutes."),
+ ("10d", "10d is not a valid minutes duration."),
+ ("-1", "-1 is not a valid minutes duration."),
+ )
+ converter = HushDurationConverter()
+ for invalid_minutes_string, exception_message in test_values:
+ with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message):
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_minutes_string)
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
new file mode 100644
index 000000000..3d450caa0
--- /dev/null
+++ b/tests/bot/test_decorators.py
@@ -0,0 +1,147 @@
+import collections
+import unittest
+import unittest.mock
+
+from bot import constants
+from bot.decorators import in_whitelist
+from bot.utils.checks import InWhitelistCheckFailure
+from tests import helpers
+
+InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
+
+
+class InWhitelistTests(unittest.TestCase):
+ """Tests for the `in_whitelist` check."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up helpers that only need to be defined once."""
+ cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456)
+ cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654)
+ cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666)
+ cls.dm_channel = helpers.MockDMChannel()
+
+ cls.non_staff_member = helpers.MockMember()
+ cls.staff_role = helpers.MockRole(id=121212)
+ cls.staff_member = helpers.MockMember(roles=(cls.staff_role,))
+
+ cls.channels = (cls.bot_commands.id,)
+ cls.categories = (cls.help_channel.category_id,)
+ cls.roles = (cls.staff_role.id,)
+
+ def test_predicate_returns_true_for_whitelisted_context(self):
+ """The predicate should return `True` if a whitelisted context was passed to it."""
+ test_cases = (
+ InWhitelistTestCase(
+ kwargs={"channels": self.channels},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="In whitelisted channels by members without whitelisted roles",
+ ),
+ InWhitelistTestCase(
+ kwargs={"redirect": self.bot_commands.id},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="`redirect` should be implicitly added to `channels`",
+ ),
+ InWhitelistTestCase(
+ kwargs={"categories": self.categories},
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member),
+ description="Whitelisted category without whitelisted role",
+ ),
+ InWhitelistTestCase(
+ kwargs={"roles": self.roles},
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member),
+ description="Whitelisted role outside of whitelisted channel/category"
+ ),
+ InWhitelistTestCase(
+ kwargs={
+ "channels": self.channels,
+ "categories": self.categories,
+ "roles": self.roles,
+ "redirect": self.bot_commands,
+ },
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member),
+ description="Case with all whitelist kwargs used",
+ ),
+ )
+
+ for test_case in test_cases:
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ self.assertTrue(predicate(test_case.ctx))
+
+ def test_predicate_raises_exception_for_non_whitelisted_context(self):
+ """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context."""
+ test_cases = (
+ # Failing check with explicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": self.bot_commands.id,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an explicit redirect channel",
+ ),
+
+ # Failing check with implicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an implicit redirect channel",
+ ),
+
+ # Failing check without `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check without a redirect channel",
+ ),
+
+ # Command issued in DM channel
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me),
+ description="Commands issued in DM channel should be rejected",
+ ),
+ )
+
+ for test_case in test_cases:
+ if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None:
+ # There are two cases in which we have a redirect channel:
+ # 1. No redirect channel was passed; the default value of `bot_commands` is used
+ # 2. An explicit `redirect` is set that is "not None"
+ redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands)
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ # If an explicit `None` was passed for `redirect`, there is no redirect channel
+ redirect_message = ""
+
+ exception_message = f"You are not allowed to use that command{redirect_message}."
+
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message):
+ predicate(test_case.ctx)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
deleted file mode 100644
index d7bcc3ba6..000000000
--- a/tests/bot/test_utils.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import unittest
-
-from bot import utils
-
-
-class CaseInsensitiveDictTests(unittest.TestCase):
- """Tests for the `CaseInsensitiveDict` container."""
-
- def test_case_insensitive_key_access(self):
- """Tests case insensitive key access and storage."""
- instance = utils.CaseInsensitiveDict()
-
- key = 'LEMON'
- value = 'trees'
-
- instance[key] = value
- self.assertIn(key, instance)
- self.assertEqual(instance.get(key), value)
- self.assertEqual(instance.get(key.casefold()), value)
- self.assertEqual(instance.pop(key.casefold()), value)
- self.assertNotIn(key, instance)
- self.assertNotIn(key.casefold(), instance)
-
- instance.setdefault(key, value)
- del instance[key]
- self.assertNotIn(key, instance)
-
- def test_initialization_from_kwargs(self):
- """Tests creating the dictionary from keyword arguments."""
- instance = utils.CaseInsensitiveDict({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')
-
- def test_update_from_other_mapping(self):
- """Tests updating the dictionary from another mapping."""
- instance = utils.CaseInsensitiveDict()
- instance.update({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 9610771e5..de72e5748 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -1,6 +1,8 @@
import unittest
+from unittest.mock import MagicMock
from bot.utils import checks
+from bot.utils.checks import InWhitelistCheckFailure
from tests.helpers import MockContext, MockRole
@@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):
self.ctx.author.roles.append(MockRole(id=role_id))
self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
- def test_in_channel_check_for_correct_channel(self):
- self.ctx.channel.id = 42
- self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_correct_channel(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list."""
+ channel_id = 3
+ self.ctx.channel.id = channel_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id]))
- def test_in_channel_check_for_incorrect_channel(self):
- self.ctx.channel.id = 42 + 10
- self.assertFalse(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_incorrect_channel(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match."""
+ self.ctx.channel.id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, [4])
+
+ def test_in_whitelist_check_correct_category(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list."""
+ category_id = 3
+ self.ctx.channel.category_id = category_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id]))
+
+ def test_in_whitelist_check_incorrect_category(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match."""
+ self.ctx.channel.category_id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, categories=[4])
+
+ def test_in_whitelist_check_correct_role(self):
+ """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6]))
+
+ def test_in_whitelist_check_incorrect_role(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, roles=[4])
+
+ def test_in_whitelist_check_fail_silently(self):
+ """`in_whitelist_check` test no exception raised if `fail_silently` is `True`"""
+ self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True))
+
+ def test_in_whitelist_check_complex(self):
+ """`in_whitelist_check` test with multiple parameters"""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.ctx.channel.category_id = 3
+ self.ctx.channel.id = 5
+ self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2]))
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
new file mode 100644
index 000000000..a2f0fe55d
--- /dev/null
+++ b/tests/bot/utils/test_redis_cache.py
@@ -0,0 +1,265 @@
+import asyncio
+import unittest
+
+import fakeredis.aioredis
+
+from bot.utils import RedisCache
+from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError
+from tests import helpers
+
+
+class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the RedisCache class from utils.redis_dict.py."""
+
+ async def asyncSetUp(self): # noqa: N802
+ """Sets up the objects that only have to be initialized once."""
+ self.bot = helpers.MockBot()
+ self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
+
+ # Okay, so this is necessary so that we can create a clean new
+ # class for every test method, and we want that because it will
+ # ensure we get a fresh loop, which is necessary for test_increment_lock
+ # to be able to pass.
+ class DummyCog:
+ """A dummy cog, for dummies."""
+
+ redis = RedisCache()
+
+ def __init__(self, bot: helpers.MockBot):
+ self.bot = bot
+
+ self.cog = DummyCog(self.bot)
+
+ await self.cog.redis.clear()
+
+ def test_class_attribute_namespace(self):
+ """Test that RedisDict creates a namespace automatically for class attributes."""
+ self.assertEqual(self.cog.redis._namespace, "DummyCog.redis")
+
+ async def test_class_attribute_required(self):
+ """Test that errors are raised when not assigned as a class attribute."""
+ bad_cache = RedisCache()
+ self.assertIs(bad_cache._namespace, None)
+
+ with self.assertRaises(RuntimeError):
+ await bad_cache.set("test", "me_up_deadman")
+
+ async def test_set_get_item(self):
+ """Test that users can set and get items from the RedisDict."""
+ test_cases = (
+ ('favorite_fruit', 'melon'),
+ ('favorite_number', 86),
+ ('favorite_fraction', 86.54),
+ ('favorite_boolean', False),
+ ('other_boolean', True),
+ )
+
+ # Test that we can get and set different types.
+ for test in test_cases:
+ await self.cog.redis.set(*test)
+ self.assertEqual(await self.cog.redis.get(test[0]), test[1])
+
+ # Test that .get allows a default value
+ self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+
+ async def test_set_item_type(self):
+ """Test that .set rejects keys and values that are not permitted."""
+ fruits = ["lemon", "melon", "apple"]
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(fruits, "nice")
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(4.23, "nice")
+
+ async def test_delete_item(self):
+ """Test that .delete allows us to delete stuff from the RedisCache."""
+ # Add an item and verify that it gets added
+ await self.cog.redis.set("internet", "firetruck")
+ self.assertEqual(await self.cog.redis.get("internet"), "firetruck")
+
+ # Delete that item and verify that it gets deleted
+ await self.cog.redis.delete("internet")
+ self.assertIs(await self.cog.redis.get("internet"), None)
+
+ async def test_contains(self):
+ """Test that we can check membership with .contains."""
+ await self.cog.redis.set('favorite_country', "Burkina Faso")
+
+ self.assertIs(await self.cog.redis.contains('favorite_country'), True)
+ self.assertIs(await self.cog.redis.contains('favorite_dentist'), False)
+
+ async def test_items(self):
+ """Test that the RedisDict can be iterated."""
+ # Set up our test cases in the Redis cache
+ test_cases = [
+ ('favorite_turtle', 'Donatello'),
+ ('second_favorite_turtle', 'Leonardo'),
+ ('third_favorite_turtle', 'Raphael'),
+ ]
+ for key, value in test_cases:
+ await self.cog.redis.set(key, value)
+
+ # Consume the AsyncIterator into a regular list, easier to compare that way.
+ redis_items = [item for item in await self.cog.redis.items()]
+
+ # These sequences are probably in the same order now, but probably
+ # isn't good enough for tests. Let's not rely on .hgetall always
+ # returning things in sequence, and just sort both lists to be safe.
+ redis_items = sorted(redis_items)
+ test_cases = sorted(test_cases)
+
+ # If these are equal now, everything works fine.
+ self.assertSequenceEqual(test_cases, redis_items)
+
+ async def test_length(self):
+ """Test that we can get the correct .length from the RedisDict."""
+ await self.cog.redis.set('one', 1)
+ await self.cog.redis.set('two', 2)
+ await self.cog.redis.set('three', 3)
+ self.assertEqual(await self.cog.redis.length(), 3)
+
+ await self.cog.redis.set('four', 4)
+ self.assertEqual(await self.cog.redis.length(), 4)
+
+ async def test_to_dict(self):
+ """Test that the .to_dict method returns a workable dictionary copy."""
+ copy = await self.cog.redis.to_dict()
+ local_copy = {key: value for key, value in await self.cog.redis.items()}
+ self.assertIs(type(copy), dict)
+ self.assertDictEqual(copy, local_copy)
+
+ async def test_clear(self):
+ """Test that the .clear method removes the entire hash."""
+ await self.cog.redis.set('teddy', 'with me')
+ await self.cog.redis.set('in my dreams', 'you have a weird hat')
+ self.assertEqual(await self.cog.redis.length(), 2)
+
+ await self.cog.redis.clear()
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_pop(self):
+ """Test that we can .pop an item from the RedisDict."""
+ await self.cog.redis.set('john', 'was afraid')
+
+ self.assertEqual(await self.cog.redis.pop('john'), 'was afraid')
+ self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck')
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_update(self):
+ """Test that we can .update the RedisDict with multiple items."""
+ await self.cog.redis.set("reckfried", "lona")
+ await self.cog.redis.set("bel air", "prince")
+ await self.cog.redis.update({
+ "reckfried": "jona",
+ "mega": "hungry, though",
+ })
+
+ result = {
+ "reckfried": "jona",
+ "bel air": "prince",
+ "mega": "hungry, though",
+ }
+ self.assertDictEqual(await self.cog.redis.to_dict(), result)
+
+ def test_typestring_conversion(self):
+ """Test the typestring-related helper functions."""
+ conversion_tests = (
+ (12, "i|12"),
+ (12.4, "f|12.4"),
+ ("cowabunga", "s|cowabunga"),
+ )
+
+ # Test conversion to typestring
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_to_typestring(_input), expected)
+
+ # Test conversion from typestrings
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_from_typestring(expected), _input)
+
+ # Test that exceptions are raised on invalid input
+ with self.assertRaises(TypeError):
+ self.cog.redis._value_to_typestring(["internet"])
+ self.cog.redis._value_from_typestring("o|firedog")
+
+ async def test_increment_decrement(self):
+ """Test .increment and .decrement methods."""
+ await self.cog.redis.set("entropic", 5)
+ await self.cog.redis.set("disentropic", 12.5)
+
+ # Test default increment
+ await self.cog.redis.increment("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 6)
+
+ # Test default decrement
+ await self.cog.redis.decrement("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # Test float increment with float
+ await self.cog.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 14.5)
+
+ # Test float increment with int
+ await self.cog.redis.increment("disentropic", 2)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 16.5)
+
+ # Test negative increments, because why not.
+ await self.cog.redis.increment("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 0)
+
+ # Negative decrements? Sure.
+ await self.cog.redis.decrement("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # What about if we use a negative float to decrement an int?
+ # This should convert the type into a float.
+ await self.cog.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 7.5)
+
+ # Let's test that they raise the right errors
+ with self.assertRaises(KeyError):
+ await self.cog.redis.increment("doesn't_exist!")
+
+ await self.cog.redis.set("stringthing", "stringthing")
+ with self.assertRaises(TypeError):
+ await self.cog.redis.increment("stringthing")
+
+ async def test_increment_lock(self):
+ """Test that we can't produce a race condition in .increment."""
+ await self.cog.redis.set("test_key", 0)
+ tasks = []
+
+ # Increment this a lot in different tasks
+ for _ in range(100):
+ task = asyncio.create_task(
+ self.cog.redis.increment("test_key", 1)
+ )
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ # Confirm that the value has been incremented the exact right number of times.
+ value = await self.cog.redis.get("test_key")
+ self.assertEqual(value, 100)
+
+ async def test_exceptions_raised(self):
+ """Testing that the various RuntimeErrors are reachable."""
+ class MyCog:
+ cache = RedisCache()
+
+ def __init__(self):
+ self.other_cache = RedisCache()
+
+ cog = MyCog()
+
+ # Raises "No Bot instance"
+ with self.assertRaises(NoBotInstanceError):
+ await cog.cache.get("john")
+
+ # Raises "RedisCache has no namespace"
+ with self.assertRaises(NoNamespaceError):
+ await cog.other_cache.get("was")
+
+ # Raises "You must access the RedisCache instance through the cog instance"
+ with self.assertRaises(NoParentInstanceError):
+ await MyCog.cache.get("afraid")
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 69f35f2f5..694d3a40f 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -1,12 +1,11 @@
import asyncio
import unittest
from datetime import datetime, timezone
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
from dateutil.relativedelta import relativedelta
from bot.utils import time
-from tests.helpers import AsyncMock
class TimeTests(unittest.TestCase):
@@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):
for max_units in test_cases:
with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
- self.assertEqual(str(error), 'max_units must be positive')
+ self.assertEqual(str(error.exception), 'max_units must be positive')
def test_parse_rfc1123(self):
"""Testing parse_rfc1123."""
diff --git a/tests/helpers.py b/tests/helpers.py
index 6f50f6ae3..facc4e1af 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,18 +1,18 @@
from __future__ import annotations
-import asyncio
import collections
-import functools
-import inspect
import itertools
import logging
import unittest.mock
-from typing import Any, Iterable, Optional
+from asyncio import AbstractEventLoop
+from typing import Callable, Iterable, Optional
import discord
+from aiohttp import ClientSession
from discord.ext.commands import Context
from bot.api import APIClient
+from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
@@ -26,19 +26,22 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def async_test(wrapped):
- """
- Run a test case via asyncio.
- Example:
- >>> @async_test
- ... async def lemon_wins():
- ... assert True
- """
+def autospec(target, *attributes: str, **kwargs) -> Callable:
+ """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = {'spec_set': True, 'autospec': True, **kwargs}
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return asyncio.run(wrapped(*args, **kwargs))
- return wrapper
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ func = patcher(func)
+ return func
+ return decorator
class HashableMixin(discord.mixins.EqualityComparable):
@@ -69,24 +72,31 @@ class CustomMockMixin:
"""
Provides common functionality for our custom Mock types.
- The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
- function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
- of making sure child mocks are instantiated with the correct class. By default, the mock of the
- children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
- `child_mock_type` on the custom mock inheriting from this mixin.
+ The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock
+ object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the
+ class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional
+ attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The
+ class method `spec_set` can be overwritten with the object that should be uses as the specification
+ for the mock.
+
+ Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to
+ implement custom behavior.
"""
child_mock_type = unittest.mock.MagicMock
discord_id = itertools.count(0)
+ spec_set = None
+ additional_spec_asyncs = None
- def __init__(self, spec_set: Any = None, **kwargs):
+ def __init__(self, **kwargs):
name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually.
- super().__init__(spec_set=spec_set, **kwargs)
+ super().__init__(spec_set=self.spec_set, **kwargs)
+
+ if self.additional_spec_asyncs:
+ self._spec_asyncs.extend(self.additional_spec_asyncs)
if name:
self.name = name
- if spec_set:
- self._extract_coroutine_methods_from_spec_instance(spec_set)
def _get_child_mock(self, **kw):
"""
@@ -100,7 +110,16 @@ class CustomMockMixin:
This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
"""
- klass = self.child_mock_type
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return unittest.mock.AsyncMock(**kw)
+
+ _type = type(self)
+ if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = unittest.mock.AsyncMock
+ else:
+ klass = self.child_mock_type
if self._mock_sealed:
attribute = "." + kw["name"] if "name" in kw else "()"
@@ -109,107 +128,6 @@ class CustomMockMixin:
return klass(**kw)
- def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
- """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
- for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
- setattr(self, name, AsyncMock())
-
-
-# TODO: Remove me in Python 3.8
-class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
-
-
-class AsyncContextManagerMock(unittest.mock.MagicMock):
- def __init__(self, return_value: Any):
- super().__init__()
- self._return_value = return_value
-
- async def __aenter__(self):
- return self._return_value
-
- async def __aexit__(self, *args):
- pass
-
-
-class AsyncIteratorMock:
- """
- A class to mock asynchronous iterators.
-
- This allows async for, which is used in certain Discord.py objects. For example,
- an async iterator is returned by the Reaction.users() method.
- """
-
- def __init__(self, iterable: Iterable = None):
- if iterable is None:
- iterable = []
-
- self.iter = iter(iterable)
- self.iterable = iterable
-
- self.call_count = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return next(self.iter)
- except StopIteration:
- raise StopAsyncIteration
-
- def __call__(self):
- """
- Keeps track of the number of times an instance has been called.
-
- This is useful, since it typically shows that the iterator has actually been used somewhere after we have
- instantiated the mock for an attribute that normally returns an iterator when called.
- """
- self.call_count += 1
- return self
-
- @property
- def return_value(self):
- """Makes `self.iterable` accessible as self.return_value."""
- return self.iterable
-
- @return_value.setter
- def return_value(self, iterable):
- """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
- self.iter = iter(iterable)
- self.iterable = iterable
-
- def assert_called(self):
- """Asserts if the AsyncIteratorMock instance has been called at least once."""
- if self.call_count == 0:
- raise AssertionError("Expected AsyncIteratorMock to have been called.")
-
- def assert_called_once(self):
- """Asserts if the AsyncIteratorMock instance has been called exactly once."""
- if self.call_count != 1:
- raise AssertionError(
- f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
- )
-
- def assert_not_called(self):
- """Asserts if the AsyncIteratorMock instance has not been called."""
- if self.call_count != 0:
- raise AssertionError(
- f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
- )
-
- def reset_mock(self):
- """Resets the call count, but not the return value or iterator."""
- self.call_count = 0
-
# Create a guild instance to get a realistic Mock of `discord.Guild`
guild_data = {
@@ -260,9 +178,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
+ spec_set = guild_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'members': []}
- super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -281,6 +201,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = role_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {
'id': next(self.discord_id),
@@ -289,7 +211,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
'colour': discord.Colour(0xdeadbf),
'permissions': discord.Permissions(),
}
- super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if isinstance(self.colour, int):
self.colour = discord.Colour(self.colour)
@@ -304,6 +226,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""Simplified position-based comparisons similar to those of `discord.Role`."""
return self.position < other.position
+ def __ge__(self, other):
+ """Simplified position-based comparisons similar to those of `discord.Role`."""
+ return self.position >= other.position
+
# Create a Member instance to get a realistic Mock of `discord.Member`
member_data = {'user': 'lemon', 'roles': [1]}
@@ -318,9 +244,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = member_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -341,9 +269,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.User` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = user_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
@@ -356,15 +286,19 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `bot.api.APIClient` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = APIClient
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=APIClient, **kwargs)
+def _get_mock_loop() -> unittest.mock.Mock:
+ """Return a mocked asyncio.AbstractEventLoop."""
+ loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True)
-# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
-bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
-bot_instance.http_session = None
-bot_instance.api_client = None
+ # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
+ # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
+ # to prevent "has not been awaited"-warnings.
+ loop.create_task.side_effect = lambda coroutine: coroutine.close()
+
+ return loop
class MockBot(CustomMockMixin, unittest.mock.MagicMock):
@@ -374,20 +308,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop())
+ additional_spec_asyncs = ("wait_for", "redis_ready")
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=bot_instance, **kwargs)
- self.api_client = MockAPIClient()
-
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- self.wait_for = AsyncMock()
+ super().__init__(**kwargs)
- # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
- # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
- # to prevent "has not been awaited"-warnings.
- self.loop.create_task.side_effect = lambda coroutine: coroutine.close()
+ self.loop = _get_mock_loop()
+ self.api_client = MockAPIClient(loop=self.loop)
+ self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True)
+ self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)
# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`
@@ -413,15 +343,37 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = channel_instance
- def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
+ def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
- super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"#{self.name}"
+# Create data for the DMChannel instance
+state = unittest.mock.MagicMock()
+me = unittest.mock.MagicMock()
+dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]}
+dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data)
+
+
+class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A MagicMock subclass to mock TextChannel objects.
+
+ Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ spec_set = dm_channel_instance
+
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+
+
# Create a Message instance to get a realistic MagicMock of `discord.Message`
message_data = {
'id': 1,
@@ -455,9 +407,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+ spec_set = context_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=context_instance, **kwargs)
+ super().__init__(**kwargs)
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -474,8 +427,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Attachment` instances. For
more information, see the `MockGuild` docstring.
"""
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=attachment_instance, **kwargs)
+ spec_set = attachment_instance
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
@@ -485,10 +437,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = message_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'attachments': []}
- super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -504,9 +457,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = emoji_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=emoji_instance, **kwargs)
+ super().__init__(**kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -520,9 +474,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=partial_emoji_instance, **kwargs)
+ spec_set = partial_emoji_instance
reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
@@ -535,12 +487,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = reaction_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=reaction_instance, **kwargs)
+ _users = kwargs.pop("users", [])
+ super().__init__(**kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
- self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+ user_iterator = unittest.mock.AsyncMock()
+ user_iterator.__aiter__.return_value = _users
+ self.users.return_value = user_iterator
+
self.__str__.return_value = str(self.emoji)
@@ -554,13 +512,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Webhook` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=webhook_instance, **kwargs)
-
- # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
- # as coroutines. That's why we need to set the methods manually.
- self.send = AsyncMock()
- self.edit = AsyncMock()
- self.delete = AsyncMock()
- self.execute = AsyncMock()
+ spec_set = webhook_instance
+ additional_spec_asyncs = ("send", "edit", "delete", "execute")
diff --git a/tests/test_base.py b/tests/test_base.py
index a16e2af8f..a7db4bf3e 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -3,7 +3,11 @@ import unittest
import unittest.mock
-from tests.base import LoggingTestCase, _CaptureLogHandler
+from tests.base import LoggingTestsMixin, _CaptureLogHandler
+
+
+class LoggingTestCase(LoggingTestsMixin, unittest.TestCase):
+ pass
class LoggingTestCaseTests(unittest.TestCase):
@@ -18,24 +22,14 @@ class LoggingTestCaseTests(unittest.TestCase):
try:
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
pass
- except AssertionError:
+ except AssertionError: # pragma: no cover
self.fail("`self.assertNotLogs` raised an AssertionError when it should not!")
- @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs")
- def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs):
- """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly."""
- assertNotLogs.return_value = iter([None])
- assertNotLogs.side_effect = AssertionError
-
- message = "`self.assertNotLogs` raised an AssertionError when it should not!"
- with self.assertRaises(AssertionError, msg=message):
- self.test_assert_not_logs_does_not_raise_with_no_logs()
-
def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):
"""Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""
msg_regex = (
r"1 logs of DEBUG or higher were triggered on root:\n"
- r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">'
+ r'<LogRecord: tests\.test_base, [\d]+, .+[/\\]tests[/\\]test_base\.py, [\d]+, "Log!">'
)
with self.assertRaisesRegex(AssertionError, msg_regex):
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 7894e104a..81285e009 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
import unittest
import unittest.mock
@@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):
with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):
asyncio.run(coroutine_object)
+ def test_user_mock_uses_explicitly_passed_mention_attribute(self):
+ """MockUser should use an explicitly passed value for user.mention."""
+ user = helpers.MockUser(mention="hello")
+ self.assertEqual(user.mention, "hello")
+
class MockObjectTests(unittest.TestCase):
"""Tests the mock objects and mixins we've defined."""
@@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase):
attribute = getattr(mock, valid_attribute)
self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
- def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
- """Test if all coroutine functions are extracted, but not regular methods or attributes."""
- class CoroutineDonor:
- def __init__(self):
- self.some_attribute = 'alpha'
-
- async def first_coroutine():
- """This coroutine function should be extracted."""
-
- async def second_coroutine():
- """This coroutine function should be extracted."""
-
- def regular_method():
- """This regular function should not be extracted."""
-
- class Receiver:
+ def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self):
+ """The CustomMockMixin should mock async magic methods with an AsyncMock."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
pass
- donor = CoroutineDonor()
- receiver = Receiver()
-
- helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
-
- self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
- self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
- self.assertFalse(hasattr(receiver, 'regular_method'))
- self.assertFalse(hasattr(receiver, 'some_attribute'))
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- spec_set = "pydis"
-
- helpers.CustomMockMixin(spec_set=spec_set)
-
- extract_method_mock.assert_called_once_with(spec_set)
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- helpers.CustomMockMixin()
-
- extract_method_mock.assert_not_called()
-
- def test_async_mock_provides_coroutine_for_dunder_call(self):
- """Test if AsyncMock objects have a coroutine for their __call__ method."""
- async_mock = helpers.AsyncMock()
- self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__))
-
- coroutine = async_mock()
- self.assertTrue(inspect.iscoroutine(coroutine))
- self.assertIsNotNone(asyncio.run(coroutine))
-
- def test_async_test_decorator_allows_synchronous_call_to_async_def(self):
- """Test if the `async_test` decorator allows an `async def` to be called synchronously."""
- @helpers.async_test
- async def kosayoda():
- return "return value"
-
- self.assertEqual(kosayoda(), "return value")
+ mock = MyMock()
+ self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock)
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
deleted file mode 100644
index 4baa6395c..000000000
--- a/tests/utils/test_time.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import asyncio
-from datetime import datetime, timezone
-from unittest.mock import patch
-
-import pytest
-from dateutil.relativedelta import relativedelta
-
-from bot.utils import time
-from tests.helpers import AsyncMock
-
-
- ('delta', 'precision', 'max_units', 'expected'),
- (
- (relativedelta(days=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
- (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
-
- # Does not abort for unknown units, as the unit name is checked
- # against the attribute of the relativedelta instance.
- (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'),
-
- # Very high maximum units, but it only ever iterates over
- # each value the relativedelta might have.
- (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'),
- )
-)
-def test_humanize_delta(
- delta: relativedelta,
- precision: str,
- max_units: int,
- expected: str
-):
- assert time.humanize_delta(delta, precision, max_units) == expected
-
-
[email protected]('max_units', (-1, 0))
-def test_humanize_delta_raises_for_invalid_max_units(max_units: int):
- with pytest.raises(ValueError, match='max_units must be positive'):
- time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
-
-
- ('stamp', 'expected'),
- (
- ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)),
- )
-)
-def test_parse_rfc1123(stamp: str, expected: str):
- assert time.parse_rfc1123(stamp) == expected
-
-
-@patch('asyncio.sleep', new_callable=AsyncMock)
-def test_wait_until(sleep_patch):
- start = datetime(2019, 1, 1, 0, 0)
- then = datetime(2019, 1, 1, 0, 10)
-
- # No return value
- assert asyncio.run(time.wait_until(then, start)) is None
-
- sleep_patch.assert_called_once_with(10 * 60)