aboutsummaryrefslogtreecommitdiffstats
path: root/tests/bot
diff options
context:
space:
mode:
authorGravatar Jeremiah Boby <[email protected]>2020-02-27 22:16:05 +0000
committerGravatar GitHub <[email protected]>2020-02-27 22:16:05 +0000
commitafeb92010e58569e4910c708fa63f0e808278e26 (patch)
tree1e4f9baa3be386499d52b5373a901b3f2ddf24a8 /tests/bot
parentMerge branch 'master' into spoiler-check (diff)
parentMerge pull request #798 from python-discord/bug/mod/bot-1v/infr-edit-task-cancel (diff)
Merge branch 'master' into spoiler-check
Diffstat (limited to 'tests/bot')
-rw-r--r--tests/bot/cogs/sync/test_base.py412
-rw-r--r--tests/bot/cogs/sync/test_cog.py395
-rw-r--r--tests/bot/cogs/sync/test_roles.py287
-rw-r--r--tests/bot/cogs/sync/test_users.py241
-rw-r--r--tests/bot/cogs/test_duck_pond.py584
-rw-r--r--tests/bot/cogs/test_information.py14
-rw-r--r--tests/bot/cogs/test_security.py11
-rw-r--r--tests/bot/cogs/test_token_remover.py8
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py91
-rw-r--r--tests/bot/rules/test_burst.py56
-rw-r--r--tests/bot/rules/test_burst_shared.py59
-rw-r--r--tests/bot/rules/test_chars.py66
-rw-r--r--tests/bot/rules/test_discord_emojis.py54
-rw-r--r--tests/bot/rules/test_duplicates.py66
-rw-r--r--tests/bot/rules/test_links.py94
-rw-r--r--tests/bot/rules/test_mentions.py67
-rw-r--r--tests/bot/rules/test_newlines.py105
-rw-r--r--tests/bot/rules/test_role_mentions.py57
-rw-r--r--tests/bot/test_api.py64
-rw-r--r--tests/bot/test_utils.py15
-rw-r--r--tests/bot/utils/test_time.py162
22 files changed, 2585 insertions, 399 deletions
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
new file mode 100644
index 000000000..e6a6f9688
--- /dev/null
+++ b/tests/bot/cogs/sync/test_base.py
@@ -0,0 +1,412 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs.sync.syncers import Syncer, _Diff
+from tests import helpers
+
+
+class TestSyncer(Syncer):
+ """Syncer subclass with mocks for abstract methods for testing purposes."""
+
+ name = "test"
+ _get_diff = helpers.AsyncMock()
+ _sync = helpers.AsyncMock()
+
+
+class SyncerBaseTests(unittest.TestCase):
+ """Tests for the syncer base class."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ def test_instantiation_fails_without_abstract_methods(self):
+ """The class must have abstract methods implemented."""
+ with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"):
+ Syncer(self.bot)
+
+
+class SyncerSendPromptTests(unittest.TestCase):
+ """Tests for sending the sync confirmation prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+
+ def mock_get_channel(self):
+ """Fixture to return a mock channel and message for when `get_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ mock_channel.send.return_value = mock_message
+ self.bot.get_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ def mock_fetch_channel(self):
+ """Fixture to return a mock channel and message for when `fetch_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ self.bot.get_channel.return_value = None
+ mock_channel.send.return_value = mock_message
+ self.bot.fetch_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ @helpers.async_test
+ async def test_send_prompt_edits_and_returns_message(self):
+ """The given message should be edited to display the prompt and then should be returned."""
+ msg = helpers.MockMessage()
+ ret_val = await self.syncer._send_prompt(msg)
+
+ msg.edit.assert_called_once()
+ self.assertIn("content", msg.edit.call_args[1])
+ self.assertEqual(ret_val, msg)
+
+ @helpers.async_test
+ async def test_send_prompt_gets_dev_core_channel(self):
+ """The dev-core channel should be retrieved if an extant message isn't given."""
+ subtests = (
+ (self.bot.get_channel, self.mock_get_channel),
+ (self.bot.fetch_channel, self.mock_fetch_channel),
+ )
+
+ for method, mock_ in subtests:
+ with self.subTest(method=method, msg=mock_.__name__):
+ mock_()
+ await self.syncer._send_prompt()
+
+ method.assert_called_once_with(constants.Channels.devcore)
+
+ @helpers.async_test
+ async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
+ """None should be returned if there's an HTTPException when fetching the channel."""
+ self.bot.get_channel.return_value = None
+ self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
+
+ ret_val = await self.syncer._send_prompt()
+
+ self.assertIsNone(ret_val)
+
+ @helpers.async_test
+ async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
+ """A new message mentioning core devs should be sent and returned if message isn't given."""
+ for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
+ with self.subTest(msg=mock_.__name__):
+ mock_channel, mock_message = mock_()
+ ret_val = await self.syncer._send_prompt()
+
+ mock_channel.send.assert_called_once()
+ self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
+ self.assertEqual(ret_val, mock_message)
+
+ @helpers.async_test
+ async def test_send_prompt_adds_reactions(self):
+ """The message should have reactions for confirmation added."""
+ extant_message = helpers.MockMessage()
+ subtests = (
+ (extant_message, lambda: (None, extant_message)),
+ (None, self.mock_get_channel),
+ (None, self.mock_fetch_channel),
+ )
+
+ for message_arg, mock_ in subtests:
+ subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__
+
+ with self.subTest(msg=subtest_msg):
+ _, mock_message = mock_()
+ await self.syncer._send_prompt(message_arg)
+
+ calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS]
+ mock_message.add_reaction.assert_has_calls(calls)
+
+
+class SyncerConfirmationTests(unittest.TestCase):
+ """Tests for waiting for a sync confirmation reaction on the prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+ self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer)
+
+ @staticmethod
+ def get_message_reaction(emoji):
+ """Fixture to return a mock message an reaction from the given `emoji`."""
+ message = helpers.MockMessage()
+ reaction = helpers.MockReaction(emoji=emoji, message=message)
+
+ return message, reaction
+
+ def test_reaction_check_for_valid_emoji_and_authors(self):
+ """Should return True if authors are identical or are a bot and a core dev, respectively."""
+ user_subtests = (
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMember(id=77),
+ "identical users",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "bot author and core-dev reactor",
+ ),
+ )
+
+ for emoji in self.syncer._REACTION_EMOJIS:
+ for author, user, msg in user_subtests:
+ with self.subTest(author=author, user=user, emoji=emoji, msg=msg):
+ message, reaction = self.get_message_reaction(emoji)
+ ret_val = self.syncer._reaction_check(author, message, reaction, user)
+
+ self.assertTrue(ret_val)
+
+ def test_reaction_check_for_invalid_reactions(self):
+ """Should return False for invalid reaction events."""
+ valid_emoji = self.syncer._REACTION_EMOJIS[0]
+ subtests = (
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "users are not identical",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43),
+ "reactor lacks the core-dev role",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ "reactor is a bot",
+ ),
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMessage(id=95),
+ helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)),
+ helpers.MockMember(id=77),
+ "messages are not identical",
+ ),
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction("InVaLiD"),
+ helpers.MockMember(id=77),
+ "emoji is invalid",
+ ),
+ )
+
+ for *args, msg in subtests:
+ kwargs = dict(zip(("author", "message", "reaction", "user"), args))
+ with self.subTest(**kwargs, msg=msg):
+ ret_val = self.syncer._reaction_check(*args)
+ self.assertFalse(ret_val)
+
+ @helpers.async_test
+ async def test_wait_for_confirmation(self):
+ """The message should always be edited and only return True if the emoji is a check mark."""
+ subtests = (
+ (constants.Emojis.check_mark, True, None),
+ ("InVaLiD", False, None),
+ (None, False, TimeoutError),
+ )
+
+ for emoji, ret_val, side_effect in subtests:
+ for bot in (True, False):
+ with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot):
+ # Set up mocks
+ message = helpers.MockMessage()
+ member = helpers.MockMember(bot=bot)
+
+ self.bot.wait_for.reset_mock()
+ self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None)
+ self.bot.wait_for.side_effect = side_effect
+
+ # Call the function
+ actual_return = await self.syncer._wait_for_confirmation(member, message)
+
+ # Perform assertions
+ self.bot.wait_for.assert_called_once()
+ self.assertIn("reaction_add", self.bot.wait_for.call_args[0])
+
+ message.edit.assert_called_once()
+ kwargs = message.edit.call_args[1]
+ self.assertIn("content", kwargs)
+
+ # Core devs should only be mentioned if the author is a bot.
+ if bot:
+ self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+ else:
+ self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+
+ self.assertIs(actual_return, ret_val)
+
+
+class SyncerSyncTests(unittest.TestCase):
+ """Tests for main function orchestrating the sync."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
+ self.syncer = TestSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_respects_confirmation_result(self):
+ """The sync should abort if confirmation fails and continue if confirmed."""
+ mock_message = helpers.MockMessage()
+ subtests = (
+ (True, mock_message),
+ (False, None),
+ )
+
+ for confirmed, message in subtests:
+ with self.subTest(confirmed=confirmed):
+ self.syncer._sync.reset_mock()
+ self.syncer._get_diff.reset_mock()
+
+ diff = _Diff({1, 2, 3}, {4, 5}, None)
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(confirmed, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+
+ if confirmed:
+ self.syncer._sync.assert_called_once_with(diff)
+ else:
+ self.syncer._sync.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_diff_size(self):
+ """The diff size should be correctly calculated."""
+ subtests = (
+ (6, _Diff({1, 2}, {3, 4}, {5, 6})),
+ (5, _Diff({1, 2, 3}, None, {4, 5})),
+ (0, _Diff(None, None, None)),
+ (0, _Diff(set(), set(), set())),
+ )
+
+ for size, diff in subtests:
+ with self.subTest(size=size, diff=diff):
+ self.syncer._get_diff.reset_mock()
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
+
+ @helpers.async_test
+ async def test_sync_message_edited(self):
+ """The message should be edited if one was sent, even if the sync has an API error."""
+ subtests = (
+ (None, None, False),
+ (helpers.MockMessage(), None, True),
+ (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True),
+ )
+
+ for message, side_effect, should_edit in subtests:
+ with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
+ self.syncer._sync.side_effect = side_effect
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(True, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ if should_edit:
+ message.edit.assert_called_once()
+ self.assertIn("content", message.edit.call_args[1])
+
+ @helpers.async_test
+ async def test_sync_confirmation_context_redirect(self):
+ """If ctx is given, a new message should be sent and author should be ctx's author."""
+ mock_member = helpers.MockMember()
+ subtests = (
+ (None, self.bot.user, None),
+ (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()),
+ )
+
+ for ctx, author, message in subtests:
+ with self.subTest(ctx=ctx, author=author, message=message):
+ if ctx is not None:
+ ctx.send.return_value = message
+
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild, ctx)
+
+ if ctx is not None:
+ ctx.send.assert_called_once()
+
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author)
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_small_diff(self):
+ """Should always return True and the given message if the diff size is too small."""
+ author = helpers.MockMember()
+ expected_message = helpers.MockMessage()
+
+ for size in (3, 2):
+ with self.subTest(size=size):
+ self.syncer._send_prompt = helpers.AsyncMock()
+ self.syncer._wait_for_confirmation = helpers.AsyncMock()
+
+ coro = self.syncer._get_confirmation_result(size, author, expected_message)
+ result, actual_message = await coro
+
+ self.assertTrue(result)
+ self.assertEqual(actual_message, expected_message)
+ self.syncer._send_prompt.assert_not_called()
+ self.syncer._wait_for_confirmation.assert_not_called()
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_large_diff(self):
+ """Should return True if confirmed and False if _send_prompt fails or aborted."""
+ author = helpers.MockMember()
+ mock_message = helpers.MockMessage()
+
+ subtests = (
+ (True, mock_message, True, "confirmed"),
+ (False, None, False, "_send_prompt failed"),
+ (False, mock_message, False, "aborted"),
+ )
+
+ for expected_result, expected_message, confirmed, msg in subtests:
+ with self.subTest(msg=msg):
+ self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed)
+
+ coro = self.syncer._get_confirmation_result(4, author)
+ actual_result, actual_message = await coro
+
+ self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None
+ self.assertIs(actual_result, expected_result)
+ self.assertEqual(actual_message, expected_message)
+
+ if expected_message:
+ self.syncer._wait_for_confirmation.assert_called_once_with(
+ author, expected_message
+ )
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
new file mode 100644
index 000000000..98c9afc0d
--- /dev/null
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -0,0 +1,395 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs import sync
+from bot.cogs.sync.syncers import Syncer
+from tests import helpers
+from tests.base import CommandTestCase
+
+
+class MockSyncer(helpers.CustomMockMixin, mock.MagicMock):
+ """
+ A MagicMock subclass to mock Syncer objects.
+
+ Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=Syncer, **kwargs)
+
+
+class SyncExtensionTests(unittest.TestCase):
+ """Tests for the sync extension."""
+
+ @staticmethod
+ def test_extension_setup():
+ """The Sync cog should be added."""
+ bot = helpers.MockBot()
+ sync.setup(bot)
+ bot.add_cog.assert_called_once()
+
+
+class SyncCogTestCase(unittest.TestCase):
+ """Base class for Sync cog tests. Sets up patches for syncers."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ # These patch the type. When the type is called, a MockSyncer instanced is returned.
+ # MockSyncer is needed so that our custom AsyncMock is used.
+ # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.
+ self.role_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.RoleSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.user_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.UserSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.RoleSyncer = self.role_syncer_patcher.start()
+ self.UserSyncer = self.user_syncer_patcher.start()
+
+ self.cog = sync.Sync(self.bot)
+
+ def tearDown(self):
+ self.role_syncer_patcher.stop()
+ self.user_syncer_patcher.stop()
+
+ @staticmethod
+ def response_error(status: int) -> ResponseCodeError:
+ """Fixture to return a ResponseCodeError with the given status code."""
+ response = mock.MagicMock()
+ response.status = status
+
+ return ResponseCodeError(response)
+
+
+class SyncCogTests(SyncCogTestCase):
+ """Tests for the Sync cog."""
+
+ @mock.patch.object(sync.Sync, "sync_guild")
+ def test_sync_cog_init(self, sync_guild):
+ """Should instantiate syncers and run a sync for the guild."""
+ # Reset because a Sync cog was already instantiated in setUp.
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
+ self.bot.loop.create_task.reset_mock()
+
+ mock_sync_guild_coro = mock.MagicMock()
+ sync_guild.return_value = mock_sync_guild_coro
+
+ sync.Sync(self.bot)
+
+ self.RoleSyncer.assert_called_once_with(self.bot)
+ self.UserSyncer.assert_called_once_with(self.bot)
+ sync_guild.assert_called_once_with()
+ self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
+
+ @helpers.async_test
+ async def test_sync_cog_sync_guild(self):
+ """Roles and users should be synced only if a guild is successfully retrieved."""
+ for guild in (helpers.MockGuild(), None):
+ with self.subTest(guild=guild):
+ self.bot.reset_mock()
+ self.cog.role_syncer.reset_mock()
+ self.cog.user_syncer.reset_mock()
+
+ self.bot.get_guild = mock.MagicMock(return_value=guild)
+
+ await self.cog.sync_guild()
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+
+ if guild is None:
+ self.cog.role_syncer.sync.assert_not_called()
+ self.cog.user_syncer.sync.assert_not_called()
+ else:
+ self.cog.role_syncer.sync.assert_called_once_with(guild)
+ self.cog.user_syncer.sync.assert_called_once_with(guild)
+
+ async def patch_user_helper(self, side_effect: BaseException) -> None:
+ """Helper to set a side effect for bot.api_client.patch and then assert it is called."""
+ self.bot.api_client.patch.reset_mock(side_effect=True)
+ self.bot.api_client.patch.side_effect = side_effect
+
+ user_id, updated_information = 5, {"key": 123}
+ await self.cog.patch_user(user_id, updated_information)
+
+ self.bot.api_client.patch.assert_called_once_with(
+ f"bot/users/{user_id}",
+ json=updated_information,
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user(self):
+ """A PATCH request should be sent and 404 errors ignored."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ await self.patch_user_helper(side_effect)
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user_non_404(self):
+ """A PATCH request should be sent and the error raised if it's not a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.patch_user_helper(self.response_error(500))
+
+
+class SyncCogListenerTests(SyncCogTestCase):
+ """Tests for the listeners of the Sync cog."""
+
+ def setUp(self):
+ super().setUp()
+ self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_create(self):
+ """A POST request should be sent with the new role's data."""
+ self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ role = helpers.MockRole(**role_data)
+ await self.cog.on_guild_role_create(role)
+
+ self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_delete(self):
+ """A DELETE request should be sent."""
+ self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
+
+ role = helpers.MockRole(id=99)
+ await self.cog.on_guild_role_delete(role)
+
+ self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_update(self):
+ """A PUT request should be sent if the colour, name, permissions, or position changes."""
+ self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ subtests = (
+ (True, ("colour", "name", "permissions", "position")),
+ (False, ("hoist", "mentionable")),
+ )
+
+ for should_put, attributes in subtests:
+ for attribute in attributes:
+ with self.subTest(should_put=should_put, changed_attribute=attribute):
+ self.bot.api_client.put.reset_mock()
+
+ after_role_data = role_data.copy()
+ after_role_data[attribute] = 876
+
+ before_role = helpers.MockRole(**role_data)
+ after_role = helpers.MockRole(**after_role_data)
+
+ await self.cog.on_guild_role_update(before_role, after_role)
+
+ if should_put:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/roles/{after_role.id}",
+ json=after_role_data
+ )
+ else:
+ self.bot.api_client.put.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_remove(self):
+ """Member should patched to set in_guild as False."""
+ self.assertTrue(self.cog.on_member_remove.__cog_listener__)
+
+ member = helpers.MockMember()
+ await self.cog.on_member_remove(member)
+
+ self.cog.patch_user.assert_called_once_with(
+ member.id,
+ updated_information={"in_guild": False}
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_roles(self):
+ """Members should be patched if their roles have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ # Roles are intentionally unsorted.
+ before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)]
+ before_member = helpers.MockMember(roles=before_roles)
+ after_member = helpers.MockMember(roles=before_roles[1:])
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ data = {"roles": sorted(role.id for role in after_member.roles)}
+ self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_other(self):
+ """Members should not be patched if other attributes have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ subtests = (
+ ("activities", discord.Game("Pong"), discord.Game("Frogger")),
+ ("nick", "old nick", "new nick"),
+ ("status", discord.Status.online, discord.Status.offline),
+ )
+
+ for attribute, old_value, new_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ before_member = helpers.MockMember(**{attribute: old_value})
+ after_member = helpers.MockMember(**{attribute: new_value})
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ self.cog.patch_user.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_user_update(self):
+ """A user should be patched only if the name, discriminator, or avatar changes."""
+ self.assertTrue(self.cog.on_user_update.__cog_listener__)
+
+ before_data = {
+ "name": "old name",
+ "discriminator": "1234",
+ "avatar": "old avatar",
+ "bot": False,
+ }
+
+ subtests = (
+ (True, "name", "name", "new name", "new name"),
+ (True, "discriminator", "discriminator", "8765", 8765),
+ (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
+ (False, "bot", "bot", True, True),
+ )
+
+ for should_patch, attribute, api_field, value, api_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ after_data = before_data.copy()
+ after_data[attribute] = value
+ before_user = helpers.MockUser(**before_data)
+ after_user = helpers.MockUser(**after_data)
+
+ await self.cog.on_user_update(before_user, after_user)
+
+ if should_patch:
+ self.cog.patch_user.assert_called_once()
+
+ # Don't care if *all* keys are present; only the changed one is required
+ call_args = self.cog.patch_user.call_args
+ self.assertEqual(call_args[0][0], after_user.id)
+ self.assertIn("updated_information", call_args[1])
+
+ updated_information = call_args[1]["updated_information"]
+ self.assertIn(api_field, updated_information)
+ self.assertEqual(updated_information[api_field], api_value)
+ else:
+ self.cog.patch_user.assert_not_called()
+
+ async def on_member_join_helper(self, side_effect: Exception) -> dict:
+ """
+ Helper to set `side_effect` for on_member_join and assert a PUT request was sent.
+
+ The request data for the mock member is returned. All exceptions will be re-raised.
+ """
+ member = helpers.MockMember(
+ discriminator="1234",
+ roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)],
+ )
+
+ data = {
+ "avatar_hash": member.avatar,
+ "discriminator": int(member.discriminator),
+ "id": member.id,
+ "in_guild": True,
+ "name": member.name,
+ "roles": sorted(role.id for role in member.roles)
+ }
+
+ self.bot.api_client.put.reset_mock(side_effect=True)
+ self.bot.api_client.put.side_effect = side_effect
+
+ try:
+ await self.cog.on_member_join(member)
+ except Exception:
+ raise
+ finally:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/users/{member.id}",
+ json=data
+ )
+
+ return data
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join(self):
+ """Should PUT user's data or POST it if the user doesn't exist."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ self.bot.api_client.post.reset_mock()
+ data = await self.on_member_join_helper(side_effect)
+
+ if side_effect:
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=data)
+ else:
+ self.bot.api_client.post.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join_non_404(self):
+ """ResponseCodeError should be re-raised if status code isn't a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.on_member_join_helper(self.response_error(500))
+
+ self.bot.api_client.post.assert_not_called()
+
+
+class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
+ """Tests for the commands in the Sync cog."""
+
+ @helpers.async_test
+ async def test_sync_roles_command(self):
+ """sync() should be called on the RoleSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_roles_command.callback(self.cog, ctx)
+
+ self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ @helpers.async_test
+ async def test_sync_users_command(self):
+ """sync() should be called on the UserSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_users_command.callback(self.cog, ctx)
+
+ self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ def test_commands_require_admin(self):
+ """The sync commands should only run if the author has the administrator permission."""
+ cmds = (
+ self.cog.sync_group,
+ self.cog.sync_roles_command,
+ self.cog.sync_users_command,
+ )
+
+ for cmd in cmds:
+ with self.subTest(cmd=cmd):
+ self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 27ae27639..14fb2577a 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -1,126 +1,165 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import Role, get_roles_for_sync
-
-
-class GetRolesForSyncTests(unittest.TestCase):
- """Tests constructing the roles to synchronize with the site."""
-
- def test_get_roles_for_sync_empty_return_for_equal_roles(self):
- """No roles should be synced when no diff is found."""
- api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), set(), set())
- )
-
- def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self):
- """Roles to be synced are returned when non-ID attributes differ."""
- api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), guild_roles, set())
- )
-
- def test_get_roles_only_returns_roles_that_require_update(self):
- """Roles that require an update should be returned as the second tuple element."""
- api_roles = {
- Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
- guild_roles = {
- Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_new_roles_in_first_tuple_element(self):
- """Newly created roles are returned as the first tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=2)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
- set(),
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_update_and_new_roles(self):
- """Newly created and updated roles should be returned together."""
- api_roles = {
- Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
- {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_delete(self):
- """Roles to be deleted should be returned as the third tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- set(),
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
-
- def test_get_roles_returns_roles_to_delete_update_and_new_roles(self):
- """When roles were added, updated, and removed, all of them are returned properly."""
- api_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
- }
- guild_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
- Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
- {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
+import discord
+
+from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role
+from tests import helpers
+
+
+def fake_role(**kwargs):
+ """Fixture to return a dictionary representing a role with default values set."""
+ kwargs.setdefault("id", 9)
+ kwargs.setdefault("name", "fake role")
+ kwargs.setdefault("colour", 7)
+ kwargs.setdefault("permissions", 0)
+ kwargs.setdefault("position", 55)
+
+ return kwargs
+
+
+class RoleSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between roles in the DB and roles in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*roles):
+ """Fixture to return a guild object with the given roles."""
+ guild = helpers.MockGuild()
+ guild.roles = []
+
+ for role in roles:
+ mock_role = helpers.MockRole(**role)
+ mock_role.colour = discord.Colour(role["colour"])
+ mock_role.permissions = discord.Permissions(role["permissions"])
+ guild.roles.append(mock_role)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_roles(self):
+ """No differences should be found if the roles in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_roles(self):
+ """Only updated roles should be added to the 'updated' set of the diff."""
+ updated_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]
+ guild = self.get_guild(updated_role, fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_Role(**updated_role)}, set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_roles(self):
+ """Only new roles should be added to the 'created' set of the diff."""
+ new_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role(), new_role)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new_role)}, set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_deleted_roles(self):
+ """Only deleted roles should be added to the 'deleted' set of the diff."""
+ deleted_role = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [fake_role(), deleted_role]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), {_Role(**deleted_role)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_deleted_roles(self):
+ """When roles are added, updated, and removed, all of them are returned properly."""
+ new = fake_role(id=41, name="new")
+ updated = fake_role(id=71, name="updated")
+ deleted = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [
+ fake_role(),
+ fake_role(id=71, name="updated name"),
+ deleted,
+ ]
+ guild = self.get_guild(fake_role(), new, updated)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class RoleSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync roles."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_roles(self):
+ """Only POST requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(role_tuples, set(), set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/roles", json=role) for role in roles]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(roles))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_roles(self):
+ """Only PUT requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), role_tuples, set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_deleted_roles(self):
+ """Only DELETE requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), set(), role_tuples)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]
+ self.bot.api_client.delete.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.delete.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.put.assert_not_called()
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index ccaf67490..421bf6bb6 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -1,84 +1,169 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import User, get_users_for_sync
+from bot.cogs.sync.syncers import UserSyncer, _Diff, _User
+from tests import helpers
def fake_user(**kwargs):
- kwargs.setdefault('id', 43)
- kwargs.setdefault('name', 'bob the test man')
- kwargs.setdefault('discriminator', 1337)
- kwargs.setdefault('avatar_hash', None)
- kwargs.setdefault('roles', (666,))
- kwargs.setdefault('in_guild', True)
- return User(**kwargs)
-
-
-class GetUsersForSyncTests(unittest.TestCase):
- """Tests constructing the users to synchronize with the site."""
-
- def test_get_users_for_sync_returns_nothing_for_empty_params(self):
- """When no users are given, none are returned."""
- self.assertEqual(
- get_users_for_sync({}, {}),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_nothing_for_equal_users(self):
- """When no users are updated, none are returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self):
- """When a non-ID-field differs, the user to update is returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(name='new fancy name')}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(name='new fancy name')})
- )
-
- def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self):
- """When new users join the guild, they are returned as the first tuple element."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(), 63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, set())
- )
-
- def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self):
+ """Fixture to return a dictionary representing a user with default values set."""
+ kwargs.setdefault("id", 43)
+ kwargs.setdefault("name", "bob the test man")
+ kwargs.setdefault("discriminator", 1337)
+ kwargs.setdefault("avatar_hash", None)
+ kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("in_guild", True)
+
+ return kwargs
+
+
+class UserSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between users in the DB and users in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*members):
+ """Fixture to return a guild object with the given members."""
+ guild = helpers.MockGuild()
+ guild.members = []
+
+ for member in members:
+ member = member.copy()
+ member["avatar"] = member.pop("avatar_hash")
+ del member["in_guild"]
+
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+
+ guild.members.append(mock_member)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_no_users(self):
+ """When no users are given, an empty diff should be returned."""
+ guild = self.get_guild()
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_users(self):
+ """No differences should be found if the users in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_users(self):
+ """Only updated users should be added to the 'updated' set of the diff."""
+ updated_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ guild = self.get_guild(updated_user, fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**updated_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_users(self):
+ """Only new users should be added to the 'created' set of the diff."""
+ new_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user(), new_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- api_users = {43: fake_user(), 63: fake_user(id=63)}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(id=63, in_guild=False)})
- )
-
- def test_get_users_for_sync_updates_and_creates_users_as_needed(self):
- """When one user left and another one was updated, both are returned."""
- api_users = {43: fake_user()}
- guild_users = {63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, {fake_user(in_guild=False)})
- )
-
- def test_get_users_for_sync_does_not_duplicate_update_users(self):
- """When the API knows a user the guild doesn't, nothing is performed."""
- api_users = {43: fake_user(in_guild=False)}
- guild_users = {}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_leaving_users(self):
+ """When users are added, updated, and removed, all of them are returned properly."""
+ new_user = fake_user(id=99, name="new")
+ updated_user = fake_user(id=55, name="updated")
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ guild = self.get_guild(fake_user(), new_user, updated_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_db_users_not_in_guild(self):
+ """When the DB knows a user the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class UserSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync users."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_users(self):
+ """Only POST requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(user_tuples, set(), None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/users", json=user) for user in users]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(users))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_users(self):
+ """Only PUT requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(set(), user_tuples, None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(users))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
new file mode 100644
index 000000000..5b0a3b8c3
--- /dev/null
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -0,0 +1,584 @@
+import asyncio
+import logging
+import typing
+import unittest
+from unittest.mock import MagicMock, patch
+
+import discord
+
+from bot import constants
+from bot.cogs import duck_pond
+from tests import base
+from tests import helpers
+
+MODULE_PATH = "bot.cogs.duck_pond"
+
+
+class DuckPondTests(base.LoggingTestCase):
+ """Tests for DuckPond functionality."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Sets up the objects that only have to be initialized once."""
+ cls.nonstaff_member = helpers.MockMember(name="Non-staffer")
+
+ cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0])
+ cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role])
+
+ cls.checkmark_emoji = "\N{White Heavy Check Mark}"
+ cls.thumbs_up_emoji = "\N{Thumbs Up Sign}"
+ cls.unicode_duck_emoji = "\N{Duck}"
+ cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0])
+ cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123)
+
+ def setUp(self):
+ """Sets up the objects that need to be refreshed before each test."""
+ self.bot = helpers.MockBot(user=helpers.MockMember(id=46692))
+ self.cog = duck_pond.DuckPond(bot=self.bot)
+
+ def test_duck_pond_correctly_initializes(self):
+ """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`."""
+ bot = helpers.MockBot()
+ cog = MagicMock()
+
+ duck_pond.DuckPond.__init__(cog, bot)
+
+ self.assertEqual(cog.bot, bot)
+ self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond)
+ bot.loop.create_loop.called_once_with(cog.fetch_webhook())
+
+ def test_fetch_webhook_succeeds_without_connectivity_issues(self):
+ """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute."""
+ self.bot.fetch_webhook.return_value = "dummy webhook"
+ self.cog.webhook_id = 1
+
+ asyncio.run(self.cog.fetch_webhook())
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.fetch_webhook.assert_called_once_with(1)
+ self.assertEqual(self.cog.webhook, "dummy webhook")
+
+ def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self):
+ """The `fetch_webhook` method should log an exception when it fails to fetch the webhook."""
+ self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.")
+ self.cog.webhook_id = 1
+
+ log = logging.getLogger('bot.cogs.duck_pond')
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ asyncio.run(self.cog.fetch_webhook())
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.fetch_webhook.assert_called_once_with(1)
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def test_is_staff_returns_correct_values_based_on_instance_passed(self):
+ """The `is_staff` method should return correct values based on the instance passed."""
+ test_cases = (
+ (helpers.MockUser(name="User instance"), False),
+ (helpers.MockMember(name="Member instance without staff role"), False),
+ (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True)
+ )
+
+ for user, expected_return in test_cases:
+ actual_return = self.cog.is_staff(user)
+ with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
+ self.assertEqual(expected_return, actual_return)
+
+ @helpers.async_test
+ async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
+ """The `has_green_checkmark` method should only return `True` if one is present."""
+ test_cases = (
+ (
+ "No reactions", helpers.MockMessage(), False
+ ),
+ (
+ "No green check mark reactions",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user])
+ ]),
+ False
+ ),
+ (
+ "Green check mark reaction, but not from our bot",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member])
+ ]),
+ False
+ ),
+ (
+ "Green check mark reaction, with one from the bot",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user])
+ ]),
+ True
+ )
+ )
+
+ for description, message, expected_return in test_cases:
+ actual_return = await self.cog.has_green_checkmark(message)
+ with self.subTest(
+ test_case=description,
+ expected_return=expected_return,
+ actual_return=actual_return
+ ):
+ self.assertEqual(expected_return, actual_return)
+
+ def test_send_webhook_correctly_passes_on_arguments(self):
+ """The `send_webhook` method should pass the arguments to the webhook correctly."""
+ self.cog.webhook = helpers.MockAsyncWebhook()
+
+ content = "fake content"
+ username = "fake username"
+ avatar_url = "fake avatar_url"
+ embed = "fake embed"
+
+ asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed))
+
+ self.cog.webhook.send.assert_called_once_with(
+ content=content,
+ username=username,
+ avatar_url=avatar_url,
+ embed=embed
+ )
+
+ def test_send_webhook_logs_when_sending_message_fails(self):
+ """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly."""
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.")
+
+ log = logging.getLogger('bot.cogs.duck_pond')
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ asyncio.run(self.cog.send_webhook())
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def _get_reaction(
+ self,
+ emoji: typing.Union[str, helpers.MockEmoji],
+ staff: int = 0,
+ nonstaff: int = 0
+ ) -> helpers.MockReaction:
+ staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)]
+ nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
+ return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
+
+ @helpers.async_test
+ async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
+ """The `count_ducks` method should return the number of unique staffers who gave a duck."""
+ test_cases = (
+ # Simple test cases
+ # A message without reactions should return 0
+ (
+ "No reactions",
+ helpers.MockMessage(),
+ 0
+ ),
+ # A message with a non-duck reaction from a non-staffer should return 0
+ (
+ "Non-duck reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a non-duck reaction from a staffer should return 0
+ (
+ "Non-duck reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]),
+ 0
+ ),
+ # A message with a non-duck reaction from a non-staffer and staffer should return 0
+ (
+ "Non-duck reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]),
+ 0
+ ),
+ # A message with a unicode duck reaction from a non-staffer should return 0
+ (
+ "Unicode Duck Reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a unicode duck reaction from a staffer should return 1
+ (
+ "Unicode Duck Reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]),
+ 1
+ ),
+ # A message with a unicode duck reaction from a non-staffer and staffer should return 1
+ (
+ "Unicode Duck Reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]),
+ 1
+ ),
+ # A message with a duckpond duck reaction from a non-staffer should return 0
+ (
+ "Duckpond Duck Reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a duckpond duck reaction from a staffer should return 1
+ (
+ "Duckpond Duck Reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]),
+ 1
+ ),
+ # A message with a duckpond duck reaction from a non-staffer and staffer should return 1
+ (
+ "Duckpond Duck Reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]),
+ 1
+ ),
+
+ # Complex test cases
+ # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3
+ (
+ "Duckpond Duck Reaction from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]),
+ 3
+ ),
+ # A staffer with multiple duck reactions only counts once
+ (
+ "Two different duck reactions from the same staffer",
+ helpers.MockMessage(
+ reactions=[
+ helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]),
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]),
+ ]
+ ),
+ 1
+ ),
+ # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif)
+ (
+ "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]),
+ 0
+ ),
+ # We correctly sum when multiple reactions are provided.
+ (
+ "Duckpond Duck Reaction from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(
+ reactions=[
+ self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2),
+ self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9),
+ ]
+ ),
+ 3 + 4
+ ),
+ )
+
+ for description, message, expected_count in test_cases:
+ actual_count = await self.cog.count_ducks(message)
+ with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
+ self.assertEqual(expected_count, actual_count)
+
+ @helpers.async_test
+ async def test_relay_message_correctly_relays_content_and_attachments(self):
+ """The `relay_message` method should correctly relay message content and attachments."""
+ send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
+ send_attachments_path = f"{MODULE_PATH}.send_attachments"
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+
+ test_values = (
+ (helpers.MockMessage(clean_content="", attachments=[]), False, False),
+ (helpers.MockMessage(clean_content="message", attachments=[]), True, False),
+ (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True),
+ (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True),
+ )
+
+ for message, expect_webhook_call, expect_attachment_call in test_values:
+ with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
+ await self.cog.relay_message(message)
+
+ self.assertEqual(expect_webhook_call, send_webhook.called)
+ self.assertEqual(expect_attachment_call, send_attachments.called)
+
+ message.add_reaction.assert_called_once_with(self.checkmark_emoji)
+
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
+ @helpers.async_test
+ async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
+ """The `relay_message` method should handle irretrievable attachments."""
+ message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
+ side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), ""))
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ log = logging.getLogger("bot.cogs.duck_pond")
+
+ for side_effect in side_effects:
+ send_attachments.side_effect = side_effect
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with self.subTest(side_effect=type(side_effect).__name__):
+ with self.assertNotLogs(logger=log, level=logging.ERROR):
+ await self.cog.relay_message(message)
+
+ self.assertEqual(send_webhook.call_count, 2)
+
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
+ @helpers.async_test
+ async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
+ """The `relay_message` method should handle irretrievable attachments."""
+ message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ log = logging.getLogger("bot.cogs.duck_pond")
+
+ side_effect = discord.HTTPException(MagicMock(), "")
+ send_attachments.side_effect = side_effect
+ with self.subTest(side_effect=type(side_effect).__name__):
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ await self.cog.relay_message(message)
+
+ send_webhook.assert_called_once_with(
+ content=message.clean_content,
+ username=message.author.display_name,
+ avatar_url=message.author.avatar_url
+ )
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str):
+ """Creates a mock `on_raw_reaction_add` payload with the specified emoji data."""
+ payload = MagicMock(name=label)
+ payload.emoji.is_custom_emoji.return_value = is_custom_emoji
+ payload.emoji.id = id_
+ payload.emoji.name = emoji_name
+ return payload
+
+ @helpers.async_test
+ async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
+ """The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
+ test_values = (
+ # Custom Emojis
+ (
+ self._mock_payload(
+ label="Custom Duckpond Emoji",
+ is_custom_emoji=True,
+ id_=constants.DuckPond.custom_emojis[0],
+ emoji_name=""
+ ),
+ True
+ ),
+ (
+ self._mock_payload(
+ label="Custom Non-Duckpond Emoji",
+ is_custom_emoji=True,
+ id_=123,
+ emoji_name=""
+ ),
+ False
+ ),
+ # Unicode Emojis
+ (
+ self._mock_payload(
+ label="Unicode Duck Emoji",
+ is_custom_emoji=False,
+ id_=1,
+ emoji_name=self.unicode_duck_emoji
+ ),
+ True
+ ),
+ (
+ self._mock_payload(
+ label="Unicode Non-Duck Emoji",
+ is_custom_emoji=False,
+ id_=1,
+ emoji_name=self.thumbs_up_emoji
+ ),
+ False
+ ),
+ )
+
+ for payload, expected_return in test_values:
+ actual_return = self.cog._payload_has_duckpond_emoji(payload)
+ with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return):
+ self.assertEqual(expected_return, actual_return)
+
+ @patch(f"{MODULE_PATH}.discord.utils.get")
+ @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False))
+ def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get):
+ """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji."""
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock())))
+
+ # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check
+ utils_get.assert_not_called()
+
+ def _raw_reaction_mocks(self, channel_id, message_id, user_id):
+ """Sets up mocks for tests of the `on_raw_reaction_add` event listener."""
+ channel = helpers.MockTextChannel(id=channel_id)
+ self.bot.get_all_channels.return_value = (channel,)
+
+ message = helpers.MockMessage(id=message_id)
+
+ channel.fetch_message.return_value = message
+
+ member = helpers.MockMember(id=user_id, roles=[self.staff_role])
+ message.guild.members = (member,)
+
+ payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id)
+
+ return channel, message, member, payload
+
+ @helpers.async_test
+ async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
+ """The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
+ channel_id = 1234
+ message_id = 2345
+ user_id = 3456
+
+ channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id)
+
+ test_cases = (
+ ("non-staff member", helpers.MockMember(id=user_id)),
+ ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)),
+ )
+
+ payload.emoji = self.duck_pond_emoji
+
+ for description, member in test_cases:
+ message.guild.members = (member, )
+ with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark:
+ checkmark.side_effect = AssertionError(
+ "Expected method to return before calling `self.has_green_checkmark`."
+ )
+ self.assertIsNone(await self.cog.on_raw_reaction_add(payload))
+
+ # Check that we did make it past the payload checks
+ channel.fetch_message.assert_called_once()
+ channel.fetch_message.reset_mock()
+
+ @patch(f"{MODULE_PATH}.DuckPond.is_staff")
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
+ """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
+ channel_id = 31415926535
+ message_id = 27182818284
+ user_id = 16180339887
+
+ channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id)
+
+ payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji)
+ payload.emoji.is_custom_emoji.return_value = False
+
+ message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])]
+
+ is_staff.return_value = True
+ count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`")
+
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload)))
+
+ # Assert that we've made it past `self.is_staff`
+ is_staff.assert_called_once()
+
+ @helpers.async_test
+ async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
+ """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
+ test_cases = (
+ (constants.DuckPond.threshold - 1, False),
+ (constants.DuckPond.threshold, True),
+ (constants.DuckPond.threshold + 1, True),
+ )
+
+ channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5)
+
+ payload.emoji = self.duck_pond_emoji
+
+ for duck_count, should_relay in test_cases:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ count_ducks.return_value = duck_count
+ with self.subTest(duck_count=duck_count, should_relay=should_relay):
+ await self.cog.on_raw_reaction_add(payload)
+
+ # Confirm that we've made it past counting
+ count_ducks.assert_called_once()
+
+ # Did we relay a message?
+ has_relayed = relay_message.called
+ self.assertEqual(has_relayed, should_relay)
+
+ if should_relay:
+ relay_message.assert_called_once_with(message)
+
+ @helpers.async_test
+ async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
+ """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
+ checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
+
+ message = helpers.MockMessage(id=1234)
+
+ channel = helpers.MockTextChannel(id=98765)
+ channel.fetch_message.return_value = message
+
+ self.bot.get_all_channels.return_value = (channel, )
+
+ payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark)
+
+ test_cases = (
+ (constants.DuckPond.threshold - 1, False),
+ (constants.DuckPond.threshold, True),
+ (constants.DuckPond.threshold + 1, True),
+ )
+ for duck_count, should_re_add_checkmark in test_cases:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ count_ducks.return_value = duck_count
+ with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
+ await self.cog.on_raw_reaction_remove(payload)
+
+ # Check if we fetched the message
+ channel.fetch_message.assert_called_once_with(message.id)
+
+ # Check if we actually counted the number of ducks
+ count_ducks.assert_called_once_with(message)
+
+ has_re_added_checkmark = message.add_reaction.called
+ self.assertEqual(should_re_add_checkmark, has_re_added_checkmark)
+
+ if should_re_add_checkmark:
+ message.add_reaction.assert_called_once_with(self.checkmark_emoji)
+ message.add_reaction.reset_mock()
+
+ # reset mocks
+ channel.fetch_message.reset_mock()
+ message.reset_mock()
+
+ def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self):
+ """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis."""
+ channel = helpers.MockTextChannel(id=98765)
+
+ channel.fetch_message.side_effect = AssertionError(
+ "Expected method to return before calling `channel.fetch_message`"
+ )
+
+ self.bot.get_all_channels.return_value = (channel, )
+
+ payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id)
+
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload)))
+
+ channel.fetch_message.assert_not_called()
+
+
+class DuckPondSetupTests(unittest.TestCase):
+ """Tests setup of the `DuckPond` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = helpers.MockBot()
+ duck_pond.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 4496a2ae0..deae7ebad 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase):
)
],
members=[
- *(helpers.MockMember(status='online') for _ in range(2)),
- *(helpers.MockMember(status='idle') for _ in range(1)),
- *(helpers.MockMember(status='dnd') for _ in range(4)),
- *(helpers.MockMember(status='offline') for _ in range(3)),
+ *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
+ *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
+ *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
+ *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
],
member_count=1_234,
icon_url='a-lemon.jpg',
@@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase):
**Counts**
Members: {self.ctx.guild.member_count:,}
Roles: {len(self.ctx.guild.roles)}
- Text: 1
- Voice: 1
- Channel categories: 1
+ Category channels: 1
+ Text channels: 1
+ Voice channels: 1
**Members**
{constants.Emojis.status_online} 2
diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py
index efa7a50b1..9d1a62f7e 100644
--- a/tests/bot/cogs/test_security.py
+++ b/tests/bot/cogs/test_security.py
@@ -1,4 +1,3 @@
-import logging
import unittest
from unittest.mock import MagicMock
@@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase):
"""Tests loading the `Security` cog."""
def test_security_cog_load(self):
- """Cog loading logs a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MagicMock()
- with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm:
- security.setup(bot)
- bot.add_cog.assert_called_once()
-
- [line] = cm.output
- self.assertIn("Cog loaded: Security", line)
+ security.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index 3276cf5a5..a54b839d7 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase):
"""Tests setup of the `TokenRemover` cog."""
def test_setup(self):
- """Setup of the cog should log a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MockBot()
- with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm:
- setup_cog(bot)
-
- [line] = cm.output
+ setup_cog(bot)
bot.add_cog.assert_called_once()
- self.assertIn("Cog loaded: TokenRemover", line)
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index e69de29bb..36c986fe1 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -0,0 +1,76 @@
+import unittest
+from abc import ABCMeta, abstractmethod
+from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
+
+from tests.helpers import MockMessage
+
+
+class DisallowedCase(NamedTuple):
+ """Encapsulation for test cases expected to fail."""
+ recent_messages: List[MockMessage]
+ culprits: Iterable[str]
+ n_violations: int
+
+
+class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+ """
+ Abstract class for antispam rule test cases.
+
+ Tests for specific rules should inherit from `RuleTest` and implement
+ `relevant_messages` and `get_report`. Each instance should also set the
+ `apply` and `config` attributes as necessary.
+
+ The execution of test cases can then be delegated to the `run_allowed`
+ and `run_disallowed` methods.
+ """
+
+ apply: Callable # The tested rule's apply function
+ config: Dict[str, int]
+
+ async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to pass."""
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config,
+ ):
+ self.assertIsNone(
+ await self.apply(last_message, recent_messages, self.config)
+ )
+
+ async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to fail."""
+ for case in cases:
+ recent_messages, culprits, n_violations = case
+ last_message = recent_messages[0]
+ relevant_messages = self.relevant_messages(case)
+ desired_output = (
+ self.get_report(case),
+ culprits,
+ relevant_messages,
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ n_violations=n_violations,
+ config=self.config,
+ ):
+ self.assertTupleEqual(
+ await self.apply(last_message, recent_messages, self.config),
+ desired_output,
+ )
+
+ @abstractmethod
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ """Give expected relevant messages for `case`."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_report(self, case: DisallowedCase) -> str:
+ """Give expected error report for `case`."""
+ raise NotImplementedError
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index 4bb0acf7c..e54b4b5b8 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,52 +1,71 @@
-import asyncio
-import unittest
-from dataclasses import dataclass
-from typing import Any, List
+from typing import Iterable
from bot.rules import attachments
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
-# Using `MagicMock` sadly doesn't work for this usecase
-# since it's __eq__ compares the MagicMock's ID. We just
-# want to compare the actual attributes we set.
-@dataclass
-class FakeMessage:
- author: str
- attachments: List[Any]
+def make_msg(author: str, total_attachments: int) -> MockMessage:
+ """Builds a message with `total_attachments` attachments."""
+ return MockMessage(author=author, attachments=list(range(total_attachments)))
-def msg(total_attachments: int) -> FakeMessage:
- return FakeMessage(author='lemon', attachments=list(range(total_attachments)))
+class AttachmentRuleTests(RuleTest):
+ """Tests applying the `attachments` antispam rule."""
+ def setUp(self):
+ self.apply = attachments.apply
+ self.config = {"max": 5, "interval": 10}
-class AttachmentRuleTests(unittest.TestCase):
- """Tests applying the `attachment` antispam rule."""
-
- def test_allows_messages_without_too_many_attachments(self):
+ @async_test
+ async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
- (msg(0), msg(0), msg(0)),
- (msg(2), msg(2)),
- (msg(0),),
+ [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
+ [make_msg("bob", 2), make_msg("bob", 2)],
+ [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
)
- for last_message, *recent_messages in cases:
- with self.subTest(last_message=last_message, recent_messages=recent_messages):
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- self.assertIsNone(asyncio.run(coro))
+ await self.run_allowed(cases)
- def test_disallows_messages_with_too_many_attachments(self):
+ @async_test
+ async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10),
- ((msg(6),), [msg(6)], 6),
- ((msg(1),) * 6, [msg(1)] * 6, 6),
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
+ ("bob",),
+ 10,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
+ ("bob",),
+ 6,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 6)],
+ ("alice",),
+ 6,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 1) for _ in range(6)],
+ ("alice",),
+ 6,
+ ),
)
- for messages, relevant_messages, total in cases:
- with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total):
- last_message, *recent_messages = messages
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- self.assertEqual(
- asyncio.run(coro),
- (f"sent {total} attachments in 5s", ('lemon',), relevant_messages)
- )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
+ )
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
new file mode 100644
index 000000000..72f0be0c7
--- /dev/null
+++ b/tests/bot/rules/test_burst.py
@@ -0,0 +1,56 @@
+from typing import Iterable
+
+from bot.rules import burst
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with author set to `author`.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstRuleTests(RuleTest):
+ """Tests the `burst` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("bob"), make_msg("bob")],
+ [make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
new file mode 100644
index 000000000..47367a5f8
--- /dev/null
+++ b/tests/bot/rules/test_burst_shared.py
@@ -0,0 +1,59 @@
+from typing import Iterable
+
+from bot.rules import burst_shared
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with the passed arg.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstSharedRuleTests(RuleTest):
+ """Tests the `burst_shared` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst_shared.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """
+ Cases that do not violate the rule.
+
+ There really isn't more to test here than a single case.
+ """
+ cases = (
+ [make_msg("spongebob"), make_msg("patrick")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ {"bob"},
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ {"bob", "alice"},
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return case.recent_messages
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
new file mode 100644
index 000000000..7cc36f49e
--- /dev/null
+++ b/tests/bot/rules/test_chars.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import chars
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_chars: int) -> MockMessage:
+ """Build a message with arbitrary content of `n_chars` length."""
+ return MockMessage(author=author, content="A" * n_chars)
+
+
+class CharsRuleTests(RuleTest):
+ """Tests the `chars` antispam rule."""
+
+ def setUp(self):
+ self.apply = chars.apply
+ self.config = {
+ "max": 20, # Max allowed sum of chars per user
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of chars within limit."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 20)],
+ [make_msg("bob", 15), make_msg("alice", 15)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the total amount of chars exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 21)],
+ ("bob",),
+ 21,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 15), make_msg("bob", 15)],
+ ("bob",),
+ 30,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
+ ("alice",),
+ 30,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
new file mode 100644
index 000000000..0239b0b00
--- /dev/null
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -0,0 +1,54 @@
+from typing import Iterable
+
+from bot.rules import discord_emojis
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
+
+
+def make_msg(author: str, n_emojis: int) -> MockMessage:
+ """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
+ return MockMessage(author=author, content=discord_emoji * n_emojis)
+
+
+class DiscordEmojisRuleTests(RuleTest):
+ """Tests for the `discord_emojis` antispam rule."""
+
+ def setUp(self):
+ self.apply = discord_emojis.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of discord emojis within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of discord emojis."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
new file mode 100644
index 000000000..59e0fb6ef
--- /dev/null
+++ b/tests/bot/rules/test_duplicates.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import duplicates
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, content: str) -> MockMessage:
+ """Give a MockMessage instance with `author` and `content` attrs."""
+ return MockMessage(author=author, content=content)
+
+
+class DuplicatesRuleTests(RuleTest):
+ """Tests the `duplicates` antispam rule."""
+
+ def setUp(self):
+ self.apply = duplicates.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", "A"), make_msg("alice", "A")],
+ [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
+ [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with too many duplicate messages from the same author."""
+ cases = (
+ DisallowedCase(
+ [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 duplicate messages, but only 3 from bob
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 message from bob, but only 3 duplicates
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and msg.content == last_message.content
+ )
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index be832843b..3c3f90e5f 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -1,32 +1,21 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import links
-from tests.helpers import async_test
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
-class FakeMessage(NamedTuple):
- author: str
- content: str
-
-
-class Case(NamedTuple):
- recent_messages: List[FakeMessage]
- relevant_messages: Tuple[FakeMessage]
- culprit: Tuple[str]
- total_links: int
-
-
-def msg(author: str, total_links: int) -> FakeMessage:
- """Makes a message with *total_links* links."""
+def make_msg(author: str, total_links: int) -> MockMessage:
+ """Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
- return FakeMessage(author=author, content=content)
+ return MockMessage(author=author, content=content)
-class LinksTests(unittest.TestCase):
+class LinksTests(RuleTest):
"""Tests applying the `links` rule."""
def setUp(self):
+ self.apply = links.apply
self.config = {
"max": 2,
"interval": 10
@@ -36,66 +25,45 @@ class LinksTests(unittest.TestCase):
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await links.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
- Case(
- [msg("bob", 1), msg("bob", 2)],
- (msg("bob", 1), msg("bob", 2)),
+ DisallowedCase(
+ [make_msg("bob", 1), make_msg("bob", 2)],
("bob",),
3
),
- Case(
- [msg("alice", 1), msg("alice", 1), msg("alice", 1)],
- (msg("alice", 1), msg("alice", 1), msg("alice", 1)),
+ DisallowedCase(
+ [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
("alice",),
3
),
- Case(
- [msg("alice", 2), msg("bob", 3), msg("alice", 1)],
- (msg("alice", 2), msg("alice", 1)),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
("alice",),
3
)
)
- for recent_messages, relevant_messages, culprit, total_links in cases:
- last_message = recent_messages[0]
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_links=total_links,
- config=self.config
- ):
- desired_output = (
- f"sent {total_links} links in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await links.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
new file mode 100644
index 000000000..ebcdabac6
--- /dev/null
+++ b/tests/bot/rules/test_mentions.py
@@ -0,0 +1,67 @@
+from typing import Iterable
+
+from bot.rules import mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, total_mentions: int) -> MockMessage:
+ """Makes a message with `total_mentions` mentions."""
+ return MockMessage(author=author, mentions=list(range(total_mentions)))
+
+
+class TestMentions(RuleTest):
+ """Tests applying the `mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = mentions.apply
+ self.config = {
+ "max": 2,
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_mentions_within_limit(self):
+ """Messages with an allowed amount of mentions."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 1), make_msg("alice", 2)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_mentions_exceeding_limit(self):
+ """Messages with a higher than allowed amount of mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
+ ("bob",),
+ 4,
+ )
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
new file mode 100644
index 000000000..d61c4609d
--- /dev/null
+++ b/tests/bot/rules/test_newlines.py
@@ -0,0 +1,105 @@
+from typing import Iterable, List
+
+from bot.rules import newlines
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
+ """Init a MockMessage instance with `author` and content configured by `newline_groups".
+
+ Configure content by passing a list of ints, where each int `n` will generate
+ a separate group of `n` newlines.
+
+ Example:
+ newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
+ """
+ content = " ".join("\n" * n for n in newline_groups)
+ return MockMessage(author=author, content=content)
+
+
+class TotalNewlinesRuleTests(RuleTest):
+ """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {
+ "max": 5, # Max sum of newlines in relevant messages
+ "max_consecutive": 3, # Max newlines in one group, in one message
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", [])], # Single message with no newlines
+ [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
+ [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
+ [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_total(self):
+ """Cases which violate the rule by having too many newlines in total."""
+ cases = (
+ DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
+ [make_msg("alice", [2, 2]), make_msg("alice", [2])],
+ ("alice",),
+ 6,
+ ),
+ DisallowedCase( # Here we test that only alice's newlines count in the sum
+ [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
+ ("alice",),
+ 7,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} newlines in {self.config['interval']}s"
+
+
+class GroupNewlinesRuleTests(RuleTest):
+ """
+ Tests the `newlines` antispam rule against max consecutive newline violations.
+
+ As these violations yield a different error report, they require a different
+ `get_report` implementation.
+ """
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
+
+ @async_test
+ async def test_disallows_messages_consecutive(self):
+ """Cases which violate the rule due to having too many consecutive newlines."""
+ cases = (
+ DisallowedCase( # Bob sends a group of newlines too large
+ [make_msg("bob", [4])],
+ ("bob",),
+ 4,
+ ),
+ DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
+ [make_msg("alice", [1]), make_msg("alice", [4])],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
new file mode 100644
index 000000000..b339cccf7
--- /dev/null
+++ b/tests/bot/rules/test_role_mentions.py
@@ -0,0 +1,57 @@
+from typing import Iterable
+
+from bot.rules import role_mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_mentions: int) -> MockMessage:
+ """Build a MockMessage instance with `n_mentions` role mentions."""
+ return MockMessage(author=author, role_mentions=[None] * n_mentions)
+
+
+class RoleMentionsRuleTests(RuleTest):
+ """Tests for the `role_mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = role_mentions.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of role mentions within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of role mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index 5a88adc5c..bdfcc73e4 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -1,9 +1,7 @@
-import logging
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from bot import api
-from tests.base import LoggingTestCase
from tests.helpers import async_test
@@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):
self.assertEqual(error.response_text, "")
self.assertIs(error.response, self.error_api_response)
- def test_responde_code_error_string_representation_default_initialization(self):
+ def test_response_code_error_string_representation_default_initialization(self):
"""Test the string representation of `ResponseCodeError` initialized without text or json."""
error = api.ResponseCodeError(response=self.error_api_response)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ")
@@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):
response_text=text_data
)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}")
-
-
-class LoggingHandlerTests(LoggingTestCase):
- """Tests the bot's API Log Handler."""
-
- @classmethod
- def setUpClass(cls):
- cls.debug_log_record = logging.LogRecord(
- name='my.logger', level=logging.DEBUG,
- pathname='my/logger.py', lineno=666,
- msg="Lemon wins", args=(),
- exc_info=None
- )
-
- cls.trace_log_record = logging.LogRecord(
- name='my.logger', level=logging.TRACE,
- pathname='my/logger.py', lineno=666,
- msg="This will not be logged", args=(),
- exc_info=None
- )
-
- def setUp(self):
- self.log_handler = api.APILoggingHandler(None)
-
- def test_emit_appends_to_queue_with_stopped_event_loop(self):
- """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running."""
- with patch("bot.api.APILoggingHandler.ship_off") as ship_off:
- # Patch `ship_off` to ease testing against the return value of this coroutine.
- ship_off.return_value = 42
- self.log_handler.emit(self.debug_log_record)
-
- self.assertListEqual(self.log_handler.queue, [42])
-
- def test_emit_ignores_less_than_debug(self):
- """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG."""
- self.log_handler.emit(self.trace_log_record)
- self.assertListEqual(self.log_handler.queue, [])
-
- def test_schedule_queued_tasks_for_empty_queue(self):
- """`APILoggingHandler` should not schedule anything when the queue is empty."""
- with self.assertNotLogs(level=logging.DEBUG):
- self.log_handler.schedule_queued_tasks()
-
- def test_schedule_queued_tasks_for_nonempty_queue(self):
- """`APILoggingHandler` should schedule logs when the queue is not empty."""
- log = logging.getLogger("bot.api")
-
- with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task:
- self.log_handler.queue = [555]
- self.log_handler.schedule_queued_tasks()
- self.assertListEqual(self.log_handler.queue, [])
- create_task.assert_called_once_with(555)
-
- [record] = logs.records
- self.assertEqual(record.message, "Scheduled 1 pending logging tasks.")
- self.assertEqual(record.levelno, logging.DEBUG)
- self.assertEqual(record.name, 'bot.api')
- self.assertIn('via_handler', record.__dict__)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
index 58ae2a81a..d7bcc3ba6 100644
--- a/tests/bot/test_utils.py
+++ b/tests/bot/test_utils.py
@@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):
instance = utils.CaseInsensitiveDict()
instance.update({'FOO': 'bar'})
self.assertEqual(instance['foo'], 'bar')
-
-
-class ChunkTests(unittest.TestCase):
- """Tests the `chunk` method."""
-
- def test_empty_chunking(self):
- """Tests chunking on an empty iterable."""
- generator = utils.chunks(iterable=[], size=5)
- self.assertEqual(list(generator), [])
-
- def test_list_chunking(self):
- """Tests chunking a non-empty list."""
- iterable = [1, 2, 3, 4, 5]
- generator = utils.chunks(iterable=iterable, size=2)
- self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
new file mode 100644
index 000000000..69f35f2f5
--- /dev/null
+++ b/tests/bot/utils/test_time.py
@@ -0,0 +1,162 @@
+import asyncio
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from dateutil.relativedelta import relativedelta
+
+from bot.utils import time
+from tests.helpers import AsyncMock
+
+
+class TimeTests(unittest.TestCase):
+ """Test helper functions in bot.utils.time."""
+
+ def test_humanize_delta_handle_unknown_units(self):
+ """humanize_delta should be able to handle unknown units, and will not abort."""
+ # Does not abort for unknown units, as the unit name is checked
+ # against the attribute of the relativedelta instance.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours')
+
+ def test_humanize_delta_handle_high_units(self):
+ """humanize_delta should be able to handle very high units."""
+ # Very high maximum units, but it only ever iterates over
+ # each value the relativedelta might have.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours')
+
+ def test_humanize_delta_should_normal_usage(self):
+ """Testing humanize delta."""
+ test_cases = (
+ (relativedelta(days=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
+ (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
+ )
+
+ for delta, precision, max_units, expected in test_cases:
+ with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected):
+ self.assertEqual(time.humanize_delta(delta, precision, max_units), expected)
+
+ def test_humanize_delta_raises_for_invalid_max_units(self):
+ """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units."""
+ test_cases = (-1, 0)
+
+ for max_units in test_cases:
+ with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
+ time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
+ self.assertEqual(str(error), 'max_units must be positive')
+
+ def test_parse_rfc1123(self):
+ """Testing parse_rfc1123."""
+ self.assertEqual(
+ time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'),
+ datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)
+ )
+
+ def test_format_infraction(self):
+ """Testing format_infraction."""
+ self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
+
+ @patch('asyncio.sleep', new_callable=AsyncMock)
+ def test_wait_until(self, mock):
+ """Testing wait_until."""
+ start = datetime(2019, 1, 1, 0, 0)
+ then = datetime(2019, 1, 1, 0, 10)
+
+ # No return value
+ self.assertIs(asyncio.run(time.wait_until(then, start)), None)
+
+ mock.assert_called_once_with(10 * 60)
+
+ def test_format_infraction_with_duration_none_expiry(self):
+ """format_infraction_with_duration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that date_from and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_custom_units(self):
+ """format_infraction_with_duration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6,
+ '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20,
+ '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)')
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_normal_usage(self):
+ """format_infraction_with_duration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2,
+ '2019-11-23 23:59 (9 minutes and 55 seconds)'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_until_expiration_with_duration_none_expiry(self):
+ """until_expiration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that now and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_with_duration_custom_units(self):
+ """until_expiration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes')
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_normal_usage(self):
+ """until_expiration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)