aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar mathsman5133 <[email protected]>2020-03-12 20:15:28 +1100
committerGravatar mathsman5133 <[email protected]>2020-03-12 20:15:28 +1100
commitf16d1bb143defd01cd988db007ed5a94b598c4ba (patch)
tree332d9dc042fc636ec9d3e91607620f2c5b762eec /tests
parentApply suggestions from Mark's code review. (diff)
parentMerge pull request #822 from python-discord/bug/mod/792/null-attachments (diff)
Merge branch 'master' of https://github.com/python-discord/bot into help-refactor
 Conflicts:  bot/cogs/help.py
Diffstat (limited to 'tests')
-rw-r--r--tests/README.md10
-rw-r--r--tests/base.py44
-rw-r--r--tests/bot/cogs/sync/test_base.py403
-rw-r--r--tests/bot/cogs/sync/test_cog.py368
-rw-r--r--tests/bot/cogs/sync/test_roles.py279
-rw-r--r--tests/bot/cogs/sync/test_users.py232
-rw-r--r--tests/bot/cogs/test_duck_pond.py39
-rw-r--r--tests/bot/cogs/test_information.py63
-rw-r--r--tests/bot/cogs/test_snekbox.py354
-rw-r--r--tests/bot/cogs/test_token_remover.py4
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py101
-rw-r--r--tests/bot/rules/test_burst.py54
-rw-r--r--tests/bot/rules/test_burst_shared.py57
-rw-r--r--tests/bot/rules/test_chars.py64
-rw-r--r--tests/bot/rules/test_discord_emojis.py52
-rw-r--r--tests/bot/rules/test_duplicates.py64
-rw-r--r--tests/bot/rules/test_links.py88
-rw-r--r--tests/bot/rules/test_mentions.py94
-rw-r--r--tests/bot/rules/test_newlines.py102
-rw-r--r--tests/bot/rules/test_role_mentions.py55
-rw-r--r--tests/bot/test_api.py68
-rw-r--r--tests/bot/test_converters.py2
-rw-r--r--tests/bot/test_utils.py52
-rw-r--r--tests/bot/utils/test_time.py5
-rw-r--r--tests/helpers.py238
-rw-r--r--tests/test_base.py20
-rw-r--r--tests/test_helpers.py71
-rw-r--r--tests/utils/test_time.py62
29 files changed, 2260 insertions, 861 deletions
diff --git a/tests/README.md b/tests/README.md
index be78821bf..4f62edd68 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -83,7 +83,7 @@ TagContentConverter should return correct values for valid input.
As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it.
-However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".
+However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".
To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.).
@@ -114,13 +114,13 @@ class BotCogTests(unittest.TestCase):
### Mocking coroutines
-By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
+By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
### Special mocks for some `discord.py` types
To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass.
-In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.
+In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.
These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR.
@@ -144,7 +144,7 @@ Finally, there are some considerations to make when writing tests, both for writ
### Test coverage is a starting point
-Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.
+Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.
One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it:
@@ -169,7 +169,7 @@ class FunctionsTests(unittest.TestCase):
If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch?
-The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).
+The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).
Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data:
diff --git a/tests/base.py b/tests/base.py
index 029a249ed..d99b9ac31 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -1,6 +1,12 @@
import logging
import unittest
from contextlib import contextmanager
+from typing import Dict
+
+import discord
+from discord.ext import commands
+
+from tests import helpers
class _CaptureLogHandler(logging.Handler):
@@ -16,11 +22,16 @@ class _CaptureLogHandler(logging.Handler):
self.records.append(record)
-class LoggingTestCase(unittest.TestCase):
- """TestCase subclass that adds more logging assertion tools."""
+class LoggingTestsMixin:
+ """
+ A mixin that defines additional test methods for logging behavior.
+
+ This mixin relies on the availability of the `fail` attribute defined by the
+ test classes included in Python's unittest method to signal test failure.
+ """
@contextmanager
- def assertNotLogs(self, logger=None, level=None, msg=None):
+ def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802
"""
Asserts that no logs of `level` and higher were emitted by `logger`.
@@ -65,3 +76,30 @@ class LoggingTestCase(unittest.TestCase):
standard_message = self._truncateMessage(base_message, record_message)
msg = self._formatMessage(msg, standard_message)
self.fail(msg)
+
+
+class CommandTestCase(unittest.IsolatedAsyncioTestCase):
+ """TestCase with additional assertions that are useful for testing Discord commands."""
+
+ async def assertHasPermissionsCheck( # noqa: N802
+ self,
+ cmd: commands.Command,
+ permissions: Dict[str, bool],
+ ) -> None:
+ """
+ Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`.
+
+ Every permission in `permissions` is expected to be reported as missing. In other words, do
+ not include permissions which should not raise an exception along with those which should.
+ """
+ # Invert permission values because it's more intuitive to pass to this assertion the same
+ # permissions as those given to the check decorator.
+ permissions = {k: not v for k, v in permissions.items()}
+
+ ctx = helpers.MockContext()
+ ctx.channel.permissions_for.return_value = discord.Permissions(**permissions)
+
+ with self.assertRaises(commands.MissingPermissions) as cm:
+ await cmd.can_run(ctx)
+
+ self.assertCountEqual(permissions.keys(), cm.exception.missing_perms)
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
new file mode 100644
index 000000000..6ee9dfda6
--- /dev/null
+++ b/tests/bot/cogs/sync/test_base.py
@@ -0,0 +1,403 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs.sync.syncers import Syncer, _Diff
+from tests import helpers
+
+
+class TestSyncer(Syncer):
+ """Syncer subclass with mocks for abstract methods for testing purposes."""
+
+ name = "test"
+ _get_diff = mock.AsyncMock()
+ _sync = mock.AsyncMock()
+
+
+class SyncerBaseTests(unittest.TestCase):
+ """Tests for the syncer base class."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ def test_instantiation_fails_without_abstract_methods(self):
+ """The class must have abstract methods implemented."""
+ with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"):
+ Syncer(self.bot)
+
+
+class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for sending the sync confirmation prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+
+ def mock_get_channel(self):
+ """Fixture to return a mock channel and message for when `get_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ mock_channel.send.return_value = mock_message
+ self.bot.get_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ def mock_fetch_channel(self):
+ """Fixture to return a mock channel and message for when `fetch_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ self.bot.get_channel.return_value = None
+ mock_channel.send.return_value = mock_message
+ self.bot.fetch_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ async def test_send_prompt_edits_and_returns_message(self):
+ """The given message should be edited to display the prompt and then should be returned."""
+ msg = helpers.MockMessage()
+ ret_val = await self.syncer._send_prompt(msg)
+
+ msg.edit.assert_called_once()
+ self.assertIn("content", msg.edit.call_args[1])
+ self.assertEqual(ret_val, msg)
+
+ async def test_send_prompt_gets_dev_core_channel(self):
+ """The dev-core channel should be retrieved if an extant message isn't given."""
+ subtests = (
+ (self.bot.get_channel, self.mock_get_channel),
+ (self.bot.fetch_channel, self.mock_fetch_channel),
+ )
+
+ for method, mock_ in subtests:
+ with self.subTest(method=method, msg=mock_.__name__):
+ mock_()
+ await self.syncer._send_prompt()
+
+ method.assert_called_once_with(constants.Channels.dev_core)
+
+ async def test_send_prompt_returns_none_if_channel_fetch_fails(self):
+ """None should be returned if there's an HTTPException when fetching the channel."""
+ self.bot.get_channel.return_value = None
+ self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
+
+ ret_val = await self.syncer._send_prompt()
+
+ self.assertIsNone(ret_val)
+
+ async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
+ """A new message mentioning core devs should be sent and returned if message isn't given."""
+ for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
+ with self.subTest(msg=mock_.__name__):
+ mock_channel, mock_message = mock_()
+ ret_val = await self.syncer._send_prompt()
+
+ mock_channel.send.assert_called_once()
+ self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
+ self.assertEqual(ret_val, mock_message)
+
+ async def test_send_prompt_adds_reactions(self):
+ """The message should have reactions for confirmation added."""
+ extant_message = helpers.MockMessage()
+ subtests = (
+ (extant_message, lambda: (None, extant_message)),
+ (None, self.mock_get_channel),
+ (None, self.mock_fetch_channel),
+ )
+
+ for message_arg, mock_ in subtests:
+ subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__
+
+ with self.subTest(msg=subtest_msg):
+ _, mock_message = mock_()
+ await self.syncer._send_prompt(message_arg)
+
+ calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS]
+ mock_message.add_reaction.assert_has_calls(calls)
+
+
+class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for waiting for a sync confirmation reaction on the prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+ self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers)
+
+ @staticmethod
+ def get_message_reaction(emoji):
+ """Fixture to return a mock message an reaction from the given `emoji`."""
+ message = helpers.MockMessage()
+ reaction = helpers.MockReaction(emoji=emoji, message=message)
+
+ return message, reaction
+
+ def test_reaction_check_for_valid_emoji_and_authors(self):
+ """Should return True if authors are identical or are a bot and a core dev, respectively."""
+ user_subtests = (
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMember(id=77),
+ "identical users",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "bot author and core-dev reactor",
+ ),
+ )
+
+ for emoji in self.syncer._REACTION_EMOJIS:
+ for author, user, msg in user_subtests:
+ with self.subTest(author=author, user=user, emoji=emoji, msg=msg):
+ message, reaction = self.get_message_reaction(emoji)
+ ret_val = self.syncer._reaction_check(author, message, reaction, user)
+
+ self.assertTrue(ret_val)
+
+ def test_reaction_check_for_invalid_reactions(self):
+ """Should return False for invalid reaction events."""
+ valid_emoji = self.syncer._REACTION_EMOJIS[0]
+ subtests = (
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "users are not identical",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43),
+ "reactor lacks the core-dev role",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ "reactor is a bot",
+ ),
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMessage(id=95),
+ helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)),
+ helpers.MockMember(id=77),
+ "messages are not identical",
+ ),
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction("InVaLiD"),
+ helpers.MockMember(id=77),
+ "emoji is invalid",
+ ),
+ )
+
+ for *args, msg in subtests:
+ kwargs = dict(zip(("author", "message", "reaction", "user"), args))
+ with self.subTest(**kwargs, msg=msg):
+ ret_val = self.syncer._reaction_check(*args)
+ self.assertFalse(ret_val)
+
+ async def test_wait_for_confirmation(self):
+ """The message should always be edited and only return True if the emoji is a check mark."""
+ subtests = (
+ (constants.Emojis.check_mark, True, None),
+ ("InVaLiD", False, None),
+ (None, False, TimeoutError),
+ )
+
+ for emoji, ret_val, side_effect in subtests:
+ for bot in (True, False):
+ with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot):
+ # Set up mocks
+ message = helpers.MockMessage()
+ member = helpers.MockMember(bot=bot)
+
+ self.bot.wait_for.reset_mock()
+ self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None)
+ self.bot.wait_for.side_effect = side_effect
+
+ # Call the function
+ actual_return = await self.syncer._wait_for_confirmation(member, message)
+
+ # Perform assertions
+ self.bot.wait_for.assert_called_once()
+ self.assertIn("reaction_add", self.bot.wait_for.call_args[0])
+
+ message.edit.assert_called_once()
+ kwargs = message.edit.call_args[1]
+ self.assertIn("content", kwargs)
+
+ # Core devs should only be mentioned if the author is a bot.
+ if bot:
+ self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+ else:
+ self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+
+ self.assertIs(actual_return, ret_val)
+
+
+class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for main function orchestrating the sync."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
+ self.syncer = TestSyncer(self.bot)
+
+ async def test_sync_respects_confirmation_result(self):
+ """The sync should abort if confirmation fails and continue if confirmed."""
+ mock_message = helpers.MockMessage()
+ subtests = (
+ (True, mock_message),
+ (False, None),
+ )
+
+ for confirmed, message in subtests:
+ with self.subTest(confirmed=confirmed):
+ self.syncer._sync.reset_mock()
+ self.syncer._get_diff.reset_mock()
+
+ diff = _Diff({1, 2, 3}, {4, 5}, None)
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = mock.AsyncMock(
+ return_value=(confirmed, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+
+ if confirmed:
+ self.syncer._sync.assert_called_once_with(diff)
+ else:
+ self.syncer._sync.assert_not_called()
+
+ async def test_sync_diff_size(self):
+ """The diff size should be correctly calculated."""
+ subtests = (
+ (6, _Diff({1, 2}, {3, 4}, {5, 6})),
+ (5, _Diff({1, 2, 3}, None, {4, 5})),
+ (0, _Diff(None, None, None)),
+ (0, _Diff(set(), set(), set())),
+ )
+
+ for size, diff in subtests:
+ with self.subTest(size=size, diff=diff):
+ self.syncer._get_diff.reset_mock()
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
+
+ async def test_sync_message_edited(self):
+ """The message should be edited if one was sent, even if the sync has an API error."""
+ subtests = (
+ (None, None, False),
+ (helpers.MockMessage(), None, True),
+ (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True),
+ )
+
+ for message, side_effect, should_edit in subtests:
+ with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
+ self.syncer._sync.side_effect = side_effect
+ self.syncer._get_confirmation_result = mock.AsyncMock(
+ return_value=(True, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ if should_edit:
+ message.edit.assert_called_once()
+ self.assertIn("content", message.edit.call_args[1])
+
+ async def test_sync_confirmation_context_redirect(self):
+ """If ctx is given, a new message should be sent and author should be ctx's author."""
+ mock_member = helpers.MockMember()
+ subtests = (
+ (None, self.bot.user, None),
+ (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()),
+ )
+
+ for ctx, author, message in subtests:
+ with self.subTest(ctx=ctx, author=author, message=message):
+ if ctx is not None:
+ ctx.send.return_value = message
+
+ # Make sure `_get_diff` returns a MagicMock, not an AsyncMock
+ self.syncer._get_diff.return_value = mock.MagicMock()
+
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild, ctx)
+
+ if ctx is not None:
+ ctx.send.assert_called_once()
+
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author)
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ async def test_confirmation_result_small_diff(self):
+ """Should always return True and the given message if the diff size is too small."""
+ author = helpers.MockMember()
+ expected_message = helpers.MockMessage()
+
+ for size in (3, 2): # pragma: no cover
+ with self.subTest(size=size):
+ self.syncer._send_prompt = mock.AsyncMock()
+ self.syncer._wait_for_confirmation = mock.AsyncMock()
+
+ coro = self.syncer._get_confirmation_result(size, author, expected_message)
+ result, actual_message = await coro
+
+ self.assertTrue(result)
+ self.assertEqual(actual_message, expected_message)
+ self.syncer._send_prompt.assert_not_called()
+ self.syncer._wait_for_confirmation.assert_not_called()
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ async def test_confirmation_result_large_diff(self):
+ """Should return True if confirmed and False if _send_prompt fails or aborted."""
+ author = helpers.MockMember()
+ mock_message = helpers.MockMessage()
+
+ subtests = (
+ (True, mock_message, True, "confirmed"),
+ (False, None, False, "_send_prompt failed"),
+ (False, mock_message, False, "aborted"),
+ )
+
+ for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover
+ with self.subTest(msg=msg):
+ self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)
+
+ coro = self.syncer._get_confirmation_result(4, author)
+ actual_result, actual_message = await coro
+
+ self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None
+ self.assertIs(actual_result, expected_result)
+ self.assertEqual(actual_message, expected_message)
+
+ if expected_message:
+ self.syncer._wait_for_confirmation.assert_called_once_with(
+ author, expected_message
+ )
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
new file mode 100644
index 000000000..81398c61f
--- /dev/null
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -0,0 +1,368 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs import sync
+from bot.cogs.sync.syncers import Syncer
+from tests import helpers
+from tests.base import CommandTestCase
+
+
+class SyncExtensionTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the sync extension."""
+
+ @staticmethod
+ def test_extension_setup():
+ """The Sync cog should be added."""
+ bot = helpers.MockBot()
+ sync.setup(bot)
+ bot.add_cog.assert_called_once()
+
+
+class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
+ """Base class for Sync cog tests. Sets up patches for syncers."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ self.role_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.RoleSyncer",
+ autospec=Syncer,
+ spec_set=True
+ )
+ self.user_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.UserSyncer",
+ autospec=Syncer,
+ spec_set=True
+ )
+ self.RoleSyncer = self.role_syncer_patcher.start()
+ self.UserSyncer = self.user_syncer_patcher.start()
+
+ self.cog = sync.Sync(self.bot)
+
+ def tearDown(self):
+ self.role_syncer_patcher.stop()
+ self.user_syncer_patcher.stop()
+
+ @staticmethod
+ def response_error(status: int) -> ResponseCodeError:
+ """Fixture to return a ResponseCodeError with the given status code."""
+ response = mock.MagicMock()
+ response.status = status
+
+ return ResponseCodeError(response)
+
+
+class SyncCogTests(SyncCogTestCase):
+ """Tests for the Sync cog."""
+
+ @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock)
+ def test_sync_cog_init(self, sync_guild):
+ """Should instantiate syncers and run a sync for the guild."""
+ # Reset because a Sync cog was already instantiated in setUp.
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
+ self.bot.loop.create_task = mock.MagicMock()
+
+ mock_sync_guild_coro = mock.MagicMock()
+ sync_guild.return_value = mock_sync_guild_coro
+
+ sync.Sync(self.bot)
+
+ self.RoleSyncer.assert_called_once_with(self.bot)
+ self.UserSyncer.assert_called_once_with(self.bot)
+ sync_guild.assert_called_once_with()
+ self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
+
+ async def test_sync_cog_sync_guild(self):
+ """Roles and users should be synced only if a guild is successfully retrieved."""
+ for guild in (helpers.MockGuild(), None):
+ with self.subTest(guild=guild):
+ self.bot.reset_mock()
+ self.cog.role_syncer.reset_mock()
+ self.cog.user_syncer.reset_mock()
+
+ self.bot.get_guild = mock.MagicMock(return_value=guild)
+
+ await self.cog.sync_guild()
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+
+ if guild is None:
+ self.cog.role_syncer.sync.assert_not_called()
+ self.cog.user_syncer.sync.assert_not_called()
+ else:
+ self.cog.role_syncer.sync.assert_called_once_with(guild)
+ self.cog.user_syncer.sync.assert_called_once_with(guild)
+
+ async def patch_user_helper(self, side_effect: BaseException) -> None:
+ """Helper to set a side effect for bot.api_client.patch and then assert it is called."""
+ self.bot.api_client.patch.reset_mock(side_effect=True)
+ self.bot.api_client.patch.side_effect = side_effect
+
+ user_id, updated_information = 5, {"key": 123}
+ await self.cog.patch_user(user_id, updated_information)
+
+ self.bot.api_client.patch.assert_called_once_with(
+ f"bot/users/{user_id}",
+ json=updated_information,
+ )
+
+ async def test_sync_cog_patch_user(self):
+ """A PATCH request should be sent and 404 errors ignored."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ await self.patch_user_helper(side_effect)
+
+ async def test_sync_cog_patch_user_non_404(self):
+ """A PATCH request should be sent and the error raised if it's not a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.patch_user_helper(self.response_error(500))
+
+
+class SyncCogListenerTests(SyncCogTestCase):
+ """Tests for the listeners of the Sync cog."""
+
+ def setUp(self):
+ super().setUp()
+ self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user)
+
+ async def test_sync_cog_on_guild_role_create(self):
+ """A POST request should be sent with the new role's data."""
+ self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ role = helpers.MockRole(**role_data)
+ await self.cog.on_guild_role_create(role)
+
+ self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
+
+ async def test_sync_cog_on_guild_role_delete(self):
+ """A DELETE request should be sent."""
+ self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
+
+ role = helpers.MockRole(id=99)
+ await self.cog.on_guild_role_delete(role)
+
+ self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
+
+ async def test_sync_cog_on_guild_role_update(self):
+ """A PUT request should be sent if the colour, name, permissions, or position changes."""
+ self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ subtests = (
+ (True, ("colour", "name", "permissions", "position")),
+ (False, ("hoist", "mentionable")),
+ )
+
+ for should_put, attributes in subtests:
+ for attribute in attributes:
+ with self.subTest(should_put=should_put, changed_attribute=attribute):
+ self.bot.api_client.put.reset_mock()
+
+ after_role_data = role_data.copy()
+ after_role_data[attribute] = 876
+
+ before_role = helpers.MockRole(**role_data)
+ after_role = helpers.MockRole(**after_role_data)
+
+ await self.cog.on_guild_role_update(before_role, after_role)
+
+ if should_put:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/roles/{after_role.id}",
+ json=after_role_data
+ )
+ else:
+ self.bot.api_client.put.assert_not_called()
+
+ async def test_sync_cog_on_member_remove(self):
+ """Member should patched to set in_guild as False."""
+ self.assertTrue(self.cog.on_member_remove.__cog_listener__)
+
+ member = helpers.MockMember()
+ await self.cog.on_member_remove(member)
+
+ self.cog.patch_user.assert_called_once_with(
+ member.id,
+ updated_information={"in_guild": False}
+ )
+
+ async def test_sync_cog_on_member_update_roles(self):
+ """Members should be patched if their roles have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ # Roles are intentionally unsorted.
+ before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)]
+ before_member = helpers.MockMember(roles=before_roles)
+ after_member = helpers.MockMember(roles=before_roles[1:])
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ data = {"roles": sorted(role.id for role in after_member.roles)}
+ self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
+
+ async def test_sync_cog_on_member_update_other(self):
+ """Members should not be patched if other attributes have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ subtests = (
+ ("activities", discord.Game("Pong"), discord.Game("Frogger")),
+ ("nick", "old nick", "new nick"),
+ ("status", discord.Status.online, discord.Status.offline),
+ )
+
+ for attribute, old_value, new_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ before_member = helpers.MockMember(**{attribute: old_value})
+ after_member = helpers.MockMember(**{attribute: new_value})
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ self.cog.patch_user.assert_not_called()
+
+ async def test_sync_cog_on_user_update(self):
+ """A user should be patched only if the name, discriminator, or avatar changes."""
+ self.assertTrue(self.cog.on_user_update.__cog_listener__)
+
+ before_data = {
+ "name": "old name",
+ "discriminator": "1234",
+ "avatar": "old avatar",
+ "bot": False,
+ }
+
+ subtests = (
+ (True, "name", "name", "new name", "new name"),
+ (True, "discriminator", "discriminator", "8765", 8765),
+ (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
+ (False, "bot", "bot", True, True),
+ )
+
+ for should_patch, attribute, api_field, value, api_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ after_data = before_data.copy()
+ after_data[attribute] = value
+ before_user = helpers.MockUser(**before_data)
+ after_user = helpers.MockUser(**after_data)
+
+ await self.cog.on_user_update(before_user, after_user)
+
+ if should_patch:
+ self.cog.patch_user.assert_called_once()
+
+ # Don't care if *all* keys are present; only the changed one is required
+ call_args = self.cog.patch_user.call_args
+ self.assertEqual(call_args[0][0], after_user.id)
+ self.assertIn("updated_information", call_args[1])
+
+ updated_information = call_args[1]["updated_information"]
+ self.assertIn(api_field, updated_information)
+ self.assertEqual(updated_information[api_field], api_value)
+ else:
+ self.cog.patch_user.assert_not_called()
+
+ async def on_member_join_helper(self, side_effect: Exception) -> dict:
+ """
+ Helper to set `side_effect` for on_member_join and assert a PUT request was sent.
+
+ The request data for the mock member is returned. All exceptions will be re-raised.
+ """
+ member = helpers.MockMember(
+ discriminator="1234",
+ roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)],
+ )
+
+ data = {
+ "avatar_hash": member.avatar,
+ "discriminator": int(member.discriminator),
+ "id": member.id,
+ "in_guild": True,
+ "name": member.name,
+ "roles": sorted(role.id for role in member.roles)
+ }
+
+ self.bot.api_client.put.reset_mock(side_effect=True)
+ self.bot.api_client.put.side_effect = side_effect
+
+ try:
+ await self.cog.on_member_join(member)
+ except Exception:
+ raise
+ finally:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/users/{member.id}",
+ json=data
+ )
+
+ return data
+
+ async def test_sync_cog_on_member_join(self):
+ """Should PUT user's data or POST it if the user doesn't exist."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ self.bot.api_client.post.reset_mock()
+ data = await self.on_member_join_helper(side_effect)
+
+ if side_effect:
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=data)
+ else:
+ self.bot.api_client.post.assert_not_called()
+
+ async def test_sync_cog_on_member_join_non_404(self):
+ """ResponseCodeError should be re-raised if status code isn't a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.on_member_join_helper(self.response_error(500))
+
+ self.bot.api_client.post.assert_not_called()
+
+
+class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
+ """Tests for the commands in the Sync cog."""
+
+ async def test_sync_roles_command(self):
+ """sync() should be called on the RoleSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_roles_command.callback(self.cog, ctx)
+
+ self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ async def test_sync_users_command(self):
+ """sync() should be called on the UserSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_users_command.callback(self.cog, ctx)
+
+ self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ async def test_commands_require_admin(self):
+ """The sync commands should only run if the author has the administrator permission."""
+ cmds = (
+ self.cog.sync_group,
+ self.cog.sync_roles_command,
+ self.cog.sync_users_command,
+ )
+
+ for cmd in cmds:
+ with self.subTest(cmd=cmd):
+ await self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 27ae27639..79eee98f4 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -1,126 +1,157 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import Role, get_roles_for_sync
-
-
-class GetRolesForSyncTests(unittest.TestCase):
- """Tests constructing the roles to synchronize with the site."""
-
- def test_get_roles_for_sync_empty_return_for_equal_roles(self):
- """No roles should be synced when no diff is found."""
- api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), set(), set())
- )
-
- def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self):
- """Roles to be synced are returned when non-ID attributes differ."""
- api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), guild_roles, set())
- )
-
- def test_get_roles_only_returns_roles_that_require_update(self):
- """Roles that require an update should be returned as the second tuple element."""
- api_roles = {
- Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
- guild_roles = {
- Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_new_roles_in_first_tuple_element(self):
- """Newly created roles are returned as the first tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=2)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
- set(),
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_update_and_new_roles(self):
- """Newly created and updated roles should be returned together."""
- api_roles = {
- Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
- {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_delete(self):
- """Roles to be deleted should be returned as the third tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- set(),
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
-
- def test_get_roles_returns_roles_to_delete_update_and_new_roles(self):
- """When roles were added, updated, and removed, all of them are returned properly."""
- api_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
- }
- guild_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
- Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
- {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
+import discord
+
+from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role
+from tests import helpers
+
+
+def fake_role(**kwargs):
+ """Fixture to return a dictionary representing a role with default values set."""
+ kwargs.setdefault("id", 9)
+ kwargs.setdefault("name", "fake role")
+ kwargs.setdefault("colour", 7)
+ kwargs.setdefault("permissions", 0)
+ kwargs.setdefault("position", 55)
+
+ return kwargs
+
+
+class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for determining differences between roles in the DB and roles in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*roles):
+ """Fixture to return a guild object with the given roles."""
+ guild = helpers.MockGuild()
+ guild.roles = []
+
+ for role in roles:
+ mock_role = helpers.MockRole(**role)
+ mock_role.colour = discord.Colour(role["colour"])
+ mock_role.permissions = discord.Permissions(role["permissions"])
+ guild.roles.append(mock_role)
+
+ return guild
+
+ async def test_empty_diff_for_identical_roles(self):
+ """No differences should be found if the roles in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_updated_roles(self):
+ """Only updated roles should be added to the 'updated' set of the diff."""
+ updated_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]
+ guild = self.get_guild(updated_role, fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_Role(**updated_role)}, set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_new_roles(self):
+ """Only new roles should be added to the 'created' set of the diff."""
+ new_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role(), new_role)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new_role)}, set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_deleted_roles(self):
+ """Only deleted roles should be added to the 'deleted' set of the diff."""
+ deleted_role = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [fake_role(), deleted_role]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), {_Role(**deleted_role)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_new_updated_and_deleted_roles(self):
+ """When roles are added, updated, and removed, all of them are returned properly."""
+ new = fake_role(id=41, name="new")
+ updated = fake_role(id=71, name="updated")
+ deleted = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [
+ fake_role(),
+ fake_role(id=71, name="updated name"),
+ deleted,
+ ]
+ guild = self.get_guild(fake_role(), new, updated)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the API requests that sync roles."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ async def test_sync_created_roles(self):
+ """Only POST requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(role_tuples, set(), set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/roles", json=role) for role in roles]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(roles))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ async def test_sync_updated_roles(self):
+ """Only PUT requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), role_tuples, set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ async def test_sync_deleted_roles(self):
+ """Only DELETE requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), set(), role_tuples)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]
+ self.bot.api_client.delete.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.delete.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.put.assert_not_called()
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index ccaf67490..818883012 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -1,84 +1,160 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import User, get_users_for_sync
+from bot.cogs.sync.syncers import UserSyncer, _Diff, _User
+from tests import helpers
def fake_user(**kwargs):
- kwargs.setdefault('id', 43)
- kwargs.setdefault('name', 'bob the test man')
- kwargs.setdefault('discriminator', 1337)
- kwargs.setdefault('avatar_hash', None)
- kwargs.setdefault('roles', (666,))
- kwargs.setdefault('in_guild', True)
- return User(**kwargs)
-
-
-class GetUsersForSyncTests(unittest.TestCase):
- """Tests constructing the users to synchronize with the site."""
-
- def test_get_users_for_sync_returns_nothing_for_empty_params(self):
- """When no users are given, none are returned."""
- self.assertEqual(
- get_users_for_sync({}, {}),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_nothing_for_equal_users(self):
- """When no users are updated, none are returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self):
- """When a non-ID-field differs, the user to update is returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(name='new fancy name')}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(name='new fancy name')})
- )
-
- def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self):
- """When new users join the guild, they are returned as the first tuple element."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(), 63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, set())
- )
-
- def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self):
+ """Fixture to return a dictionary representing a user with default values set."""
+ kwargs.setdefault("id", 43)
+ kwargs.setdefault("name", "bob the test man")
+ kwargs.setdefault("discriminator", 1337)
+ kwargs.setdefault("avatar_hash", None)
+ kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("in_guild", True)
+
+ return kwargs
+
+
+class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for determining differences between users in the DB and users in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*members):
+ """Fixture to return a guild object with the given members."""
+ guild = helpers.MockGuild()
+ guild.members = []
+
+ for member in members:
+ member = member.copy()
+ member["avatar"] = member.pop("avatar_hash")
+ del member["in_guild"]
+
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+
+ guild.members.append(mock_member)
+
+ return guild
+
+ async def test_empty_diff_for_no_users(self):
+ """When no users are given, an empty diff should be returned."""
+ guild = self.get_guild()
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_empty_diff_for_identical_users(self):
+ """No differences should be found if the users in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_updated_users(self):
+ """Only updated users should be added to the 'updated' set of the diff."""
+ updated_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ guild = self.get_guild(updated_user, fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**updated_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_new_users(self):
+ """Only new users should be added to the 'created' set of the diff."""
+ new_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user(), new_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- api_users = {43: fake_user(), 63: fake_user(id=63)}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(id=63, in_guild=False)})
- )
-
- def test_get_users_for_sync_updates_and_creates_users_as_needed(self):
- """When one user left and another one was updated, both are returned."""
- api_users = {43: fake_user()}
- guild_users = {63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, {fake_user(in_guild=False)})
- )
-
- def test_get_users_for_sync_does_not_duplicate_update_users(self):
- """When the API knows a user the guild doesn't, nothing is performed."""
- api_users = {43: fake_user(in_guild=False)}
- guild_users = {}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_diff_for_new_updated_and_leaving_users(self):
+ """When users are added, updated, and removed, all of them are returned properly."""
+ new_user = fake_user(id=99, name="new")
+ updated_user = fake_user(id=55, name="updated")
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ guild = self.get_guild(fake_user(), new_user, updated_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ async def test_empty_diff_for_db_users_not_in_guild(self):
+ """When the DB knows a user the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the API requests that sync users."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ async def test_sync_created_users(self):
+ """Only POST requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(user_tuples, set(), None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/users", json=user) for user in users]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(users))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ async def test_sync_updated_users(self):
+ """Only PUT requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(set(), user_tuples, None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(users))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index d07b2bce1..7e6bfc748 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import typing
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import discord
@@ -14,7 +14,7 @@ from tests import helpers
MODULE_PATH = "bot.cogs.duck_pond"
-class DuckPondTests(base.LoggingTestCase):
+class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
"""Tests for DuckPond functionality."""
@classmethod
@@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase):
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(self.cog.webhook, "dummy webhook")
@@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase):
with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(len(log_watcher.records), 1)
@@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
self.assertEqual(expected_return, actual_return)
- @helpers.async_test
async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
"""The `has_green_checkmark` method should only return `True` if one is present."""
test_cases = (
@@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):
nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
- @helpers.async_test
async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
"""The `count_ducks` method should return the number of unique staffers who gave a duck."""
test_cases = (
@@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
self.assertEqual(expected_count, actual_count)
- @helpers.async_test
async def test_relay_message_correctly_relays_content_and_attachments(self):
"""The `relay_message` method should correctly relay message content and attachments."""
send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
@@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):
)
for message, expect_webhook_call, expect_attachment_call in test_values:
- with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
- with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:
with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
await self.cog.relay_message(message)
@@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):
message.add_reaction.assert_called_once_with(self.checkmark_emoji)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):
self.cog.webhook = helpers.MockAsyncWebhook()
log = logging.getLogger("bot.cogs.duck_pond")
- for side_effect in side_effects:
+ for side_effect in side_effects: # pragma: no cover
send_attachments.side_effect = side_effect
- with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:
with self.subTest(side_effect=type(side_effect).__name__):
with self.assertNotLogs(logger=log, level=logging.ERROR):
await self.cog.relay_message(message)
self.assertEqual(send_webhook.call_count, 2)
- @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji.name = emoji_name
return payload
- @helpers.async_test
async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
"""The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
test_values = (
@@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):
return channel, message, member, payload
- @helpers.async_test
async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
"""The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
channel_id = 1234
@@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):
channel.fetch_message.reset_mock()
@patch(f"{MODULE_PATH}.DuckPond.is_staff")
- @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)
def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
"""The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
channel_id = 31415926535
@@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):
# Assert that we've made it past `self.is_staff`
is_staff.assert_called_once()
- @helpers.async_test
async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
"""The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
test_cases = (
@@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji = self.duck_pond_emoji
for duck_count, should_relay in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_relay=should_relay):
await self.cog.on_raw_reaction_add(payload)
@@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):
if should_relay:
relay_message.assert_called_once_with(message)
- @helpers.async_test
async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
"""The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
@@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):
(constants.DuckPond.threshold + 1, True),
)
for duck_count, should_re_add_checkmark in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
await self.cog.on_raw_reaction_remove(payload)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 4496a2ae0..3c26374f5 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator)
+ cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators)
def setUp(self):
"""Sets up fresh objects for each test."""
@@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
- self.cog.roles_info.can_run = helpers.AsyncMock()
+ self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
@@ -45,10 +45,9 @@ class InformationCogTests(unittest.TestCase):
_, kwargs = self.ctx.send.call_args
embed = kwargs.pop('embed')
- self.assertEqual(embed.title, "Role information")
+ self.assertEqual(embed.title, "Role information (Total 1 role)")
self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
- self.assertEqual(embed.footer.text, "Total roles: 1")
+ self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
def test_role_info_command(self):
"""Tests the `role info` command."""
@@ -72,7 +71,7 @@ class InformationCogTests(unittest.TestCase):
self.ctx.guild.roles.append([dummy_role, admin_role])
- self.cog.role_info.can_run = helpers.AsyncMock()
+ self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
@@ -125,10 +124,10 @@ class InformationCogTests(unittest.TestCase):
)
],
members=[
- *(helpers.MockMember(status='online') for _ in range(2)),
- *(helpers.MockMember(status='idle') for _ in range(1)),
- *(helpers.MockMember(status='dnd') for _ in range(4)),
- *(helpers.MockMember(status='offline') for _ in range(3)),
+ *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
+ *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
+ *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
+ *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
],
member_count=1_234,
icon_url='a-lemon.jpg',
@@ -153,9 +152,9 @@ class InformationCogTests(unittest.TestCase):
**Counts**
Members: {self.ctx.guild.member_count:,}
Roles: {len(self.ctx.guild.roles)}
- Text: 1
- Voice: 1
- Channel categories: 1
+ Category channels: 1
+ Text channels: 1
+ Voice channels: 1
**Members**
{constants.Emojis.status_online} 2
@@ -174,7 +173,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
@@ -345,10 +344,10 @@ class UserEmbedTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -360,7 +359,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Mr. Hemlock")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -372,7 +371,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -387,8 +386,8 @@ class UserEmbedTests(unittest.TestCase):
self.assertIn("&Admins", embed.description)
self.assertNotIn("&Everyone", embed.description)
- @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock)
- @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -423,7 +422,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -454,7 +453,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -467,7 +466,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
@@ -477,7 +476,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
@@ -521,7 +520,7 @@ class UserCommandTests(unittest.TestCase):
"""A regular user should not be able to use this command outside of bot-commands."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
@@ -529,11 +528,11 @@ class UserCommandTests(unittest.TestCase):
with self.assertRaises(InChannelCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -542,11 +541,11 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -555,11 +554,11 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
@@ -568,7 +567,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
new file mode 100644
index 000000000..fd9468829
--- /dev/null
+++ b/tests/bot/cogs/test_snekbox.py
@@ -0,0 +1,354 @@
+import asyncio
+import logging
+import unittest
+from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
+
+from bot.cogs import snekbox
+from bot.cogs.snekbox import Snekbox
+from bot.constants import URLs
+from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
+
+
+class SnekboxTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ """Add mocked bot and cog to the instance."""
+ self.bot = MockBot()
+ self.cog = Snekbox(bot=self.bot)
+
+ async def test_post_eval(self):
+ """Post the eval code to the URLs.snekbox_eval_api endpoint."""
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value="return")
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ self.assertEqual(await self.cog.post_eval("import random"), "return")
+ self.bot.http_session.post.assert_called_with(
+ URLs.snekbox_eval_api,
+ json={"input": "import random"},
+ raise_for_status=True
+ )
+ resp.json.assert_awaited_once()
+
+ async def test_upload_output_reject_too_long(self):
+ """Reject output longer than MAX_PASTE_LEN."""
+ result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))
+ self.assertEqual(result, "too long to upload")
+
+ async def test_upload_output(self):
+ """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
+ key = "MarkDiamond"
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value={"key": key})
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ self.assertEqual(
+ await self.cog.upload_output("My awesome output"),
+ URLs.paste_service.format(key=key)
+ )
+ self.bot.http_session.post.assert_called_with(
+ URLs.paste_service.format(key="documents"),
+ data="My awesome output",
+ raise_for_status=True
+ )
+
+ async def test_upload_output_gracefully_fallback_if_exception_during_request(self):
+ """Output upload gracefully fallback if the upload fail."""
+ resp = MagicMock()
+ resp.json = AsyncMock(side_effect=Exception)
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ log = logging.getLogger("bot.cogs.snekbox")
+ with self.assertLogs(logger=log, level='ERROR'):
+ await self.cog.upload_output('My awesome output!')
+
+ async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):
+ """Output upload gracefully fallback if there is no key entry in the response body."""
+ self.assertEqual((await self.cog.upload_output('My awesome output!')), None)
+
+ def test_prepare_input(self):
+ cases = (
+ ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'),
+ ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
+ ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
+ ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=f'Extract code from {testname}.'):
+ self.assertEqual(self.cog.prepare_input(case), expected)
+
+ def test_get_results_message(self):
+ """Return error and message according to the eval result."""
+ cases = (
+ ('ERROR', None, ('Your eval job has failed', 'ERROR')),
+ ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')),
+ ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred'))
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ @patch('bot.cogs.snekbox.Signals', side_effect=ValueError)
+ def test_get_results_message_invalid_signal(self, mock_signals: Mock):
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127', '')
+ )
+
+ @patch('bot.cogs.snekbox.Signals')
+ def test_get_results_message_valid_signal(self, mock_signals: Mock):
+ mock_signals.return_value.name = 'SIGTEST'
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127 (SIGTEST)', '')
+ )
+
+ def test_get_status_emoji(self):
+ """Return emoji according to the eval result."""
+ cases = (
+ (' ', -1, ':warning:'),
+ ('Hello world!', 0, ':white_check_mark:'),
+ ('Invalid beard size', -1, ':x:')
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ async def test_format_output(self):
+ """Test output formatting."""
+ self.cog.upload_output = AsyncMock(return_value='https://testificate.com/')
+
+ too_many_lines = (
+ '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n'
+ '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)'
+ )
+ too_long_too_many_lines = (
+ "\n".join(
+ f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1)
+ )[:1000] + "\n... (truncated - too long, too many lines)"
+ )
+
+ cases = (
+ ('', ('[No output]', None), 'No output'),
+ ('My awesome output', ('My awesome output', None), 'One line output'),
+ ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'),
+ ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'),
+ (
+ '\u202E\u202E\u202E',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect RIGHT-TO-LEFT OVERRIDE'
+ ),
+ (
+ '\u200B\u200B\u200B',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect ZERO WIDTH SPACE'
+ ),
+ ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'),
+ (
+ 'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd',
+ (too_many_lines, 'https://testificate.com/'),
+ '12 lines output'
+ ),
+ (
+ 'verylongbeard' * 100,
+ ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'),
+ '1300 characters output'
+ ),
+ (
+ ('verylongbeard' * 10 + '\n') * 15,
+ (too_long_too_many_lines, 'https://testificate.com/'),
+ '15 lines, 1965 characters output'
+ ),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=testname, case=case, expected=expected):
+ self.assertEqual(await self.cog.format_output(case), expected)
+
+ async def test_eval_command_evaluate_once(self):
+ """Test the eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock(return_value=None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
+ self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_once_with(ctx, response)
+
+ async def test_eval_command_evaluate_twice(self):
+ """Test the eval and re-eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock()
+ self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
+ self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_with(ctx, response)
+
+ async def test_eval_command_reject_two_eval_at_the_same_time(self):
+ """Test if the eval command rejects an eval if the author already have a running eval."""
+ ctx = MockContext()
+ ctx.author.id = 42
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.send = AsyncMock()
+ self.cog.jobs = (42,)
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
+ )
+
+ async def test_eval_command_call_help(self):
+ """Test if the eval command call the help command if no code is provided."""
+ ctx = MockContext()
+ ctx.invoke = AsyncMock()
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
+ ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval")
+
+ async def test_send_eval(self):
+ """Test the send_eval function."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+
+ self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('[No output]', None))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('')
+
+ async def test_send_eval_with_paste_link(self):
+ """Test the send_eval function with a too long output that generate a paste link."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.'
+ '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('Way too long beard')
+
+ async def test_send_eval_with_non_zero_eval(self):
+ """Test the send_eval function with a code returning a non-zero code."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
+ self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
+ self.cog.format_output = AsyncMock() # This function isn't called
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.format_output.assert_not_called()
+
+ @patch("bot.cogs.snekbox.partial")
+ async def test_continue_eval_does_continue(self, partial_mock):
+ """Test that the continue_eval function does continue if required conditions are met."""
+ ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
+ response = MockMessage(delete=AsyncMock())
+ new_msg = MockMessage(content='!e NewCode')
+ self.bot.wait_for.side_effect = ((None, new_msg), None)
+
+ actual = await self.cog.continue_eval(ctx, response)
+ self.assertEqual(actual, 'NewCode')
+ self.bot.wait_for.assert_has_awaits(
+ (
+ call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10),
+ call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ )
+ )
+ ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
+ ctx.message.clear_reactions.assert_called_once()
+ response.delete.assert_called_once()
+
+ async def test_continue_eval_does_not_continue(self):
+ ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))
+ self.bot.wait_for.side_effect = asyncio.TimeoutError
+
+ actual = await self.cog.continue_eval(ctx, MockMessage())
+ self.assertEqual(actual, None)
+ ctx.message.clear_reactions.assert_called_once()
+
+ def test_predicate_eval_message_edit(self):
+ """Test the predicate_eval_message_edit function."""
+ msg0 = MockMessage(id=1, content='abc')
+ msg1 = MockMessage(id=2, content='abcdef')
+ msg2 = MockMessage(id=1, content='abcdef')
+
+ cases = (
+ (msg0, msg0, False, 'same ID, same content'),
+ (msg0, msg1, False, 'different ID, different content'),
+ (msg0, msg2, True, 'same ID, different content')
+ )
+ for ctx_msg, new_msg, expected, testname in cases:
+ with self.subTest(msg=f'Messages with {testname} return {expected}'):
+ ctx = MockContext(message=ctx_msg)
+ actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg)
+ self.assertEqual(actual, expected)
+
+ def test_predicate_eval_emoji_reaction(self):
+ """Test the predicate_eval_emoji_reaction function."""
+ valid_reaction = MockReaction(message=MockMessage(id=1))
+ valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI
+ valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))
+ valid_user = MockUser(id=2)
+
+ invalid_reaction_id = MockReaction(message=MockMessage(id=42))
+ invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI
+ invalid_user_id = MockUser(id=42)
+ invalid_reaction_str = MockReaction(message=MockMessage(id=1))
+ invalid_reaction_str.__str__.return_value = ':longbeard:'
+
+ cases = (
+ (invalid_reaction_id, valid_user, False, 'invalid reaction ID'),
+ (valid_reaction, invalid_user_id, False, 'invalid user ID'),
+ (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'),
+ (valid_reaction, valid_user, True, 'matching attributes')
+ )
+ for reaction, user, expected, testname in cases:
+ with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
+ actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user)
+ self.assertEqual(actual, expected)
+
+
+class SnekboxSetupTests(unittest.TestCase):
+ """Tests setup of the `Snekbox` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ snekbox.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index a54b839d7..33d1ec170 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -1,7 +1,7 @@
import asyncio
import logging
import unittest
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
from discord import Colour
@@ -11,7 +11,7 @@ from bot.cogs.token_remover import (
setup as setup_cog,
)
from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import AsyncMock, MockBot, MockMessage
+from tests.helpers import MockBot, MockMessage
class TokenRemoverTests(unittest.TestCase):
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index e69de29bb..0d570f5a3 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -0,0 +1,76 @@
+import unittest
+from abc import ABCMeta, abstractmethod
+from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
+
+from tests.helpers import MockMessage
+
+
+class DisallowedCase(NamedTuple):
+ """Encapsulation for test cases expected to fail."""
+ recent_messages: List[MockMessage]
+ culprits: Iterable[str]
+ n_violations: int
+
+
+class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
+ """
+ Abstract class for antispam rule test cases.
+
+ Tests for specific rules should inherit from `RuleTest` and implement
+ `relevant_messages` and `get_report`. Each instance should also set the
+ `apply` and `config` attributes as necessary.
+
+ The execution of test cases can then be delegated to the `run_allowed`
+ and `run_disallowed` methods.
+ """
+
+ apply: Callable # The tested rule's apply function
+ config: Dict[str, int]
+
+ async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to pass."""
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config,
+ ):
+ self.assertIsNone(
+ await self.apply(last_message, recent_messages, self.config)
+ )
+
+ async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to fail."""
+ for case in cases:
+ recent_messages, culprits, n_violations = case
+ last_message = recent_messages[0]
+ relevant_messages = self.relevant_messages(case)
+ desired_output = (
+ self.get_report(case),
+ culprits,
+ relevant_messages,
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ n_violations=n_violations,
+ config=self.config,
+ ):
+ self.assertTupleEqual(
+ await self.apply(last_message, recent_messages, self.config),
+ desired_output,
+ )
+
+ @abstractmethod
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ """Give expected relevant messages for `case`."""
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def get_report(self, case: DisallowedCase) -> str:
+ """Give expected error report for `case`."""
+ raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index d7187f315..d7e779221 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,98 +1,69 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import attachments
-from tests.helpers import MockMessage, async_test
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_attachments: int
-
-
-def msg(author: str, total_attachments: int) -> MockMessage:
+def make_msg(author: str, total_attachments: int) -> MockMessage:
"""Builds a message with `total_attachments` attachments."""
return MockMessage(author=author, attachments=list(range(total_attachments)))
-class AttachmentRuleTests(unittest.TestCase):
+class AttachmentRuleTests(RuleTest):
"""Tests applying the `attachments` antispam rule."""
def setUp(self):
- self.config = {"max": 5}
+ self.apply = attachments.apply
+ self.config = {"max": 5, "interval": 10}
- @async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
- [msg("bob", 0), msg("bob", 0), msg("bob", 0)],
- [msg("bob", 2), msg("bob", 2)],
- [msg("bob", 2), msg("alice", 2), msg("bob", 2)],
+ [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
+ [make_msg("bob", 2), make_msg("bob", 2)],
+ [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await attachments.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- Case(
- [msg("bob", 4), msg("bob", 0), msg("bob", 6)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
("bob",),
- 10
+ 10,
),
- Case(
- [msg("bob", 4), msg("alice", 6), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
("bob",),
- 6
+ 6,
),
- Case(
- [msg("alice", 6)],
+ DisallowedCase(
+ [make_msg("alice", 6)],
("alice",),
- 6
+ 6,
),
- (
- [msg("alice", 1) for _ in range(6)],
+ DisallowedCase(
+ [make_msg("alice", 1) for _ in range(6)],
("alice",),
- 6
+ 6,
),
)
- for recent_messages, culprit, total_attachments in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
)
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- total_attachments=total_attachments,
- config=self.config
- ):
- desired_output = (
- f"sent {total_attachments} attachments in {self.config['max']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await attachments.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
new file mode 100644
index 000000000..03682966b
--- /dev/null
+++ b/tests/bot/rules/test_burst.py
@@ -0,0 +1,54 @@
+from typing import Iterable
+
+from bot.rules import burst
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with author set to `author`.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstRuleTests(RuleTest):
+ """Tests the `burst` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst.apply
+ self.config = {"max": 2, "interval": 10}
+
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("bob"), make_msg("bob")],
+ [make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
new file mode 100644
index 000000000..3275143d5
--- /dev/null
+++ b/tests/bot/rules/test_burst_shared.py
@@ -0,0 +1,57 @@
+from typing import Iterable
+
+from bot.rules import burst_shared
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with the passed arg.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstSharedRuleTests(RuleTest):
+ """Tests the `burst_shared` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst_shared.apply
+ self.config = {"max": 2, "interval": 10}
+
+ async def test_allows_messages_within_limit(self):
+ """
+ Cases that do not violate the rule.
+
+ There really isn't more to test here than a single case.
+ """
+ cases = (
+ [make_msg("spongebob"), make_msg("patrick")],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ {"bob"},
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ {"bob", "alice"},
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return case.recent_messages
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
new file mode 100644
index 000000000..f1e3c76a7
--- /dev/null
+++ b/tests/bot/rules/test_chars.py
@@ -0,0 +1,64 @@
+from typing import Iterable
+
+from bot.rules import chars
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str, n_chars: int) -> MockMessage:
+ """Build a message with arbitrary content of `n_chars` length."""
+ return MockMessage(author=author, content="A" * n_chars)
+
+
+class CharsRuleTests(RuleTest):
+ """Tests the `chars` antispam rule."""
+
+ def setUp(self):
+ self.apply = chars.apply
+ self.config = {
+ "max": 20, # Max allowed sum of chars per user
+ "interval": 10,
+ }
+
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of chars within limit."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 20)],
+ [make_msg("bob", 15), make_msg("alice", 15)],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the total amount of chars exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 21)],
+ ("bob",),
+ 21,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 15), make_msg("bob", 15)],
+ ("bob",),
+ 30,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
+ ("alice",),
+ 30,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
new file mode 100644
index 000000000..9a72723e2
--- /dev/null
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -0,0 +1,52 @@
+from typing import Iterable
+
+from bot.rules import discord_emojis
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
+
+
+def make_msg(author: str, n_emojis: int) -> MockMessage:
+ """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
+ return MockMessage(author=author, content=discord_emoji * n_emojis)
+
+
+class DiscordEmojisRuleTests(RuleTest):
+ """Tests for the `discord_emojis` antispam rule."""
+
+ def setUp(self):
+ self.apply = discord_emojis.apply
+ self.config = {"max": 2, "interval": 10}
+
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of discord emojis within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of discord emojis."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
new file mode 100644
index 000000000..9bd886a77
--- /dev/null
+++ b/tests/bot/rules/test_duplicates.py
@@ -0,0 +1,64 @@
+from typing import Iterable
+
+from bot.rules import duplicates
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str, content: str) -> MockMessage:
+ """Give a MockMessage instance with `author` and `content` attrs."""
+ return MockMessage(author=author, content=content)
+
+
+class DuplicatesRuleTests(RuleTest):
+ """Tests the `duplicates` antispam rule."""
+
+ def setUp(self):
+ self.apply = duplicates.apply
+ self.config = {"max": 2, "interval": 10}
+
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", "A"), make_msg("alice", "A")],
+ [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
+ [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with too many duplicate messages from the same author."""
+ cases = (
+ DisallowedCase(
+ [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 duplicate messages, but only 3 from bob
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 message from bob, but only 3 duplicates
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and msg.content == last_message.content
+ )
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 02a5d5501..b091bd9d7 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -1,97 +1,67 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import links
-from tests.helpers import MockMessage, async_test
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_links: int
-
-
-def msg(author: str, total_links: int) -> MockMessage:
+def make_msg(author: str, total_links: int) -> MockMessage:
"""Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
return MockMessage(author=author, content=content)
-class LinksTests(unittest.TestCase):
+class LinksTests(RuleTest):
"""Tests applying the `links` rule."""
def setUp(self):
+ self.apply = links.apply
self.config = {
"max": 2,
"interval": 10
}
- @async_test
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await links.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
- @async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
- Case(
- [msg("bob", 1), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 1), make_msg("bob", 2)],
("bob",),
3
),
- Case(
- [msg("alice", 1), msg("alice", 1), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
("alice",),
3
),
- Case(
- [msg("alice", 2), msg("bob", 3), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
("alice",),
3
)
)
- for recent_messages, culprit, total_links in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_links=total_links,
- config=self.config
- ):
- desired_output = (
- f"sent {total_links} links in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await links.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ad49ead32..6444532f2 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -1,95 +1,65 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import mentions
-from tests.helpers import MockMessage, async_test
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_mentions: int
-
-
-def msg(author: str, total_mentions: int) -> MockMessage:
+def make_msg(author: str, total_mentions: int) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
return MockMessage(author=author, mentions=list(range(total_mentions)))
-class TestMentions(unittest.TestCase):
+class TestMentions(RuleTest):
"""Tests applying the `mentions` antispam rule."""
def setUp(self):
+ self.apply = mentions.apply
self.config = {
"max": 2,
- "interval": 10
+ "interval": 10,
}
- @async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 1), msg("alice", 2)]
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 1), make_msg("alice", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await mentions.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
- @async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
- Case(
- [msg("bob", 3)],
+ DisallowedCase(
+ [make_msg("bob", 3)],
("bob",),
- 3
+ 3,
),
- Case(
- [msg("alice", 2), msg("alice", 0), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
("alice",),
- 3
+ 3,
),
- Case(
- [msg("bob", 2), msg("alice", 3), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
- 4
+ 4,
)
)
- for recent_messages, culprit, total_mentions in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_mentions=total_mentions,
- cofig=self.config
- ):
- desired_output = (
- f"sent {total_mentions} mentions in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await mentions.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
new file mode 100644
index 000000000..e35377773
--- /dev/null
+++ b/tests/bot/rules/test_newlines.py
@@ -0,0 +1,102 @@
+from typing import Iterable, List
+
+from bot.rules import newlines
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
+ """Init a MockMessage instance with `author` and content configured by `newline_groups".
+
+ Configure content by passing a list of ints, where each int `n` will generate
+ a separate group of `n` newlines.
+
+ Example:
+ newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
+ """
+ content = " ".join("\n" * n for n in newline_groups)
+ return MockMessage(author=author, content=content)
+
+
+class TotalNewlinesRuleTests(RuleTest):
+ """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {
+ "max": 5, # Max sum of newlines in relevant messages
+ "max_consecutive": 3, # Max newlines in one group, in one message
+ "interval": 10,
+ }
+
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", [])], # Single message with no newlines
+ [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
+ [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
+ [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_total(self):
+ """Cases which violate the rule by having too many newlines in total."""
+ cases = (
+ DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
+ [make_msg("alice", [2, 2]), make_msg("alice", [2])],
+ ("alice",),
+ 6,
+ ),
+ DisallowedCase( # Here we test that only alice's newlines count in the sum
+ [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
+ ("alice",),
+ 7,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} newlines in {self.config['interval']}s"
+
+
+class GroupNewlinesRuleTests(RuleTest):
+ """
+ Tests the `newlines` antispam rule against max consecutive newline violations.
+
+ As these violations yield a different error report, they require a different
+ `get_report` implementation.
+ """
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
+
+ async def test_disallows_messages_consecutive(self):
+ """Cases which violate the rule due to having too many consecutive newlines."""
+ cases = (
+ DisallowedCase( # Bob sends a group of newlines too large
+ [make_msg("bob", [4])],
+ ("bob",),
+ 4,
+ ),
+ DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
+ [make_msg("alice", [1]), make_msg("alice", [4])],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
new file mode 100644
index 000000000..26c05d527
--- /dev/null
+++ b/tests/bot/rules/test_role_mentions.py
@@ -0,0 +1,55 @@
+from typing import Iterable
+
+from bot.rules import role_mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage
+
+
+def make_msg(author: str, n_mentions: int) -> MockMessage:
+ """Build a MockMessage instance with `n_mentions` role mentions."""
+ return MockMessage(author=author, role_mentions=[None] * n_mentions)
+
+
+class RoleMentionsRuleTests(RuleTest):
+ """Tests for the `role_mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = role_mentions.apply
+ self.config = {"max": 2, "interval": 10}
+
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of role mentions within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of role mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index 5a88adc5c..99e942813 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -1,13 +1,10 @@
-import logging
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from bot import api
-from tests.base import LoggingTestCase
-from tests.helpers import async_test
-class APIClientTests(unittest.TestCase):
+class APIClientTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the bot's API client."""
@classmethod
@@ -20,7 +17,6 @@ class APIClientTests(unittest.TestCase):
"""The event loop should not be running by default."""
self.assertFalse(api.loop_is_running())
- @async_test
async def test_loop_is_running_in_async_context(self):
"""The event loop should be running in an async context."""
self.assertTrue(api.loop_is_running())
@@ -34,7 +30,7 @@ class APIClientTests(unittest.TestCase):
self.assertEqual(error.response_text, "")
self.assertIs(error.response, self.error_api_response)
- def test_responde_code_error_string_representation_default_initialization(self):
+ def test_response_code_error_string_representation_default_initialization(self):
"""Test the string representation of `ResponseCodeError` initialized without text or json."""
error = api.ResponseCodeError(response=self.error_api_response)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ")
@@ -76,61 +72,3 @@ class APIClientTests(unittest.TestCase):
response_text=text_data
)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}")
-
-
-class LoggingHandlerTests(LoggingTestCase):
- """Tests the bot's API Log Handler."""
-
- @classmethod
- def setUpClass(cls):
- cls.debug_log_record = logging.LogRecord(
- name='my.logger', level=logging.DEBUG,
- pathname='my/logger.py', lineno=666,
- msg="Lemon wins", args=(),
- exc_info=None
- )
-
- cls.trace_log_record = logging.LogRecord(
- name='my.logger', level=logging.TRACE,
- pathname='my/logger.py', lineno=666,
- msg="This will not be logged", args=(),
- exc_info=None
- )
-
- def setUp(self):
- self.log_handler = api.APILoggingHandler(None)
-
- def test_emit_appends_to_queue_with_stopped_event_loop(self):
- """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running."""
- with patch("bot.api.APILoggingHandler.ship_off") as ship_off:
- # Patch `ship_off` to ease testing against the return value of this coroutine.
- ship_off.return_value = 42
- self.log_handler.emit(self.debug_log_record)
-
- self.assertListEqual(self.log_handler.queue, [42])
-
- def test_emit_ignores_less_than_debug(self):
- """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG."""
- self.log_handler.emit(self.trace_log_record)
- self.assertListEqual(self.log_handler.queue, [])
-
- def test_schedule_queued_tasks_for_empty_queue(self):
- """`APILoggingHandler` should not schedule anything when the queue is empty."""
- with self.assertNotLogs(level=logging.DEBUG):
- self.log_handler.schedule_queued_tasks()
-
- def test_schedule_queued_tasks_for_nonempty_queue(self):
- """`APILoggingHandler` should schedule logs when the queue is not empty."""
- log = logging.getLogger("bot.api")
-
- with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task:
- self.log_handler.queue = [555]
- self.log_handler.schedule_queued_tasks()
- self.assertListEqual(self.log_handler.queue, [])
- create_task.assert_called_once_with(555)
-
- [record] = logs.records
- self.assertEqual(record.message, "Scheduled 1 pending logging tasks.")
- self.assertEqual(record.levelno, logging.DEBUG)
- self.assertEqual(record.name, 'bot.api')
- self.assertIn('via_handler', record.__dict__)
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index b2b78d9dd..1e5ca62ae 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -68,7 +68,7 @@ class ConverterTests(unittest.TestCase):
('👋', "Don't be ridiculous, you can't use that character!"),
('', "Tag names should not be empty, or filled with whitespace."),
(' ', "Tag names should not be empty, or filled with whitespace."),
- ('42', "Tag names can't be numbers."),
+ ('42', "Tag names must contain at least one letter."),
('x' * 128, "Are you insane? That's way too long!"),
)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
deleted file mode 100644
index 58ae2a81a..000000000
--- a/tests/bot/test_utils.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import unittest
-
-from bot import utils
-
-
-class CaseInsensitiveDictTests(unittest.TestCase):
- """Tests for the `CaseInsensitiveDict` container."""
-
- def test_case_insensitive_key_access(self):
- """Tests case insensitive key access and storage."""
- instance = utils.CaseInsensitiveDict()
-
- key = 'LEMON'
- value = 'trees'
-
- instance[key] = value
- self.assertIn(key, instance)
- self.assertEqual(instance.get(key), value)
- self.assertEqual(instance.get(key.casefold()), value)
- self.assertEqual(instance.pop(key.casefold()), value)
- self.assertNotIn(key, instance)
- self.assertNotIn(key.casefold(), instance)
-
- instance.setdefault(key, value)
- del instance[key]
- self.assertNotIn(key, instance)
-
- def test_initialization_from_kwargs(self):
- """Tests creating the dictionary from keyword arguments."""
- instance = utils.CaseInsensitiveDict({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')
-
- def test_update_from_other_mapping(self):
- """Tests updating the dictionary from another mapping."""
- instance = utils.CaseInsensitiveDict()
- instance.update({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')
-
-
-class ChunkTests(unittest.TestCase):
- """Tests the `chunk` method."""
-
- def test_empty_chunking(self):
- """Tests chunking on an empty iterable."""
- generator = utils.chunks(iterable=[], size=5)
- self.assertEqual(list(generator), [])
-
- def test_list_chunking(self):
- """Tests chunking a non-empty list."""
- iterable = [1, 2, 3, 4, 5]
- generator = utils.chunks(iterable=iterable, size=2)
- self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 69f35f2f5..694d3a40f 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -1,12 +1,11 @@
import asyncio
import unittest
from datetime import datetime, timezone
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
from dateutil.relativedelta import relativedelta
from bot.utils import time
-from tests.helpers import AsyncMock
class TimeTests(unittest.TestCase):
@@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):
for max_units in test_cases:
with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
- self.assertEqual(str(error), 'max_units must be positive')
+ self.assertEqual(str(error.exception), 'max_units must be positive')
def test_parse_rfc1123(self):
"""Testing parse_rfc1123."""
diff --git a/tests/helpers.py b/tests/helpers.py
index 5df796c23..8e13f0f28 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,17 +1,15 @@
from __future__ import annotations
-import asyncio
import collections
-import functools
-import inspect
import itertools
import logging
import unittest.mock
-from typing import Any, Iterable, Optional
+from typing import Iterable, Optional
import discord
from discord.ext.commands import Context
+from bot.api import APIClient
from bot.bot import Bot
@@ -25,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def async_test(wrapped):
- """
- Run a test case via asyncio.
- Example:
- >>> @async_test
- ... async def lemon_wins():
- ... assert True
- """
-
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return asyncio.run(wrapped(*args, **kwargs))
- return wrapper
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.
@@ -68,24 +51,31 @@ class CustomMockMixin:
"""
Provides common functionality for our custom Mock types.
- The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
- function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
- of making sure child mocks are instantiated with the correct class. By default, the mock of the
- children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
- `child_mock_type` on the custom mock inheriting from this mixin.
+ The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock
+ object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the
+ class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional
+ attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The
+ class method `spec_set` can be overwritten with the object that should be uses as the specification
+ for the mock.
+
+ Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to
+ implement custom behavior.
"""
child_mock_type = unittest.mock.MagicMock
discord_id = itertools.count(0)
+ spec_set = None
+ additional_spec_asyncs = None
- def __init__(self, spec_set: Any = None, **kwargs):
+ def __init__(self, **kwargs):
name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually.
- super().__init__(spec_set=spec_set, **kwargs)
+ super().__init__(spec_set=self.spec_set, **kwargs)
+
+ if self.additional_spec_asyncs:
+ self._spec_asyncs.extend(self.additional_spec_asyncs)
if name:
self.name = name
- if spec_set:
- self._extract_coroutine_methods_from_spec_instance(spec_set)
def _get_child_mock(self, **kw):
"""
@@ -99,7 +89,16 @@ class CustomMockMixin:
This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
"""
- klass = self.child_mock_type
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return unittest.mock.AsyncMock(**kw)
+
+ _type = type(self)
+ if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = unittest.mock.AsyncMock
+ else:
+ klass = self.child_mock_type
if self._mock_sealed:
attribute = "." + kw["name"] if "name" in kw else "()"
@@ -108,95 +107,6 @@ class CustomMockMixin:
return klass(**kw)
- def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
- """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
- for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
- setattr(self, name, AsyncMock())
-
-
-# TODO: Remove me in Python 3.8
-class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
-
-
-class AsyncIteratorMock:
- """
- A class to mock asynchronous iterators.
-
- This allows async for, which is used in certain Discord.py objects. For example,
- an async iterator is returned by the Reaction.users() method.
- """
-
- def __init__(self, iterable: Iterable = None):
- if iterable is None:
- iterable = []
-
- self.iter = iter(iterable)
- self.iterable = iterable
-
- self.call_count = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return next(self.iter)
- except StopIteration:
- raise StopAsyncIteration
-
- def __call__(self):
- """
- Keeps track of the number of times an instance has been called.
-
- This is useful, since it typically shows that the iterator has actually been used somewhere after we have
- instantiated the mock for an attribute that normally returns an iterator when called.
- """
- self.call_count += 1
- return self
-
- @property
- def return_value(self):
- """Makes `self.iterable` accessible as self.return_value."""
- return self.iterable
-
- @return_value.setter
- def return_value(self, iterable):
- """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
- self.iter = iter(iterable)
- self.iterable = iterable
-
- def assert_called(self):
- """Asserts if the AsyncIteratorMock instance has been called at least once."""
- if self.call_count == 0:
- raise AssertionError("Expected AsyncIteratorMock to have been called.")
-
- def assert_called_once(self):
- """Asserts if the AsyncIteratorMock instance has been called exactly once."""
- if self.call_count != 1:
- raise AssertionError(
- f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
- )
-
- def assert_not_called(self):
- """Asserts if the AsyncIteratorMock instance has not been called."""
- if self.call_count != 0:
- raise AssertionError(
- f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
- )
-
- def reset_mock(self):
- """Resets the call count, but not the return value or iterator."""
- self.call_count = 0
-
# Create a guild instance to get a realistic Mock of `discord.Guild`
guild_data = {
@@ -247,9 +157,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
+ spec_set = guild_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'members': []}
- super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -268,9 +180,23 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = role_instance
+
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1}
- super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ default_kwargs = {
+ 'id': next(self.discord_id),
+ 'name': 'role',
+ 'position': 1,
+ 'colour': discord.Colour(0xdeadbf),
+ 'permissions': discord.Permissions(),
+ }
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+
+ if isinstance(self.colour, int):
+ self.colour = discord.Colour(self.colour)
+
+ if isinstance(self.permissions, int):
+ self.permissions = discord.Permissions(self.permissions)
if 'mention' not in kwargs:
self.mention = f'&{self.name}'
@@ -293,9 +219,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = member_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -316,14 +244,26 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.User` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = user_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
+class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock APIClient objects.
+
+ Instances of this class will follow the specifications of `bot.api.APIClient` instances.
+ For more information, see the `MockGuild` docstring.
+ """
+ spec_set = APIClient
+
+
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
bot_instance.http_session = None
@@ -337,14 +277,12 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = bot_instance
+ additional_spec_asyncs = ("wait_for",)
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=bot_instance, **kwargs)
-
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- self.wait_for = AsyncMock()
+ super().__init__(**kwargs)
+ self.api_client = MockAPIClient()
# Since calling `create_task` on our MockBot does not actually schedule the coroutine object
# as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
@@ -375,10 +313,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = channel_instance
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
- super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"#{self.name}"
@@ -417,9 +356,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+ spec_set = context_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=context_instance, **kwargs)
+ super().__init__(**kwargs)
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -436,8 +376,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Attachment` instances. For
more information, see the `MockGuild` docstring.
"""
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=attachment_instance, **kwargs)
+ spec_set = attachment_instance
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
@@ -447,10 +386,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = message_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'attachments': []}
- super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -466,9 +406,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = emoji_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=emoji_instance, **kwargs)
+ super().__init__(**kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -482,9 +423,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=partial_emoji_instance, **kwargs)
+ spec_set = partial_emoji_instance
reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
@@ -497,12 +436,19 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = reaction_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=reaction_instance, **kwargs)
+ _users = kwargs.pop("users", [])
+ super().__init__(**kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
- self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+ user_iterator = unittest.mock.AsyncMock()
+ user_iterator.__aiter__.return_value = _users
+ self.users.return_value = user_iterator
+
+ self.__str__.return_value = str(self.emoji)
webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())
@@ -515,13 +461,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Webhook` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=webhook_instance, **kwargs)
-
- # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
- # as coroutines. That's why we need to set the methods manually.
- self.send = AsyncMock()
- self.edit = AsyncMock()
- self.delete = AsyncMock()
- self.execute = AsyncMock()
+ spec_set = webhook_instance
+ additional_spec_asyncs = ("send", "edit", "delete", "execute")
diff --git a/tests/test_base.py b/tests/test_base.py
index a16e2af8f..a7db4bf3e 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -3,7 +3,11 @@ import unittest
import unittest.mock
-from tests.base import LoggingTestCase, _CaptureLogHandler
+from tests.base import LoggingTestsMixin, _CaptureLogHandler
+
+
+class LoggingTestCase(LoggingTestsMixin, unittest.TestCase):
+ pass
class LoggingTestCaseTests(unittest.TestCase):
@@ -18,24 +22,14 @@ class LoggingTestCaseTests(unittest.TestCase):
try:
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
pass
- except AssertionError:
+ except AssertionError: # pragma: no cover
self.fail("`self.assertNotLogs` raised an AssertionError when it should not!")
- @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs")
- def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs):
- """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly."""
- assertNotLogs.return_value = iter([None])
- assertNotLogs.side_effect = AssertionError
-
- message = "`self.assertNotLogs` raised an AssertionError when it should not!"
- with self.assertRaises(AssertionError, msg=message):
- self.test_assert_not_logs_does_not_raise_with_no_logs()
-
def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):
"""Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""
msg_regex = (
r"1 logs of DEBUG or higher were triggered on root:\n"
- r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">'
+ r'<LogRecord: tests\.test_base, [\d]+, .+[/\\]tests[/\\]test_base\.py, [\d]+, "Log!">'
)
with self.assertRaisesRegex(AssertionError, msg_regex):
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 7894e104a..81285e009 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
import unittest
import unittest.mock
@@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):
with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):
asyncio.run(coroutine_object)
+ def test_user_mock_uses_explicitly_passed_mention_attribute(self):
+ """MockUser should use an explicitly passed value for user.mention."""
+ user = helpers.MockUser(mention="hello")
+ self.assertEqual(user.mention, "hello")
+
class MockObjectTests(unittest.TestCase):
"""Tests the mock objects and mixins we've defined."""
@@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase):
attribute = getattr(mock, valid_attribute)
self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
- def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
- """Test if all coroutine functions are extracted, but not regular methods or attributes."""
- class CoroutineDonor:
- def __init__(self):
- self.some_attribute = 'alpha'
-
- async def first_coroutine():
- """This coroutine function should be extracted."""
-
- async def second_coroutine():
- """This coroutine function should be extracted."""
-
- def regular_method():
- """This regular function should not be extracted."""
-
- class Receiver:
+ def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self):
+ """The CustomMockMixin should mock async magic methods with an AsyncMock."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
pass
- donor = CoroutineDonor()
- receiver = Receiver()
-
- helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
-
- self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
- self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
- self.assertFalse(hasattr(receiver, 'regular_method'))
- self.assertFalse(hasattr(receiver, 'some_attribute'))
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- spec_set = "pydis"
-
- helpers.CustomMockMixin(spec_set=spec_set)
-
- extract_method_mock.assert_called_once_with(spec_set)
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- helpers.CustomMockMixin()
-
- extract_method_mock.assert_not_called()
-
- def test_async_mock_provides_coroutine_for_dunder_call(self):
- """Test if AsyncMock objects have a coroutine for their __call__ method."""
- async_mock = helpers.AsyncMock()
- self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__))
-
- coroutine = async_mock()
- self.assertTrue(inspect.iscoroutine(coroutine))
- self.assertIsNotNone(asyncio.run(coroutine))
-
- def test_async_test_decorator_allows_synchronous_call_to_async_def(self):
- """Test if the `async_test` decorator allows an `async def` to be called synchronously."""
- @helpers.async_test
- async def kosayoda():
- return "return value"
-
- self.assertEqual(kosayoda(), "return value")
+ mock = MyMock()
+ self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock)
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
deleted file mode 100644
index 4baa6395c..000000000
--- a/tests/utils/test_time.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import asyncio
-from datetime import datetime, timezone
-from unittest.mock import patch
-
-import pytest
-from dateutil.relativedelta import relativedelta
-
-from bot.utils import time
-from tests.helpers import AsyncMock
-
-
- ('delta', 'precision', 'max_units', 'expected'),
- (
- (relativedelta(days=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
- (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
-
- # Does not abort for unknown units, as the unit name is checked
- # against the attribute of the relativedelta instance.
- (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'),
-
- # Very high maximum units, but it only ever iterates over
- # each value the relativedelta might have.
- (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'),
- )
-)
-def test_humanize_delta(
- delta: relativedelta,
- precision: str,
- max_units: int,
- expected: str
-):
- assert time.humanize_delta(delta, precision, max_units) == expected
-
-
[email protected]('max_units', (-1, 0))
-def test_humanize_delta_raises_for_invalid_max_units(max_units: int):
- with pytest.raises(ValueError, match='max_units must be positive'):
- time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
-
-
- ('stamp', 'expected'),
- (
- ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)),
- )
-)
-def test_parse_rfc1123(stamp: str, expected: str):
- assert time.parse_rfc1123(stamp) == expected
-
-
-@patch('asyncio.sleep', new_callable=AsyncMock)
-def test_wait_until(sleep_patch):
- start = datetime(2019, 1, 1, 0, 0)
- then = datetime(2019, 1, 1, 0, 10)
-
- # No return value
- assert asyncio.run(time.wait_until(then, start)) is None
-
- sleep_patch.assert_called_once_with(10 * 60)