aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Matteo Bertucci <[email protected]>2019-12-13 23:11:25 +0100
committerGravatar GitHub <[email protected]>2019-12-13 23:11:25 +0100
commitbf6a576608228a4cefb10150b3c3e082ab7ccf0a (patch)
tree2d572b5780e6b26dea2b11d3d749ad6c83e4ffb6 /tests
parentSpecify assertion to be a tuple comparison (diff)
parentUse OAuth to be Reddit API compliant (#696) (diff)
Merge branch 'master' into unittest-mentions
Diffstat (limited to 'tests')
-rw-r--r--tests/README.md1
-rw-r--r--tests/bot/cogs/test_duck_pond.py584
-rw-r--r--tests/bot/cogs/test_security.py11
-rw-r--r--tests/bot/cogs/test_token_remover.py8
-rw-r--r--tests/bot/utils/test_time.py162
-rw-r--r--tests/helpers.py146
6 files changed, 894 insertions, 18 deletions
diff --git a/tests/README.md b/tests/README.md
index 6ab9bc93e..d052de2f6 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -15,6 +15,7 @@ We are using the following modules and packages for our unit tests:
To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts:
- `pipenv run test` will run `unittest` with `coverage.py`
+- `pipenv run test path/to/test.py` will run a specific test.
- `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report.
If you want a coverage report, make sure to run the tests with `pipenv run test` *first*.
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
new file mode 100644
index 000000000..d07b2bce1
--- /dev/null
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -0,0 +1,584 @@
+import asyncio
+import logging
+import typing
+import unittest
+from unittest.mock import MagicMock, patch
+
+import discord
+
+from bot import constants
+from bot.cogs import duck_pond
+from tests import base
+from tests import helpers
+
+MODULE_PATH = "bot.cogs.duck_pond"
+
+
+class DuckPondTests(base.LoggingTestCase):
+ """Tests for DuckPond functionality."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Sets up the objects that only have to be initialized once."""
+ cls.nonstaff_member = helpers.MockMember(name="Non-staffer")
+
+ cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0])
+ cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role])
+
+ cls.checkmark_emoji = "\N{White Heavy Check Mark}"
+ cls.thumbs_up_emoji = "\N{Thumbs Up Sign}"
+ cls.unicode_duck_emoji = "\N{Duck}"
+ cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0])
+ cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123)
+
+ def setUp(self):
+ """Sets up the objects that need to be refreshed before each test."""
+ self.bot = helpers.MockBot(user=helpers.MockMember(id=46692))
+ self.cog = duck_pond.DuckPond(bot=self.bot)
+
+ def test_duck_pond_correctly_initializes(self):
+ """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`."""
+ bot = helpers.MockBot()
+ cog = MagicMock()
+
+ duck_pond.DuckPond.__init__(cog, bot)
+
+ self.assertEqual(cog.bot, bot)
+ self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond)
+ bot.loop.create_loop.called_once_with(cog.fetch_webhook())
+
+ def test_fetch_webhook_succeeds_without_connectivity_issues(self):
+ """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute."""
+ self.bot.fetch_webhook.return_value = "dummy webhook"
+ self.cog.webhook_id = 1
+
+ asyncio.run(self.cog.fetch_webhook())
+
+ self.bot.wait_until_ready.assert_called_once()
+ self.bot.fetch_webhook.assert_called_once_with(1)
+ self.assertEqual(self.cog.webhook, "dummy webhook")
+
+ def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self):
+ """The `fetch_webhook` method should log an exception when it fails to fetch the webhook."""
+ self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.")
+ self.cog.webhook_id = 1
+
+ log = logging.getLogger('bot.cogs.duck_pond')
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ asyncio.run(self.cog.fetch_webhook())
+
+ self.bot.wait_until_ready.assert_called_once()
+ self.bot.fetch_webhook.assert_called_once_with(1)
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def test_is_staff_returns_correct_values_based_on_instance_passed(self):
+ """The `is_staff` method should return correct values based on the instance passed."""
+ test_cases = (
+ (helpers.MockUser(name="User instance"), False),
+ (helpers.MockMember(name="Member instance without staff role"), False),
+ (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True)
+ )
+
+ for user, expected_return in test_cases:
+ actual_return = self.cog.is_staff(user)
+ with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
+ self.assertEqual(expected_return, actual_return)
+
+ @helpers.async_test
+ async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
+ """The `has_green_checkmark` method should only return `True` if one is present."""
+ test_cases = (
+ (
+ "No reactions", helpers.MockMessage(), False
+ ),
+ (
+ "No green check mark reactions",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user])
+ ]),
+ False
+ ),
+ (
+ "Green check mark reaction, but not from our bot",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member])
+ ]),
+ False
+ ),
+ (
+ "Green check mark reaction, with one from the bot",
+ helpers.MockMessage(reactions=[
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]),
+ helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user])
+ ]),
+ True
+ )
+ )
+
+ for description, message, expected_return in test_cases:
+ actual_return = await self.cog.has_green_checkmark(message)
+ with self.subTest(
+ test_case=description,
+ expected_return=expected_return,
+ actual_return=actual_return
+ ):
+ self.assertEqual(expected_return, actual_return)
+
+ def test_send_webhook_correctly_passes_on_arguments(self):
+ """The `send_webhook` method should pass the arguments to the webhook correctly."""
+ self.cog.webhook = helpers.MockAsyncWebhook()
+
+ content = "fake content"
+ username = "fake username"
+ avatar_url = "fake avatar_url"
+ embed = "fake embed"
+
+ asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed))
+
+ self.cog.webhook.send.assert_called_once_with(
+ content=content,
+ username=username,
+ avatar_url=avatar_url,
+ embed=embed
+ )
+
+ def test_send_webhook_logs_when_sending_message_fails(self):
+ """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly."""
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.")
+
+ log = logging.getLogger('bot.cogs.duck_pond')
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ asyncio.run(self.cog.send_webhook())
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def _get_reaction(
+ self,
+ emoji: typing.Union[str, helpers.MockEmoji],
+ staff: int = 0,
+ nonstaff: int = 0
+ ) -> helpers.MockReaction:
+ staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)]
+ nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
+ return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
+
+ @helpers.async_test
+ async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
+ """The `count_ducks` method should return the number of unique staffers who gave a duck."""
+ test_cases = (
+ # Simple test cases
+ # A message without reactions should return 0
+ (
+ "No reactions",
+ helpers.MockMessage(),
+ 0
+ ),
+ # A message with a non-duck reaction from a non-staffer should return 0
+ (
+ "Non-duck reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a non-duck reaction from a staffer should return 0
+ (
+ "Non-duck reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]),
+ 0
+ ),
+ # A message with a non-duck reaction from a non-staffer and staffer should return 0
+ (
+ "Non-duck reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]),
+ 0
+ ),
+ # A message with a unicode duck reaction from a non-staffer should return 0
+ (
+ "Unicode Duck Reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a unicode duck reaction from a staffer should return 1
+ (
+ "Unicode Duck Reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]),
+ 1
+ ),
+ # A message with a unicode duck reaction from a non-staffer and staffer should return 1
+ (
+ "Unicode Duck Reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]),
+ 1
+ ),
+ # A message with a duckpond duck reaction from a non-staffer should return 0
+ (
+ "Duckpond Duck Reaction from non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]),
+ 0
+ ),
+ # A message with a duckpond duck reaction from a staffer should return 1
+ (
+ "Duckpond Duck Reaction from staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]),
+ 1
+ ),
+ # A message with a duckpond duck reaction from a non-staffer and staffer should return 1
+ (
+ "Duckpond Duck Reaction from staffer + non-staffer",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]),
+ 1
+ ),
+
+ # Complex test cases
+ # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3
+ (
+ "Duckpond Duck Reaction from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]),
+ 3
+ ),
+ # A staffer with multiple duck reactions only counts once
+ (
+ "Two different duck reactions from the same staffer",
+ helpers.MockMessage(
+ reactions=[
+ helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]),
+ helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]),
+ ]
+ ),
+ 1
+ ),
+ # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif)
+ (
+ "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]),
+ 0
+ ),
+ # We correctly sum when multiple reactions are provided.
+ (
+ "Duckpond Duck Reaction from 3 staffers + 2 non-staffers",
+ helpers.MockMessage(
+ reactions=[
+ self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2),
+ self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9),
+ ]
+ ),
+ 3 + 4
+ ),
+ )
+
+ for description, message, expected_count in test_cases:
+ actual_count = await self.cog.count_ducks(message)
+ with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
+ self.assertEqual(expected_count, actual_count)
+
+ @helpers.async_test
+ async def test_relay_message_correctly_relays_content_and_attachments(self):
+ """The `relay_message` method should correctly relay message content and attachments."""
+ send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
+ send_attachments_path = f"{MODULE_PATH}.send_attachments"
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+
+ test_values = (
+ (helpers.MockMessage(clean_content="", attachments=[]), False, False),
+ (helpers.MockMessage(clean_content="message", attachments=[]), True, False),
+ (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True),
+ (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True),
+ )
+
+ for message, expect_webhook_call, expect_attachment_call in test_values:
+ with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
+ await self.cog.relay_message(message)
+
+ self.assertEqual(expect_webhook_call, send_webhook.called)
+ self.assertEqual(expect_attachment_call, send_attachments.called)
+
+ message.add_reaction.assert_called_once_with(self.checkmark_emoji)
+
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
+ @helpers.async_test
+ async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
+ """The `relay_message` method should handle irretrievable attachments."""
+ message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
+ side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), ""))
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ log = logging.getLogger("bot.cogs.duck_pond")
+
+ for side_effect in side_effects:
+ send_attachments.side_effect = side_effect
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with self.subTest(side_effect=type(side_effect).__name__):
+ with self.assertNotLogs(logger=log, level=logging.ERROR):
+ await self.cog.relay_message(message)
+
+ self.assertEqual(send_webhook.call_count, 2)
+
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
+ @helpers.async_test
+ async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
+ """The `relay_message` method should handle irretrievable attachments."""
+ message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
+
+ self.cog.webhook = helpers.MockAsyncWebhook()
+ log = logging.getLogger("bot.cogs.duck_pond")
+
+ side_effect = discord.HTTPException(MagicMock(), "")
+ send_attachments.side_effect = side_effect
+ with self.subTest(side_effect=type(side_effect).__name__):
+ with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
+ await self.cog.relay_message(message)
+
+ send_webhook.assert_called_once_with(
+ content=message.clean_content,
+ username=message.author.display_name,
+ avatar_url=message.author.avatar_url
+ )
+
+ self.assertEqual(len(log_watcher.records), 1)
+
+ record = log_watcher.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+
+ def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str):
+ """Creates a mock `on_raw_reaction_add` payload with the specified emoji data."""
+ payload = MagicMock(name=label)
+ payload.emoji.is_custom_emoji.return_value = is_custom_emoji
+ payload.emoji.id = id_
+ payload.emoji.name = emoji_name
+ return payload
+
+ @helpers.async_test
+ async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
+ """The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
+ test_values = (
+ # Custom Emojis
+ (
+ self._mock_payload(
+ label="Custom Duckpond Emoji",
+ is_custom_emoji=True,
+ id_=constants.DuckPond.custom_emojis[0],
+ emoji_name=""
+ ),
+ True
+ ),
+ (
+ self._mock_payload(
+ label="Custom Non-Duckpond Emoji",
+ is_custom_emoji=True,
+ id_=123,
+ emoji_name=""
+ ),
+ False
+ ),
+ # Unicode Emojis
+ (
+ self._mock_payload(
+ label="Unicode Duck Emoji",
+ is_custom_emoji=False,
+ id_=1,
+ emoji_name=self.unicode_duck_emoji
+ ),
+ True
+ ),
+ (
+ self._mock_payload(
+ label="Unicode Non-Duck Emoji",
+ is_custom_emoji=False,
+ id_=1,
+ emoji_name=self.thumbs_up_emoji
+ ),
+ False
+ ),
+ )
+
+ for payload, expected_return in test_values:
+ actual_return = self.cog._payload_has_duckpond_emoji(payload)
+ with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return):
+ self.assertEqual(expected_return, actual_return)
+
+ @patch(f"{MODULE_PATH}.discord.utils.get")
+ @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False))
+ def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get):
+ """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji."""
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock())))
+
+ # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check
+ utils_get.assert_not_called()
+
+ def _raw_reaction_mocks(self, channel_id, message_id, user_id):
+ """Sets up mocks for tests of the `on_raw_reaction_add` event listener."""
+ channel = helpers.MockTextChannel(id=channel_id)
+ self.bot.get_all_channels.return_value = (channel,)
+
+ message = helpers.MockMessage(id=message_id)
+
+ channel.fetch_message.return_value = message
+
+ member = helpers.MockMember(id=user_id, roles=[self.staff_role])
+ message.guild.members = (member,)
+
+ payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id)
+
+ return channel, message, member, payload
+
+ @helpers.async_test
+ async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
+ """The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
+ channel_id = 1234
+ message_id = 2345
+ user_id = 3456
+
+ channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id)
+
+ test_cases = (
+ ("non-staff member", helpers.MockMember(id=user_id)),
+ ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)),
+ )
+
+ payload.emoji = self.duck_pond_emoji
+
+ for description, member in test_cases:
+ message.guild.members = (member, )
+ with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark:
+ checkmark.side_effect = AssertionError(
+ "Expected method to return before calling `self.has_green_checkmark`."
+ )
+ self.assertIsNone(await self.cog.on_raw_reaction_add(payload))
+
+ # Check that we did make it past the payload checks
+ channel.fetch_message.assert_called_once()
+ channel.fetch_message.reset_mock()
+
+ @patch(f"{MODULE_PATH}.DuckPond.is_staff")
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
+ """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
+ channel_id = 31415926535
+ message_id = 27182818284
+ user_id = 16180339887
+
+ channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id)
+
+ payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji)
+ payload.emoji.is_custom_emoji.return_value = False
+
+ message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])]
+
+ is_staff.return_value = True
+ count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`")
+
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload)))
+
+ # Assert that we've made it past `self.is_staff`
+ is_staff.assert_called_once()
+
+ @helpers.async_test
+ async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
+ """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
+ test_cases = (
+ (constants.DuckPond.threshold - 1, False),
+ (constants.DuckPond.threshold, True),
+ (constants.DuckPond.threshold + 1, True),
+ )
+
+ channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5)
+
+ payload.emoji = self.duck_pond_emoji
+
+ for duck_count, should_relay in test_cases:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ count_ducks.return_value = duck_count
+ with self.subTest(duck_count=duck_count, should_relay=should_relay):
+ await self.cog.on_raw_reaction_add(payload)
+
+ # Confirm that we've made it past counting
+ count_ducks.assert_called_once()
+
+ # Did we relay a message?
+ has_relayed = relay_message.called
+ self.assertEqual(has_relayed, should_relay)
+
+ if should_relay:
+ relay_message.assert_called_once_with(message)
+
+ @helpers.async_test
+ async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
+ """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
+ checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
+
+ message = helpers.MockMessage(id=1234)
+
+ channel = helpers.MockTextChannel(id=98765)
+ channel.fetch_message.return_value = message
+
+ self.bot.get_all_channels.return_value = (channel, )
+
+ payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark)
+
+ test_cases = (
+ (constants.DuckPond.threshold - 1, False),
+ (constants.DuckPond.threshold, True),
+ (constants.DuckPond.threshold + 1, True),
+ )
+ for duck_count, should_re_add_checkmark in test_cases:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ count_ducks.return_value = duck_count
+ with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
+ await self.cog.on_raw_reaction_remove(payload)
+
+ # Check if we fetched the message
+ channel.fetch_message.assert_called_once_with(message.id)
+
+ # Check if we actually counted the number of ducks
+ count_ducks.assert_called_once_with(message)
+
+ has_re_added_checkmark = message.add_reaction.called
+ self.assertEqual(should_re_add_checkmark, has_re_added_checkmark)
+
+ if should_re_add_checkmark:
+ message.add_reaction.assert_called_once_with(self.checkmark_emoji)
+ message.add_reaction.reset_mock()
+
+ # reset mocks
+ channel.fetch_message.reset_mock()
+ message.reset_mock()
+
+ def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self):
+ """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis."""
+ channel = helpers.MockTextChannel(id=98765)
+
+ channel.fetch_message.side_effect = AssertionError(
+ "Expected method to return before calling `channel.fetch_message`"
+ )
+
+ self.bot.get_all_channels.return_value = (channel, )
+
+ payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id)
+
+ self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload)))
+
+ channel.fetch_message.assert_not_called()
+
+
+class DuckPondSetupTests(unittest.TestCase):
+ """Tests setup of the `DuckPond` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = helpers.MockBot()
+ duck_pond.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py
index efa7a50b1..9d1a62f7e 100644
--- a/tests/bot/cogs/test_security.py
+++ b/tests/bot/cogs/test_security.py
@@ -1,4 +1,3 @@
-import logging
import unittest
from unittest.mock import MagicMock
@@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase):
"""Tests loading the `Security` cog."""
def test_security_cog_load(self):
- """Cog loading logs a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MagicMock()
- with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm:
- security.setup(bot)
- bot.add_cog.assert_called_once()
-
- [line] = cm.output
- self.assertIn("Cog loaded: Security", line)
+ security.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index 3276cf5a5..a54b839d7 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase):
"""Tests setup of the `TokenRemover` cog."""
def test_setup(self):
- """Setup of the cog should log a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MockBot()
- with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm:
- setup_cog(bot)
-
- [line] = cm.output
+ setup_cog(bot)
bot.add_cog.assert_called_once()
- self.assertIn("Cog loaded: TokenRemover", line)
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
new file mode 100644
index 000000000..69f35f2f5
--- /dev/null
+++ b/tests/bot/utils/test_time.py
@@ -0,0 +1,162 @@
+import asyncio
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from dateutil.relativedelta import relativedelta
+
+from bot.utils import time
+from tests.helpers import AsyncMock
+
+
+class TimeTests(unittest.TestCase):
+ """Test helper functions in bot.utils.time."""
+
+ def test_humanize_delta_handle_unknown_units(self):
+ """humanize_delta should be able to handle unknown units, and will not abort."""
+ # Does not abort for unknown units, as the unit name is checked
+ # against the attribute of the relativedelta instance.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours')
+
+ def test_humanize_delta_handle_high_units(self):
+ """humanize_delta should be able to handle very high units."""
+ # Very high maximum units, but it only ever iterates over
+ # each value the relativedelta might have.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours')
+
+ def test_humanize_delta_should_normal_usage(self):
+ """Testing humanize delta."""
+ test_cases = (
+ (relativedelta(days=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
+ (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
+ )
+
+ for delta, precision, max_units, expected in test_cases:
+ with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected):
+ self.assertEqual(time.humanize_delta(delta, precision, max_units), expected)
+
+ def test_humanize_delta_raises_for_invalid_max_units(self):
+ """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units."""
+ test_cases = (-1, 0)
+
+ for max_units in test_cases:
+ with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
+ time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
+ self.assertEqual(str(error), 'max_units must be positive')
+
+ def test_parse_rfc1123(self):
+ """Testing parse_rfc1123."""
+ self.assertEqual(
+ time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'),
+ datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)
+ )
+
+ def test_format_infraction(self):
+ """Testing format_infraction."""
+ self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
+
+ @patch('asyncio.sleep', new_callable=AsyncMock)
+ def test_wait_until(self, mock):
+ """Testing wait_until."""
+ start = datetime(2019, 1, 1, 0, 0)
+ then = datetime(2019, 1, 1, 0, 10)
+
+ # No return value
+ self.assertIs(asyncio.run(time.wait_until(then, start)), None)
+
+ mock.assert_called_once_with(10 * 60)
+
+ def test_format_infraction_with_duration_none_expiry(self):
+ """format_infraction_with_duration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that date_from and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_custom_units(self):
+ """format_infraction_with_duration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6,
+ '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20,
+ '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)')
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_normal_usage(self):
+ """format_infraction_with_duration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2,
+ '2019-11-23 23:59 (9 minutes and 55 seconds)'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_until_expiration_with_duration_none_expiry(self):
+ """until_expiration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that now and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_with_duration_custom_units(self):
+ """until_expiration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes')
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_normal_usage(self):
+ """until_expiration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
diff --git a/tests/helpers.py b/tests/helpers.py
index 8a14aeef4..5df796c23 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -10,7 +10,9 @@ import unittest.mock
from typing import Any, Iterable, Optional
import discord
-from discord.ext.commands import Bot, Context
+from discord.ext.commands import Context
+
+from bot.bot import Bot
for logger in logging.Logger.manager.loggerDict.values():
@@ -120,8 +122,80 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
features; this stand-in only overwrites the `__call__` method to an async version.
"""
+
async def __call__(self, *args, **kwargs):
- return super(AsyncMock, self).__call__(*args, **kwargs)
+ return super().__call__(*args, **kwargs)
+
+
+class AsyncIteratorMock:
+ """
+ A class to mock asynchronous iterators.
+
+ This allows async for, which is used in certain Discord.py objects. For example,
+ an async iterator is returned by the Reaction.users() method.
+ """
+
+ def __init__(self, iterable: Iterable = None):
+ if iterable is None:
+ iterable = []
+
+ self.iter = iter(iterable)
+ self.iterable = iterable
+
+ self.call_count = 0
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ return next(self.iter)
+ except StopIteration:
+ raise StopAsyncIteration
+
+ def __call__(self):
+ """
+ Keeps track of the number of times an instance has been called.
+
+ This is useful, since it typically shows that the iterator has actually been used somewhere after we have
+ instantiated the mock for an attribute that normally returns an iterator when called.
+ """
+ self.call_count += 1
+ return self
+
+ @property
+ def return_value(self):
+ """Makes `self.iterable` accessible as self.return_value."""
+ return self.iterable
+
+ @return_value.setter
+ def return_value(self, iterable):
+ """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
+ self.iter = iter(iterable)
+ self.iterable = iterable
+
+ def assert_called(self):
+ """Asserts if the AsyncIteratorMock instance has been called at least once."""
+ if self.call_count == 0:
+ raise AssertionError("Expected AsyncIteratorMock to have been called.")
+
+ def assert_called_once(self):
+ """Asserts if the AsyncIteratorMock instance has been called exactly once."""
+ if self.call_count != 1:
+ raise AssertionError(
+ f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
+ )
+
+ def assert_not_called(self):
+ """Asserts if the AsyncIteratorMock instance has not been called."""
+ if self.call_count != 0:
+ raise AssertionError(
+ f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
+ )
+
+ def reset_mock(self):
+ """Resets the call count, but not the return value or iterator."""
+ self.call_count = 0
# Create a guild instance to get a realistic Mock of `discord.Guild`
@@ -220,7 +294,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
information, see the `MockGuild` docstring.
"""
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
- default_kwargs = {'name': 'member', 'id': next(self.discord_id)}
+ default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
@@ -231,6 +305,25 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
self.mention = f"@{self.name}"
+# Create a User instance to get a realistic Mock of `discord.User`
+user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock())
+
+
+class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
+ """
+ A Mock subclass to mock User objects.
+
+ Instances of this class will follow the specifications of `discord.User` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
+ super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+
+ if 'mention' not in kwargs:
+ self.mention = f"@{self.name}"
+
+
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
bot_instance.http_session = None
@@ -244,6 +337,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=bot_instance, **kwargs)
@@ -281,6 +375,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
@@ -322,6 +417,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=context_instance, **kwargs)
self.bot = kwargs.get('bot', MockBot())
@@ -330,6 +426,20 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.channel = kwargs.get('channel', MockTextChannel())
+attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())
+
+
+class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Attachment objects.
+
+ Instances of this class will follow the specifications of `discord.Attachment` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=attachment_instance, **kwargs)
+
+
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
"""
A MagicMock subclass to mock Message objects.
@@ -337,8 +447,10 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=message_instance, **kwargs)
+ default_kwargs = {'attachments': []}
+ super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -354,6 +466,7 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=emoji_instance, **kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -369,6 +482,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=partial_emoji_instance, **kwargs)
@@ -383,7 +497,31 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=reaction_instance, **kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
+ self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+
+webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())
+
+
+class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter.
+
+ Instances of this class will follow the specifications of `discord.Webhook` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=webhook_instance, **kwargs)
+
+ # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
+ # as coroutines. That's why we need to set the methods manually.
+ self.send = AsyncMock()
+ self.edit = AsyncMock()
+ self.delete = AsyncMock()
+ self.execute = AsyncMock()