From d3f4673c1a1c3f5213840e756c5f35f7c70d46f6 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sun, 23 Feb 2020 12:39:25 +0100 Subject: Use mixin-composition not inheritance for LoggingTestCase We used inheritence to add additional logging assertion methods to unittest's TestCase class. However, with the introduction of the new IsolatedAsyncioTestCase this extension strategy means we'd have to create multiple child classes to be able to use the extended functionality in all of the TestCase variants. Since that leads to undesirable code reuse and an inheritance relationship is not at all needed, I've switched to a mixin-composition based approach that allows the user to extend the functionality of any TestCase variant with a mixin where needed. --- tests/base.py | 10 +++++++--- tests/bot/cogs/test_duck_pond.py | 2 +- tests/test_base.py | 18 ++++++------------ 3 files changed, 14 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 029a249ed..21a57716a 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,5 +1,4 @@ import logging -import unittest from contextlib import contextmanager @@ -16,8 +15,13 @@ class _CaptureLogHandler(logging.Handler): self.records.append(record) -class LoggingTestCase(unittest.TestCase): - """TestCase subclass that adds more logging assertion tools.""" +class LoggingTestsMixin: + """ + A mixin that defines additional test methods for logging behavior. + + This mixin relies on the availability of the `fail` attribute defined by the + test classes included in Python's unittest method to signal test failure. + """ @contextmanager def assertNotLogs(self, logger=None, level=None, msg=None): diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..320cbd5c5 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -14,7 +14,7 @@ from tests import helpers MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): """Tests for DuckPond functionality.""" @classmethod diff --git a/tests/test_base.py b/tests/test_base.py index a16e2af8f..23abb1dfd 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,7 +3,11 @@ import unittest import unittest.mock -from tests.base import LoggingTestCase, _CaptureLogHandler +from tests.base import LoggingTestsMixin, _CaptureLogHandler + + +class LoggingTestCase(LoggingTestsMixin): + pass class LoggingTestCaseTests(unittest.TestCase): @@ -18,19 +22,9 @@ class LoggingTestCaseTests(unittest.TestCase): try: with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): pass - except AssertionError: + except AssertionError: # pragma: no cover self.fail("`self.assertNotLogs` raised an AssertionError when it should not!") - @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs") - def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs): - """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly.""" - assertNotLogs.return_value = iter([None]) - assertNotLogs.side_effect = AssertionError - - message = "`self.assertNotLogs` raised an AssertionError when it should not!" - with self.assertRaises(AssertionError, msg=message): - self.test_assert_not_logs_does_not_raise_with_no_logs() - def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self): """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted.""" msg_regex = ( -- cgit v1.2.3 From 135d6daa4804574935cd788c5baec656765f484b Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sun, 23 Feb 2020 13:05:10 +0100 Subject: Use IsolatedAsyncioTestCase instead of async_test Since we upgraded to Python 3.8, we can now use the new IsolatedAsyncioTestCase test class to use coroutine-based test methods instead of our own, custom async_test decorator. I have changed the base class for all of our test classes that use coroutine-based test methods and removed the now obsolete decorator from our helpers. --- tests/bot/cogs/test_duck_pond.py | 11 +---------- tests/bot/rules/__init__.py | 2 +- tests/bot/rules/test_attachments.py | 4 +--- tests/bot/rules/test_burst.py | 4 +--- tests/bot/rules/test_burst_shared.py | 4 +--- tests/bot/rules/test_chars.py | 4 +--- tests/bot/rules/test_discord_emojis.py | 4 +--- tests/bot/rules/test_duplicates.py | 4 +--- tests/bot/rules/test_links.py | 4 +--- tests/bot/rules/test_mentions.py | 4 +--- tests/bot/rules/test_newlines.py | 5 +---- tests/bot/rules/test_role_mentions.py | 4 +--- tests/bot/test_api.py | 4 +--- tests/helpers.py | 17 ----------------- tests/test_helpers.py | 8 -------- 15 files changed, 13 insertions(+), 70 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 320cbd5c5..6406f0737 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -14,7 +14,7 @@ from tests import helpers MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): """Tests for DuckPond functionality.""" @classmethod @@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): self.assertEqual(expected_return, actual_return) - @helpers.async_test async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): """The `has_green_checkmark` method should only return `True` if one is present.""" test_cases = ( @@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - @helpers.async_test async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): """The `count_ducks` method should return the number of unique staffers who gave a duck.""" test_cases = ( @@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): self.assertEqual(expected_count, actual_count) - @helpers.async_test async def test_relay_message_correctly_relays_content_and_attachments(self): """The `relay_message` method should correctly relay message content and attachments.""" send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" @@ -307,7 +304,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): message.add_reaction.assert_called_once_with(self.checkmark_emoji) @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -327,7 +323,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): payload.emoji.name = emoji_name return payload - @helpers.async_test async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" test_values = ( @@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): return channel, message, member, payload - @helpers.async_test async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" channel_id = 1234 @@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): # Assert that we've made it past `self.is_staff` is_staff.assert_called_once() - @helpers.async_test async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" test_cases = ( @@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.TestCase): if should_relay: relay_message.assert_called_once_with(message) - @helpers.async_test async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index 36c986fe1..0233e7939 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple): n_violations: int -class RuleTest(unittest.TestCase, metaclass=ABCMeta): +class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): """ Abstract class for antispam rule test cases. diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index e54b4b5b8..d7e779221 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import attachments from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_attachments: int) -> MockMessage: @@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest): self.apply = attachments.apply self.config = {"max": 5, "interval": 10} - @async_test async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( @@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py index 72f0be0c7..03682966b 100644 --- a/tests/bot/rules/test_burst.py +++ b/tests/bot/rules/test_burst.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import burst from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest): self.apply = burst.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the amount of messages exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py index 47367a5f8..3275143d5 100644 --- a/tests/bot/rules/test_burst_shared.py +++ b/tests/bot/rules/test_burst_shared.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import burst_shared from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str) -> MockMessage: @@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest): self.apply = burst_shared.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """ Cases that do not violate the rule. @@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the amount of messages exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py index 7cc36f49e..f1e3c76a7 100644 --- a/tests/bot/rules/test_chars.py +++ b/tests/bot/rules/test_chars.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import chars from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, n_chars: int) -> MockMessage: @@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest): "interval": 10, } - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of chars within limit.""" cases = ( @@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases where the total amount of chars exceeds the limit, triggering the rule.""" cases = ( diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 0239b0b00..9a72723e2 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import discord_emojis from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> @@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest): self.apply = discord_emojis.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of discord emojis within limit.""" cases = ( @@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with more than the allowed amount of discord emojis.""" cases = ( diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py index 59e0fb6ef..9bd886a77 100644 --- a/tests/bot/rules/test_duplicates.py +++ b/tests/bot/rules/test_duplicates.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import duplicates from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, content: str) -> MockMessage: @@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest): self.apply = duplicates.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with too many duplicate messages from the same author.""" cases = ( diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 3c3f90e5f..b091bd9d7 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import links from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_links: int) -> MockMessage: @@ -21,7 +21,6 @@ class LinksTests(RuleTest): "interval": 10 } - @async_test async def test_links_within_limit(self): """Messages with an allowed amount of links.""" cases = ( @@ -34,7 +33,6 @@ class LinksTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_links_exceeding_limit(self): """Messages with a a higher than allowed amount of links.""" cases = ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ebcdabac6..6444532f2 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, total_mentions: int) -> MockMessage: @@ -20,7 +20,6 @@ class TestMentions(RuleTest): "interval": 10, } - @async_test async def test_mentions_within_limit(self): """Messages with an allowed amount of mentions.""" cases = ( @@ -32,7 +31,6 @@ class TestMentions(RuleTest): await self.run_allowed(cases) - @async_test async def test_mentions_exceeding_limit(self): """Messages with a higher than allowed amount of mentions.""" cases = ( diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py index d61c4609d..e35377773 100644 --- a/tests/bot/rules/test_newlines.py +++ b/tests/bot/rules/test_newlines.py @@ -2,7 +2,7 @@ from typing import Iterable, List from bot.rules import newlines from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, newline_groups: List[int]) -> MockMessage: @@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest): "interval": 10, } - @async_test async def test_allows_messages_within_limit(self): """Cases which do not violate the rule.""" cases = ( @@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_total(self): """Cases which violate the rule by having too many newlines in total.""" cases = ( @@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest): self.apply = newlines.apply self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - @async_test async def test_disallows_messages_consecutive(self): """Cases which violate the rule due to having too many consecutive newlines.""" cases = ( diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py index b339cccf7..26c05d527 100644 --- a/tests/bot/rules/test_role_mentions.py +++ b/tests/bot/rules/test_role_mentions.py @@ -2,7 +2,7 @@ from typing import Iterable from bot.rules import role_mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage, async_test +from tests.helpers import MockMessage def make_msg(author: str, n_mentions: int) -> MockMessage: @@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest): self.apply = role_mentions.apply self.config = {"max": 2, "interval": 10} - @async_test async def test_allows_messages_within_limit(self): """Cases with a total amount of role mentions within limit.""" cases = ( @@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest): await self.run_allowed(cases) - @async_test async def test_disallows_messages_beyond_limit(self): """Cases with more than the allowed amount of role mentions.""" cases = ( diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index bdfcc73e4..99e942813 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -2,10 +2,9 @@ import unittest from unittest.mock import MagicMock from bot import api -from tests.helpers import async_test -class APIClientTests(unittest.TestCase): +class APIClientTests(unittest.IsolatedAsyncioTestCase): """Tests for the bot's API client.""" @classmethod @@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase): """The event loop should not be running by default.""" self.assertFalse(api.loop_is_running()) - @async_test async def test_loop_is_running_in_async_context(self): """The event loop should be running in an async context.""" self.assertTrue(api.loop_is_running()) diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..01752a791 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,6 @@ from __future__ import annotations -import asyncio import collections -import functools import inspect import itertools import logging @@ -25,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def async_test(wrapped): - """ - Run a test case via asyncio. - Example: - >>> @async_test - ... async def lemon_wins(): - ... assert True - """ - - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - return asyncio.run(wrapped(*args, **kwargs)) - return wrapper - - class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7894e104a..fe39df308 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -395,11 +395,3 @@ class MockObjectTests(unittest.TestCase): coroutine = async_mock() self.assertTrue(inspect.iscoroutine(coroutine)) self.assertIsNotNone(asyncio.run(coroutine)) - - def test_async_test_decorator_allows_synchronous_call_to_async_def(self): - """Test if the `async_test` decorator allows an `async def` to be called synchronously.""" - @helpers.async_test - async def kosayoda(): - return "return value" - - self.assertEqual(kosayoda(), "return value") -- cgit v1.2.3 From b6500eb967ae4856d4d65d7946b1e341c093eedd Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sun, 23 Feb 2020 13:41:27 +0100 Subject: Remove lingering pytest test_time.py file I forgot to remove one pytest test file during the migration from pytest to unittest. Since we have sinced added a unittest version of the same file, I've now removed the lingering pytest file. --- tests/utils/test_time.py | 62 ------------------------------------------------ 1 file changed, 62 deletions(-) delete mode 100644 tests/utils/test_time.py (limited to 'tests') diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py deleted file mode 100644 index 4baa6395c..000000000 --- a/tests/utils/test_time.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from datetime import datetime, timezone -from unittest.mock import patch - -import pytest -from dateutil.relativedelta import relativedelta - -from bot.utils import time -from tests.helpers import AsyncMock - - -@pytest.mark.parametrize( - ('delta', 'precision', 'max_units', 'expected'), - ( - (relativedelta(days=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), - (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'days', 2, '2 days'), - - # Does not abort for unknown units, as the unit name is checked - # against the attribute of the relativedelta instance. - (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), - - # Very high maximum units, but it only ever iterates over - # each value the relativedelta might have. - (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), - ) -) -def test_humanize_delta( - delta: relativedelta, - precision: str, - max_units: int, - expected: str -): - assert time.humanize_delta(delta, precision, max_units) == expected - - -@pytest.mark.parametrize('max_units', (-1, 0)) -def test_humanize_delta_raises_for_invalid_max_units(max_units: int): - with pytest.raises(ValueError, match='max_units must be positive'): - time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - - -@pytest.mark.parametrize( - ('stamp', 'expected'), - ( - ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), - ) -) -def test_parse_rfc1123(stamp: str, expected: str): - assert time.parse_rfc1123(stamp) == expected - - -@patch('asyncio.sleep', new_callable=AsyncMock) -def test_wait_until(sleep_patch): - start = datetime(2019, 1, 1, 0, 0) - then = datetime(2019, 1, 1, 0, 10) - - # No return value - assert asyncio.run(time.wait_until(then, start)) is None - - sleep_patch.assert_called_once_with(10 * 60) -- cgit v1.2.3 From ea64d7cc6defa759fc1c7f1631a7ae9b8073cc29 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sun, 23 Feb 2020 20:53:45 +0100 Subject: Use unittest's AsyncMock instead of our AsyncMock Python 3.8 introduced an `unittest.mock.AsyncMock` class that can be used to mock coroutines and other types of asynchronous operations like async iterators and async context managers. As we were using our custom, but limited, AsyncMock, I have replaced our mock with unittest's AsyncMock. Since Python 3.8 also introduces a different way of automatically detecting which attributes should be mocked with an AsyncMock, I've changed our CustomMockMixin to use this new method as well. Together with a couple other small changes, this means that our Custom Mocks now use a lazy method of detecting coroutine attributes, which significantly speeds up the test suite. --- tests/bot/cogs/test_duck_pond.py | 22 ++-- tests/bot/cogs/test_information.py | 34 +++---- tests/bot/cogs/test_token_remover.py | 4 +- tests/bot/utils/test_time.py | 3 +- tests/helpers.py | 190 ++++++++++++----------------------- tests/test_helpers.py | 63 ++---------- 6 files changed, 103 insertions(+), 213 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 6406f0737..e164f7544 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio import logging import typing import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import discord @@ -293,8 +293,8 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): ) for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: with self.subTest(clean_content=message.clean_content, attachments=message.attachments): await self.cog.relay_message(message) @@ -303,7 +303,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): message.add_reaction.assert_called_once_with(self.checkmark_emoji) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -314,15 +314,15 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): for side_effect in side_effects: send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: with self.subTest(side_effect=type(side_effect).__name__): with self.assertNotLogs(logger=log, level=logging.ERROR): await self.cog.relay_message(message) self.assertEqual(send_webhook.call_count, 2) - @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -456,7 +456,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): channel.fetch_message.reset_mock() @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" channel_id = 31415926535 @@ -491,8 +491,8 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): payload.emoji = self.duck_pond_emoji for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_relay=should_relay): await self.cog.on_raw_reaction_add(payload) @@ -526,7 +526,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): (constants.DuckPond.threshold + 1, True), ) for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index deae7ebad..f5e937356 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase): """Test if the `role_info` command correctly returns the `moderator_role`.""" self.ctx.guild.roles.append(self.moderator_role) - self.cog.roles_info.can_run = helpers.AsyncMock() + self.cog.roles_info.can_run = unittest.mock.AsyncMock() self.cog.roles_info.can_run.return_value = True coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase): self.ctx.guild.roles.append([dummy_role, admin_role]) - self.cog.role_info.can_run = helpers.AsyncMock() + self.cog.role_info.can_run = unittest.mock.AsyncMock() self.cog.role_info.can_run.return_value = True coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) self.member = helpers.MockMember(id=1234) @@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Mr. Hemlock") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase): self.assertIn("&Admins", embed.description) self.assertNotIn("&Everyone", embed.description) - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): """The embed should contain expanded infractions and nomination info in mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): """The embed should contain only basic infraction data outside of mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() @@ -529,7 +529,7 @@ class UserCommandTests(unittest.TestCase): with self.assertRaises(InChannelCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -542,7 +542,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -555,7 +555,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] @@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.moderator) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_moderators_can_target_another_member(self, create_embed, constants): """A moderator should be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..33d1ec170 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,7 +1,7 @@ import asyncio import logging import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock from discord import Colour @@ -11,7 +11,7 @@ from bot.cogs.token_remover import ( setup as setup_cog, ) from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from tests.helpers import MockBot, MockMessage class TokenRemoverTests(unittest.TestCase): diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..de5724bca 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@ import asyncio import unittest from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from dateutil.relativedelta import relativedelta from bot.utils import time -from tests.helpers import AsyncMock class TimeTests(unittest.TestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 01752a791..506fe9894 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,11 +1,10 @@ from __future__ import annotations import collections -import inspect import itertools import logging import unittest.mock -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import discord from discord.ext.commands import Context @@ -51,24 +50,31 @@ class CustomMockMixin: """ Provides common functionality for our custom Mock types. - The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine - function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care - of making sure child mocks are instantiated with the correct class. By default, the mock of the - children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute - `child_mock_type` on the custom mock inheriting from this mixin. + The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the + class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional + attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The + class method `spec_set` can be overwritten with the object that should be uses as the specification + for the mock. + + Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to + implement custom behavior. """ child_mock_type = unittest.mock.MagicMock discord_id = itertools.count(0) + spec_set = None + additional_spec_asyncs = None - def __init__(self, spec_set: Any = None, **kwargs): + def __init__(self, **kwargs): name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually. - super().__init__(spec_set=spec_set, **kwargs) + super().__init__(spec_set=self.spec_set, **kwargs) + + if self.additional_spec_asyncs: + self._spec_asyncs.extend(self.additional_spec_asyncs) if name: self.name = name - if spec_set: - self._extract_coroutine_methods_from_spec_instance(spec_set) def _get_child_mock(self, **kw): """ @@ -82,7 +88,16 @@ class CustomMockMixin: This override will look for an attribute called `child_mock_type` and use that as the type of the child mock. """ - klass = self.child_mock_type + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return unittest.mock.AsyncMock(**kw) + + _type = type(self) + if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: + # Any asynchronous magic becomes an AsyncMock + klass = unittest.mock.AsyncMock + else: + klass = self.child_mock_type if self._mock_sealed: attribute = "." + kw["name"] if "name" in kw else "()" @@ -91,95 +106,6 @@ class CustomMockMixin: return klass(**kw) - def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: - """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" - for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): - setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): - """ - A MagicMock subclass to mock async callables. - - Python 3.8 will introduce an AsyncMock class in the standard library that will have some more - features; this stand-in only overwrites the `__call__` method to an async version. - """ - - async def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - -class AsyncIteratorMock: - """ - A class to mock asynchronous iterators. - - This allows async for, which is used in certain Discord.py objects. For example, - an async iterator is returned by the Reaction.users() method. - """ - - def __init__(self, iterable: Iterable = None): - if iterable is None: - iterable = [] - - self.iter = iter(iterable) - self.iterable = iterable - - self.call_count = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - def __call__(self): - """ - Keeps track of the number of times an instance has been called. - - This is useful, since it typically shows that the iterator has actually been used somewhere after we have - instantiated the mock for an attribute that normally returns an iterator when called. - """ - self.call_count += 1 - return self - - @property - def return_value(self): - """Makes `self.iterable` accessible as self.return_value.""" - return self.iterable - - @return_value.setter - def return_value(self, iterable): - """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" - self.iter = iter(iterable) - self.iterable = iterable - - def assert_called(self): - """Asserts if the AsyncIteratorMock instance has been called at least once.""" - if self.call_count == 0: - raise AssertionError("Expected AsyncIteratorMock to have been called.") - - def assert_called_once(self): - """Asserts if the AsyncIteratorMock instance has been called exactly once.""" - if self.call_count != 1: - raise AssertionError( - f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." - ) - - def assert_not_called(self): - """Asserts if the AsyncIteratorMock instance has not been called.""" - if self.call_count != 0: - raise AssertionError( - f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." - ) - - def reset_mock(self): - """Resets the call count, but not the return value or iterator.""" - self.call_count = 0 - # Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { @@ -230,9 +156,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): For more info, see the `Mocking` section in `tests/README.md`. """ + spec_set = guild_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'members': []} - super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -251,9 +179,11 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ + spec_set = role_instance + def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} - super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -276,9 +206,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ + spec_set = member_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -299,9 +231,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ + spec_set = user_instance + def __init__(self, **kwargs) -> None: default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"@{self.name}" @@ -320,14 +254,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + spec_set = bot_instance + additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: - super().__init__(spec_set=bot_instance, **kwargs) + super().__init__(**kwargs) # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which # is technically incorrect, since it's a regular def.) - self.wait_for = AsyncMock() + # self.wait_for = unittest.mock.AsyncMock() # Since calling `create_task` on our MockBot does not actually schedule the coroutine object # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object @@ -358,10 +294,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + spec_set = channel_instance def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} - super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"#{self.name}" @@ -400,9 +337,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + spec_set = context_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=context_instance, **kwargs) + super().__init__(**kwargs) self.bot = kwargs.get('bot', MockBot()) self.guild = kwargs.get('guild', MockGuild()) self.author = kwargs.get('author', MockMember()) @@ -419,8 +357,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=attachment_instance, **kwargs) + spec_set = attachment_instance class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -430,10 +367,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + spec_set = message_instance def __init__(self, **kwargs) -> None: default_kwargs = {'attachments': []} - super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -449,9 +387,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + spec_set = emoji_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=emoji_instance, **kwargs) + super().__init__(**kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -465,9 +404,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=partial_emoji_instance, **kwargs) + spec_set = partial_emoji_instance reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -480,12 +417,17 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + spec_set = reaction_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=reaction_instance, **kwargs) + _users = kwargs.pop("users", []) + super().__init__(**kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) - self.users = AsyncIteratorMock(kwargs.get('users', [])) + + user_iterator = unittest.mock.AsyncMock() + user_iterator.__aiter__.return_value = _users + self.users.return_value = user_iterator webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) @@ -498,13 +440,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=webhook_instance, **kwargs) - - # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined - # as coroutines. That's why we need to set the methods manually. - self.send = AsyncMock() - self.edit = AsyncMock() - self.delete = AsyncMock() - self.execute = AsyncMock() + spec_set = webhook_instance + additional_spec_asyncs = ("send", "edit", "delete", "execute") diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fe39df308..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,4 @@ import asyncio -import inspect import unittest import unittest.mock @@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase): with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"): asyncio.run(coroutine_object) + def test_user_mock_uses_explicitly_passed_mention_attribute(self): + """MockUser should use an explicitly passed value for user.mention.""" + user = helpers.MockUser(mention="hello") + self.assertEqual(user.mention, "hello") + class MockObjectTests(unittest.TestCase): """Tests the mock objects and mixins we've defined.""" @@ -341,57 +345,10 @@ class MockObjectTests(unittest.TestCase): attribute = getattr(mock, valid_attribute) self.assertTrue(isinstance(attribute, mock_type.child_mock_type)) - def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self): - """Test if all coroutine functions are extracted, but not regular methods or attributes.""" - class CoroutineDonor: - def __init__(self): - self.some_attribute = 'alpha' - - async def first_coroutine(): - """This coroutine function should be extracted.""" - - async def second_coroutine(): - """This coroutine function should be extracted.""" - - def regular_method(): - """This regular function should not be extracted.""" - - class Receiver: + def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self): + """The CustomMockMixin should mock async magic methods with an AsyncMock.""" + class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock): pass - donor = CoroutineDonor() - receiver = Receiver() - - helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor) - - self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock) - self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock) - self.assertFalse(hasattr(receiver, 'regular_method')) - self.assertFalse(hasattr(receiver, 'some_attribute')) - - @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) - @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") - def test_custom_mock_mixin_init_with_spec(self, extract_method_mock): - """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" - spec_set = "pydis" - - helpers.CustomMockMixin(spec_set=spec_set) - - extract_method_mock.assert_called_once_with(spec_set) - - @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) - @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") - def test_custom_mock_mixin_init_without_spec(self, extract_method_mock): - """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" - helpers.CustomMockMixin() - - extract_method_mock.assert_not_called() - - def test_async_mock_provides_coroutine_for_dunder_call(self): - """Test if AsyncMock objects have a coroutine for their __call__ method.""" - async_mock = helpers.AsyncMock() - self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__)) - - coroutine = async_mock() - self.assertTrue(inspect.iscoroutine(coroutine)) - self.assertIsNotNone(asyncio.run(coroutine)) + mock = MyMock() + self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock) -- cgit v1.2.3 From f67cb7ac61eee86419d10e23e3fd3c66f1f9312e Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sun, 23 Feb 2020 20:58:20 +0100 Subject: Fix test_time test and ensure coverage One of the test_time methods did not actually assert the exception message it was trying to detect as the assertion statement was contained within the context manager handling the exception. I've moved it out of the context so it actually runs. I've also added a few `praga: no cover` comments for parts that were artifically lowering coverage of the test suite. --- tests/bot/cogs/test_duck_pond.py | 2 +- tests/bot/rules/__init__.py | 4 ++-- tests/bot/utils/test_time.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index e164f7544..7370b8471 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -312,7 +312,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): self.cog.webhook = helpers.MockAsyncWebhook() log = logging.getLogger("bot.cogs.duck_pond") - for side_effect in side_effects: + for side_effect in side_effects: # pragma: no cover send_attachments.side_effect = side_effect with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: with self.subTest(side_effect=type(side_effect).__name__): diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index 0233e7939..0d570f5a3 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -68,9 +68,9 @@ class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): @abstractmethod def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: """Give expected relevant messages for `case`.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover @abstractmethod def get_report(self, case: DisallowedCase) -> str: """Give expected error report for `case`.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index de5724bca..694d3a40f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -43,7 +43,7 @@ class TimeTests(unittest.TestCase): for max_units in test_cases: with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - self.assertEqual(str(error), 'max_units must be positive') + self.assertEqual(str(error.exception), 'max_units must be positive') def test_parse_rfc1123(self): """Testing parse_rfc1123.""" -- cgit v1.2.3 From c7ffafeedc44fde40e3bd5dae6c95fbabc75a9d9 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Mon, 24 Feb 2020 02:10:22 +0100 Subject: Use realistic mixin implementation Instead of using the mixin class bare, I've now included into a class tha subclasses unittest.TestCase as that's how it's going to be used "in the wild". --- tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_base.py b/tests/test_base.py index 23abb1dfd..235a2ee6c 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -6,7 +6,7 @@ import unittest.mock from tests.base import LoggingTestsMixin, _CaptureLogHandler -class LoggingTestCase(LoggingTestsMixin): +class LoggingTestCase(LoggingTestsMixin, unittest.TestCase): pass -- cgit v1.2.3 From b8bd18bd743608ddff47064d0b459edff3da65e3 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Mon, 24 Feb 2020 02:12:02 +0100 Subject: Migrate syncers test suite to Python 3.8 The test suite for the new role/member syncers used the "old"-style test suite with the helpers implemented for Python 3.7. I have migrated it to use the new Python 3.8 asyncio test helpers. --- tests/base.py | 4 ++-- tests/bot/cogs/sync/test_base.py | 45 ++++++++++++++++----------------------- tests/bot/cogs/sync/test_cog.py | 31 ++++++++------------------- tests/bot/cogs/sync/test_roles.py | 12 ++--------- tests/bot/cogs/sync/test_users.py | 13 ++--------- tests/helpers.py | 4 +--- 6 files changed, 34 insertions(+), 75 deletions(-) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 21613110e..42174e911 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,4 +1,5 @@ import logging +import unittest from contextlib import contextmanager from typing import Dict @@ -77,10 +78,9 @@ class LoggingTestsMixin: self.fail(msg) -class CommandTestCase(unittest.TestCase): +class CommandTestCase(unittest.IsolatedAsyncioTestCase): """TestCase with additional assertions that are useful for testing Discord commands.""" - @helpers.async_test async def assertHasPermissionsCheck( self, cmd: commands.Command, diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index e6a6f9688..17aa4198b 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -13,8 +13,8 @@ class TestSyncer(Syncer): """Syncer subclass with mocks for abstract methods for testing purposes.""" name = "test" - _get_diff = helpers.AsyncMock() - _sync = helpers.AsyncMock() + _get_diff = mock.AsyncMock() + _sync = mock.AsyncMock() class SyncerBaseTests(unittest.TestCase): @@ -29,7 +29,7 @@ class SyncerBaseTests(unittest.TestCase): Syncer(self.bot) -class SyncerSendPromptTests(unittest.TestCase): +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): """Tests for sending the sync confirmation prompt.""" def setUp(self): @@ -61,7 +61,6 @@ class SyncerSendPromptTests(unittest.TestCase): return mock_channel, mock_message - @helpers.async_test async def test_send_prompt_edits_and_returns_message(self): """The given message should be edited to display the prompt and then should be returned.""" msg = helpers.MockMessage() @@ -71,7 +70,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIn("content", msg.edit.call_args[1]) self.assertEqual(ret_val, msg) - @helpers.async_test async def test_send_prompt_gets_dev_core_channel(self): """The dev-core channel should be retrieved if an extant message isn't given.""" subtests = ( @@ -86,7 +84,6 @@ class SyncerSendPromptTests(unittest.TestCase): method.assert_called_once_with(constants.Channels.devcore) - @helpers.async_test async def test_send_prompt_returns_None_if_channel_fetch_fails(self): """None should be returned if there's an HTTPException when fetching the channel.""" self.bot.get_channel.return_value = None @@ -96,7 +93,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIsNone(ret_val) - @helpers.async_test async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): """A new message mentioning core devs should be sent and returned if message isn't given.""" for mock_ in (self.mock_get_channel, self.mock_fetch_channel): @@ -108,7 +104,6 @@ class SyncerSendPromptTests(unittest.TestCase): self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) self.assertEqual(ret_val, mock_message) - @helpers.async_test async def test_send_prompt_adds_reactions(self): """The message should have reactions for confirmation added.""" extant_message = helpers.MockMessage() @@ -129,7 +124,7 @@ class SyncerSendPromptTests(unittest.TestCase): mock_message.add_reaction.assert_has_calls(calls) -class SyncerConfirmationTests(unittest.TestCase): +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): """Tests for waiting for a sync confirmation reaction on the prompt.""" def setUp(self): @@ -211,7 +206,6 @@ class SyncerConfirmationTests(unittest.TestCase): ret_val = self.syncer._reaction_check(*args) self.assertFalse(ret_val) - @helpers.async_test async def test_wait_for_confirmation(self): """The message should always be edited and only return True if the emoji is a check mark.""" subtests = ( @@ -251,14 +245,13 @@ class SyncerConfirmationTests(unittest.TestCase): self.assertIs(actual_return, ret_val) -class SyncerSyncTests(unittest.TestCase): +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for main function orchestrating the sync.""" def setUp(self): self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) self.syncer = TestSyncer(self.bot) - @helpers.async_test async def test_sync_respects_confirmation_result(self): """The sync should abort if confirmation fails and continue if confirmed.""" mock_message = helpers.MockMessage() @@ -274,7 +267,7 @@ class SyncerSyncTests(unittest.TestCase): diff = _Diff({1, 2, 3}, {4, 5}, None) self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = helpers.AsyncMock( + self.syncer._get_confirmation_result = mock.AsyncMock( return_value=(confirmed, message) ) @@ -289,7 +282,6 @@ class SyncerSyncTests(unittest.TestCase): else: self.syncer._sync.assert_not_called() - @helpers.async_test async def test_sync_diff_size(self): """The diff size should be correctly calculated.""" subtests = ( @@ -303,7 +295,7 @@ class SyncerSyncTests(unittest.TestCase): with self.subTest(size=size, diff=diff): self.syncer._get_diff.reset_mock() self.syncer._get_diff.return_value = diff - self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) guild = helpers.MockGuild() await self.syncer.sync(guild) @@ -312,7 +304,6 @@ class SyncerSyncTests(unittest.TestCase): self.syncer._get_confirmation_result.assert_called_once() self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) - @helpers.async_test async def test_sync_message_edited(self): """The message should be edited if one was sent, even if the sync has an API error.""" subtests = ( @@ -324,7 +315,7 @@ class SyncerSyncTests(unittest.TestCase): for message, side_effect, should_edit in subtests: with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): self.syncer._sync.side_effect = side_effect - self.syncer._get_confirmation_result = helpers.AsyncMock( + self.syncer._get_confirmation_result = mock.AsyncMock( return_value=(True, message) ) @@ -335,7 +326,6 @@ class SyncerSyncTests(unittest.TestCase): message.edit.assert_called_once() self.assertIn("content", message.edit.call_args[1]) - @helpers.async_test async def test_sync_confirmation_context_redirect(self): """If ctx is given, a new message should be sent and author should be ctx's author.""" mock_member = helpers.MockMember() @@ -349,7 +339,10 @@ class SyncerSyncTests(unittest.TestCase): if ctx is not None: ctx.send.return_value = message - self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) guild = helpers.MockGuild() await self.syncer.sync(guild, ctx) @@ -362,16 +355,15 @@ class SyncerSyncTests(unittest.TestCase): self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) @mock.patch.object(constants.Sync, "max_diff", new=3) - @helpers.async_test async def test_confirmation_result_small_diff(self): """Should always return True and the given message if the diff size is too small.""" author = helpers.MockMember() expected_message = helpers.MockMessage() - for size in (3, 2): + for size in (3, 2): # pragma: no cover with self.subTest(size=size): - self.syncer._send_prompt = helpers.AsyncMock() - self.syncer._wait_for_confirmation = helpers.AsyncMock() + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() coro = self.syncer._get_confirmation_result(size, author, expected_message) result, actual_message = await coro @@ -382,7 +374,6 @@ class SyncerSyncTests(unittest.TestCase): self.syncer._wait_for_confirmation.assert_not_called() @mock.patch.object(constants.Sync, "max_diff", new=3) - @helpers.async_test async def test_confirmation_result_large_diff(self): """Should return True if confirmed and False if _send_prompt fails or aborted.""" author = helpers.MockMember() @@ -394,10 +385,10 @@ class SyncerSyncTests(unittest.TestCase): (False, mock_message, False, "aborted"), ) - for expected_result, expected_message, confirmed, msg in subtests: + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover with self.subTest(msg=msg): - self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) - self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) coro = self.syncer._get_confirmation_result(4, author) actual_result, actual_message = await coro diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 98c9afc0d..8c87c0d6b 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -18,12 +18,13 @@ class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` instances. For more information, see the `MockGuild` docstring. """ + spec_set = Syncer def __init__(self, **kwargs) -> None: - super().__init__(spec_set=Syncer, **kwargs) + super().__init__(**kwargs) -class SyncExtensionTests(unittest.TestCase): +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod @@ -34,7 +35,7 @@ class SyncExtensionTests(unittest.TestCase): bot.add_cog.assert_called_once() -class SyncCogTestCase(unittest.TestCase): +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): """Base class for Sync cog tests. Sets up patches for syncers.""" def setUp(self): @@ -72,13 +73,13 @@ class SyncCogTestCase(unittest.TestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch.object(sync.Sync, "sync_guild") + @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) def test_sync_cog_init(self, sync_guild): """Should instantiate syncers and run a sync for the guild.""" # Reset because a Sync cog was already instantiated in setUp. self.RoleSyncer.reset_mock() self.UserSyncer.reset_mock() - self.bot.loop.create_task.reset_mock() + self.bot.loop.create_task = mock.MagicMock() mock_sync_guild_coro = mock.MagicMock() sync_guild.return_value = mock_sync_guild_coro @@ -90,7 +91,6 @@ class SyncCogTests(SyncCogTestCase): sync_guild.assert_called_once_with() self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) - @helpers.async_test async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" for guild in (helpers.MockGuild(), None): @@ -126,14 +126,12 @@ class SyncCogTests(SyncCogTestCase): json=updated_information, ) - @helpers.async_test async def test_sync_cog_patch_user(self): """A PATCH request should be sent and 404 errors ignored.""" for side_effect in (None, self.response_error(404)): with self.subTest(side_effect=side_effect): await self.patch_user_helper(side_effect) - @helpers.async_test async def test_sync_cog_patch_user_non_404(self): """A PATCH request should be sent and the error raised if it's not a 404.""" with self.assertRaises(ResponseCodeError): @@ -145,9 +143,8 @@ class SyncCogListenerTests(SyncCogTestCase): def setUp(self): super().setUp() - self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) - @helpers.async_test async def test_sync_cog_on_guild_role_create(self): """A POST request should be sent with the new role's data.""" self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) @@ -164,7 +161,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) - @helpers.async_test async def test_sync_cog_on_guild_role_delete(self): """A DELETE request should be sent.""" self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) @@ -174,7 +170,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.bot.api_client.delete.assert_called_once_with("bot/roles/99") - @helpers.async_test async def test_sync_cog_on_guild_role_update(self): """A PUT request should be sent if the colour, name, permissions, or position changes.""" self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) @@ -212,7 +207,6 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.put.assert_not_called() - @helpers.async_test async def test_sync_cog_on_member_remove(self): """Member should patched to set in_guild as False.""" self.assertTrue(self.cog.on_member_remove.__cog_listener__) @@ -225,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase): updated_information={"in_guild": False} ) - @helpers.async_test async def test_sync_cog_on_member_update_roles(self): """Members should be patched if their roles have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -240,7 +233,6 @@ class SyncCogListenerTests(SyncCogTestCase): data = {"roles": sorted(role.id for role in after_member.roles)} self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) - @helpers.async_test async def test_sync_cog_on_member_update_other(self): """Members should not be patched if other attributes have changed.""" self.assertTrue(self.cog.on_member_update.__cog_listener__) @@ -262,7 +254,6 @@ class SyncCogListenerTests(SyncCogTestCase): self.cog.patch_user.assert_not_called() - @helpers.async_test async def test_sync_cog_on_user_update(self): """A user should be patched only if the name, discriminator, or avatar changes.""" self.assertTrue(self.cog.on_user_update.__cog_listener__) @@ -341,7 +332,6 @@ class SyncCogListenerTests(SyncCogTestCase): return data - @helpers.async_test async def test_sync_cog_on_member_join(self): """Should PUT user's data or POST it if the user doesn't exist.""" for side_effect in (None, self.response_error(404)): @@ -354,7 +344,6 @@ class SyncCogListenerTests(SyncCogTestCase): else: self.bot.api_client.post.assert_not_called() - @helpers.async_test async def test_sync_cog_on_member_join_non_404(self): """ResponseCodeError should be re-raised if status code isn't a 404.""" with self.assertRaises(ResponseCodeError): @@ -366,7 +355,6 @@ class SyncCogListenerTests(SyncCogTestCase): class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): """Tests for the commands in the Sync cog.""" - @helpers.async_test async def test_sync_roles_command(self): """sync() should be called on the RoleSyncer.""" ctx = helpers.MockContext() @@ -374,7 +362,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) - @helpers.async_test async def test_sync_users_command(self): """sync() should be called on the UserSyncer.""" ctx = helpers.MockContext() @@ -382,7 +369,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) - def test_commands_require_admin(self): + async def test_commands_require_admin(self): """The sync commands should only run if the author has the administrator permission.""" cmds = ( self.cog.sync_group, @@ -392,4 +379,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): for cmd in cmds: with self.subTest(cmd=cmd): - self.assertHasPermissionsCheck(cmd, {"administrator": True}) + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 14fb2577a..79eee98f4 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -18,7 +18,7 @@ def fake_role(**kwargs): return kwargs -class RoleSyncerDiffTests(unittest.TestCase): +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between roles in the DB and roles in the Guild cache.""" def setUp(self): @@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase): return guild - @helpers.async_test async def test_empty_diff_for_identical_roles(self): """No differences should be found if the roles in the guild and DB are identical.""" self.bot.api_client.get.return_value = [fake_role()] @@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_updated_roles(self): """Only updated roles should be added to the 'updated' set of the diff.""" updated_role = fake_role(id=41, name="new") @@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_roles(self): """Only new roles should be added to the 'created' set of the diff.""" new_role = fake_role(id=41, name="new") @@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_deleted_roles(self): """Only deleted roles should be added to the 'deleted' set of the diff.""" deleted_role = fake_role(id=61, name="deleted") @@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_updated_and_deleted_roles(self): """When roles are added, updated, and removed, all of them are returned properly.""" new = fake_role(id=41, name="new") @@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) -class RoleSyncerSyncTests(unittest.TestCase): +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync roles.""" def setUp(self): self.bot = helpers.MockBot() self.syncer = RoleSyncer(self.bot) - @helpers.async_test async def test_sync_created_roles(self): """Only POST requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] @@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase): self.bot.api_client.put.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_updated_roles(self): """Only PUT requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] @@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase): self.bot.api_client.post.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_deleted_roles(self): """Only DELETE requests should be made with the correct payload.""" roles = [fake_role(id=111), fake_role(id=222)] diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 421bf6bb6..818883012 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -17,7 +17,7 @@ def fake_user(**kwargs): return kwargs -class UserSyncerDiffTests(unittest.TestCase): +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between users in the DB and users in the Guild cache.""" def setUp(self): @@ -42,7 +42,6 @@ class UserSyncerDiffTests(unittest.TestCase): return guild - @helpers.async_test async def test_empty_diff_for_no_users(self): """When no users are given, an empty diff should be returned.""" guild = self.get_guild() @@ -52,7 +51,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_empty_diff_for_identical_users(self): """No differences should be found if the users in the guild and DB are identical.""" self.bot.api_client.get.return_value = [fake_user()] @@ -63,7 +61,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_updated_users(self): """Only updated users should be added to the 'updated' set of the diff.""" updated_user = fake_user(id=99, name="new") @@ -76,7 +73,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_users(self): """Only new users should be added to the 'created' set of the diff.""" new_user = fake_user(id=99, name="new") @@ -89,7 +85,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" leaving_user = fake_user(id=63, in_guild=False) @@ -102,7 +97,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_diff_for_new_updated_and_leaving_users(self): """When users are added, updated, and removed, all of them are returned properly.""" new_user = fake_user(id=99, name="new") @@ -117,7 +111,6 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) - @helpers.async_test async def test_empty_diff_for_db_users_not_in_guild(self): """When the DB knows a user the guild doesn't, no difference is found.""" self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] @@ -129,14 +122,13 @@ class UserSyncerDiffTests(unittest.TestCase): self.assertEqual(actual_diff, expected_diff) -class UserSyncerSyncTests(unittest.TestCase): +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync users.""" def setUp(self): self.bot = helpers.MockBot() self.syncer = UserSyncer(self.bot) - @helpers.async_test async def test_sync_created_users(self): """Only POST requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] @@ -152,7 +144,6 @@ class UserSyncerSyncTests(unittest.TestCase): self.bot.api_client.put.assert_not_called() self.bot.api_client.delete.assert_not_called() - @helpers.async_test async def test_sync_updated_users(self): """Only PUT requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] diff --git a/tests/helpers.py b/tests/helpers.py index 7ae7ed621..8e13f0f28 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -261,9 +261,7 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `bot.api.APIClient` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=APIClient, **kwargs) + spec_set = APIClient # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -- cgit v1.2.3 From 0de8f42c122a4bf8f0ea84ea481d2f26d718a0c7 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 23 Feb 2020 22:00:34 -0800 Subject: Sync tests: use autospec instead of MockSyncer Autospec supports using AsyncMocks in 3.8 so there's no need to rely on a subclass of CustomMockMixin for the async mocks. --- tests/bot/cogs/sync/test_cog.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 8c87c0d6b..81398c61f 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -11,19 +11,6 @@ from tests import helpers from tests.base import CommandTestCase -class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): - """ - A MagicMock subclass to mock Syncer objects. - - Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` - instances. For more information, see the `MockGuild` docstring. - """ - spec_set = Syncer - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - - class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @@ -41,16 +28,15 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = helpers.MockBot() - # These patch the type. When the type is called, a MockSyncer instanced is returned. - # MockSyncer is needed so that our custom AsyncMock is used. - # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. self.role_syncer_patcher = mock.patch( "bot.cogs.sync.syncers.RoleSyncer", - new=mock.MagicMock(return_value=MockSyncer()) + autospec=Syncer, + spec_set=True ) self.user_syncer_patcher = mock.patch( "bot.cogs.sync.syncers.UserSyncer", - new=mock.MagicMock(return_value=MockSyncer()) + autospec=Syncer, + spec_set=True ) self.RoleSyncer = self.role_syncer_patcher.start() self.UserSyncer = self.user_syncer_patcher.start() -- cgit v1.2.3 From 3574eaa0c903cd8ed862b8bff896ce0a73412321 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Mon, 24 Feb 2020 10:06:09 +0100 Subject: Use MagicMock as return value for _get_diff mock The `_get_diff` method of TestSyncer class is mocked using an AsyncMock object. By default, when an AsyncMock object is called **and awaited**, it returns a child mock of the same time (another AsyncMock) according to the "the child is a like the parent" principle. This means that the _get_diff method will return an AsyncMock unless a different return_value is explicitly provided. Because of that "child is like parent" behavior, this will happen in lines 194-196 of bot.cogs.sync.syncers (annotations added by me): ``` // `diff` will be a child AsyncMock as "child is like parent" diff = await self._get_diff(guild) // `diff._asdict` will be an AsyncMock as "child is like parent" and, // after being called, it will return an unawaited coroutine object // we assign the name `diff_dict`: diff_dict = diff._asdict() // `diff_dict` is still an unawaited coroutine object meaning that it // doesn't have an `items()` method: totals = {k: len(v) for k, v in diff_dict.items() if v is not None} ``` Original, unannotated: https://github.com/python-discord/bot/blob/c81a4d401ea434e98b0a1ece51d3d10f1a3ad226/bot/cogs/sync/syncers.py#L194-L196 This will lead to the following exception when running the tests: ```py ====================================================================== ERROR: test_sync_confirmation_context_redirect (tests.bot.cogs.sync.test_base.SyncerSyncTests) (ctx=None, author=, message=None) If ctx is given, a new message should be sent and author should be ctx's author. ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/sebastiaan/pydis/repositories/bot/tests/bot/cogs/sync/test_base.py", line 348, in test_sync_confirmation_context_redirect await self.syncer.sync(guild, ctx) File "/home/sebastiaan/pydis/repositories/bot/bot/cogs/sync/syncers.py", line 196, in sync totals = {k: len(v) for k, v in diff_dict.items() if v is not None} AttributeError: 'coroutine' object has no attribute 'items' ``` The solution is to assign an explicit return value so the parent mock doesn't "guess" and return an object of its own type. I previously did that by providing a specific `_Diff` object as the return value, but I should have gone with a `MagicMock` to signify that it's not an important return value; it's just something that needs to support/mimic the API we use on it. So that's what this commit adds. --- tests/bot/cogs/sync/test_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index 17aa4198b..d17a27409 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -339,8 +339,8 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): if ctx is not None: ctx.send.return_value = message - diff = _Diff({1, 2, 3}, {4, 5}, None) - self.syncer._get_diff.return_value = diff + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) -- cgit v1.2.3 From c2af442676011eb620593505789be4d34da76ea3 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sat, 29 Feb 2020 17:06:51 +0100 Subject: Migrate snekbox tests to Python 3.8's unittest I've migrated the `tests/test_snekbox.py` file to use the new Python 3.8-style unittests instead of our old style using our custom Async mocks. In particular, I had to make a few changes: - Mocking the async post() context manager correctly Since `ClientSession.post` returns an async context manager when called, we need to make sure to assign the return value to the __aenter__ method of whatever `post()` returns, not of `post` itself (i.e.. when it's not called). - Use the new AsyncMock assert methods `assert_awaited_once` and `assert_awaited_once_with` Objects of the new `unittest.mock.AsyncMock` class have special methods to assert what they were called with that also assert that specific coroutine object was awaited. This means we test two things in one: Whether or not it was called with the right arguments and whether or not the returned coroutine object was then awaited. - Patch `functools.partial` as `partial` objects are compared by identity When you create two partial functions of the same function, you'll end up with two different `partial` objects. Since `partial` objects are compared by identity, you can't compare a `partial` created in a test method to that created in the callable you're trying to test. They will always compare as `False`. Since we're not interested in actually creating `partial` objects, I've just patched `functools.partial` in the namespace of the module we're testing to make sure we can compare them. --- tests/bot/cogs/test_snekbox.py | 68 +++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 41 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 985bc66a1..9cd7f0154 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,74 +1,68 @@ import asyncio import logging import unittest -from functools import partial -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch from bot.cogs import snekbox from bot.cogs.snekbox import Snekbox from bot.constants import URLs -from tests.helpers import ( - AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test -) +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser -class SnekboxTests(unittest.TestCase): +class SnekboxTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Add mocked bot and cog to the instance.""" self.bot = MockBot() - - self.mocked_post = MagicMock() - self.mocked_post.json = AsyncMock() - self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post)) - self.cog = Snekbox(bot=self.bot) - @async_test async def test_post_eval(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - self.mocked_post.json.return_value = {'lemon': 'AI'} + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + self.bot.http_session.post().__aenter__.return_value = resp - self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'}) - self.bot.http_session.post.assert_called_once_with( + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( URLs.snekbox_eval_api, json={"input": "import random"}, raise_for_status=True ) + resp.json.assert_awaited_once() - @async_test async def test_upload_output_reject_too_long(self): """Reject output longer than MAX_PASTE_LEN.""" result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) self.assertEqual(result, "too long to upload") - @async_test async def test_upload_output(self): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - key = "RainbowDash" - self.mocked_post.json.return_value = {"key": key} + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + self.bot.http_session.post().__aenter__.return_value = resp self.assertEqual( await self.cog.upload_output("My awesome output"), URLs.paste_service.format(key=key) ) - self.bot.http_session.post.assert_called_once_with( + self.bot.http_session.post.assert_called_with( URLs.paste_service.format(key="documents"), data="My awesome output", raise_for_status=True ) - @async_test async def test_upload_output_gracefully_fallback_if_exception_during_request(self): """Output upload gracefully fallback if the upload fail.""" - self.mocked_post.json.side_effect = Exception + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + self.bot.http_session.post().__aenter__.return_value = resp + log = logging.getLogger("bot.cogs.snekbox") with self.assertLogs(logger=log, level='ERROR'): await self.cog.upload_output('My awesome output!') - @async_test async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): """Output upload gracefully fallback if there is no key entry in the response body.""" - self.mocked_post.json.return_value = {} self.assertEqual((await self.cog.upload_output('My awesome output!')), None) def test_prepare_input(self): @@ -121,7 +115,6 @@ class SnekboxTests(unittest.TestCase): actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) self.assertEqual(actual, expected) - @async_test async def test_format_output(self): """Test output formatting.""" self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') @@ -172,7 +165,6 @@ class SnekboxTests(unittest.TestCase): with self.subTest(msg=testname, case=case, expected=expected): self.assertEqual(await self.cog.format_output(case), expected) - @async_test async def test_eval_command_evaluate_once(self): """Test the eval command procedure.""" ctx = MockContext() @@ -186,7 +178,6 @@ class SnekboxTests(unittest.TestCase): self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') self.cog.continue_eval.assert_called_once_with(ctx, response) - @async_test async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() @@ -201,7 +192,6 @@ class SnekboxTests(unittest.TestCase): self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') self.cog.continue_eval.assert_called_with(ctx, response) - @async_test async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" ctx = MockContext() @@ -214,7 +204,6 @@ class SnekboxTests(unittest.TestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - @async_test async def test_eval_command_call_help(self): """Test if the eval command call the help command if no code is provided.""" ctx = MockContext() @@ -222,14 +211,13 @@ class SnekboxTests(unittest.TestCase): await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") - @async_test async def test_send_eval(self): """Test the send_eval function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) + self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -244,14 +232,13 @@ class SnekboxTests(unittest.TestCase): self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.format_output.assert_called_once_with('') - @async_test async def test_send_eval_with_paste_link(self): """Test the send_eval function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) + self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') @@ -267,14 +254,12 @@ class SnekboxTests(unittest.TestCase): self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.format_output.assert_called_once_with('Way too long beard') - @async_test async def test_send_eval_with_non_zero_eval(self): """Test the send_eval function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') @@ -289,8 +274,8 @@ class SnekboxTests(unittest.TestCase): self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.format_output.assert_not_called() - @async_test - async def test_continue_eval_does_continue(self): + @patch("bot.cogs.snekbox.partial") + async def test_continue_eval_does_continue(self, partial_mock): """Test that the continue_eval function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) response = MockMessage(delete=AsyncMock()) @@ -299,15 +284,16 @@ class SnekboxTests(unittest.TestCase): actual = await self.cog.continue_eval(ctx, response) self.assertEqual(actual, 'NewCode') - self.bot.wait_for.has_calls( - call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10), - call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + self.bot.wait_for.assert_has_awaits( + ( + call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), + call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + ) ) ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) ctx.message.clear_reactions.assert_called_once() response.delete.assert_called_once() - @async_test async def test_continue_eval_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError -- cgit v1.2.3 From fc2224fc047fbbefdc17e3624ebc2854342c59c1 Mon Sep 17 00:00:00 2001 From: "Karlis. S" Date: Sun, 1 Mar 2020 09:35:34 +0200 Subject: !roles Command Test: Applied !roles command changes --- tests/bot/cogs/test_information.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 8443cfe71..bb4ebd9d0 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -45,10 +45,9 @@ class InformationCogTests(unittest.TestCase): _, kwargs = self.ctx.send.call_args embed = kwargs.pop('embed') - self.assertEqual(embed.title, "Role information") + self.assertEqual(embed.title, "Role information (Total 1 roles)") self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - self.assertEqual(embed.footer.text, "Total roles: 1") + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n\n") def test_role_info_command(self): """Tests the `role info` command.""" -- cgit v1.2.3 From 91c6bcd0dfbaad201ee47af2ee7e36e4f372a115 Mon Sep 17 00:00:00 2001 From: "S. Co1" Date: Sun, 1 Mar 2020 14:27:14 -0500 Subject: Modify log test regex to be non-os-specific Previous regex utilized a `/`, which doesn't work for comparing against Windows paths, which use `\` --- tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_base.py b/tests/test_base.py index 235a2ee6c..a7db4bf3e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -29,7 +29,7 @@ class LoggingTestCaseTests(unittest.TestCase): """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted.""" msg_regex = ( r"1 logs of DEBUG or higher were triggered on root:\n" - r'' + r'' ) with self.assertRaisesRegex(AssertionError, msg_regex): with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): -- cgit v1.2.3 From 28bcbf334eb08dfcd35b898b7cb803338664ee61 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 2 Mar 2020 09:42:00 -0800 Subject: Add more pre-commit hooks * Remove trailing whitespaces * Specify error code for a noqa in the free command --- .pre-commit-config.yaml | 23 ++++++++++++++++++++--- CONTRIBUTING.md | 2 +- bot/cogs/free.py | 2 +- tests/README.md | 10 +++++----- 4 files changed, 27 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 860357868..4bb5e7e1c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,27 @@ +exclude: ^\.cache/|\.venv/|\.git/|htmlcov/|logs/ repos: -- repo: local + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 hooks: - - id: flake8 + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + args: [--unsafe] # Required due to custom constructors (e.g. !ENV) + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.5.1 + hooks: + - id: python-check-blanket-noqa + - repo: local + hooks: + - id: flake8 name: Flake8 description: This hook runs flake8 within our project's pipenv environment. entry: pipenv run lint language: python types: [python] - require_serial: true \ No newline at end of file + require_serial: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39f76c7b4..61d11f844 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -43,7 +43,7 @@ To provide a standalone development environment for this project, docker compose When pulling down changes from GitHub, remember to sync your environment using `pipenv sync --dev` to ensure you're using the most up-to-date versions the project's dependencies. ### Type Hinting -[PEP 484](https://www.python.org/dev/peps/pep-0484/) formally specifies type hints for Python functions, added to the Python Standard Library in version 3.5. Type hints are recognized by most modern code editing tools and provide useful insight into both the input and output types of a function, preventing the user from having to go through the codebase to determine these types. +[PEP 484](https://www.python.org/dev/peps/pep-0484/) formally specifies type hints for Python functions, added to the Python Standard Library in version 3.5. Type hints are recognized by most modern code editing tools and provide useful insight into both the input and output types of a function, preventing the user from having to go through the codebase to determine these types. For example: diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 02c02d067..33b55e79a 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -55,7 +55,7 @@ class Free(Cog): msg = messages[seek - 1] # Otherwise get last message else: - msg = await channel.history(limit=1).next() # noqa (False positive) + msg = await channel.history(limit=1).next() # noqa: B305 inactive = (datetime.utcnow() - msg.created_at).seconds if inactive > TIMEOUT: diff --git a/tests/README.md b/tests/README.md index be78821bf..4f62edd68 100644 --- a/tests/README.md +++ b/tests/README.md @@ -83,7 +83,7 @@ TagContentConverter should return correct values for valid input. As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it. -However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". +However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). @@ -114,13 +114,13 @@ class BotCogTests(unittest.TestCase): ### Mocking coroutines -By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. +By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. ### Special mocks for some `discord.py` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. @@ -144,7 +144,7 @@ Finally, there are some considerations to make when writing tests, both for writ ### Test coverage is a starting point -Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work. +Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work. One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it: @@ -169,7 +169,7 @@ class FunctionsTests(unittest.TestCase): If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch? -The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`). +The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`). Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data: -- cgit v1.2.3 From aae928ebc06e7e7a6ed5b5b848464ce95e4ea9d8 Mon Sep 17 00:00:00 2001 From: "S. Co1" Date: Tue, 3 Mar 2020 22:53:19 -0500 Subject: Remove CaseInsensitiveDict This was added by the now-removed Snake cog & is not used elsewhere on bot. --- bot/utils/__init__.py | 57 ------------------------------------------------- tests/bot/test_utils.py | 37 -------------------------------- 2 files changed, 94 deletions(-) delete mode 100644 tests/bot/test_utils.py (limited to 'tests') diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 3e4b15ce4..9b32e515d 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,4 @@ from abc import ABCMeta -from typing import Any, Hashable from discord.ext.commands import CogMeta @@ -8,59 +7,3 @@ class CogABCMeta(CogMeta, ABCMeta): """Metaclass for ABCs meant to be implemented as Cogs.""" pass - - -class CaseInsensitiveDict(dict): - """ - We found this class on StackOverflow. Thanks to m000 for writing it! - - https://stackoverflow.com/a/32888599/4022104 - """ - - @classmethod - def _k(cls, key: Hashable) -> Hashable: - """Return lowered key if a string-like is passed, otherwise pass key straight through.""" - return key.lower() if isinstance(key, str) else key - - def __init__(self, *args, **kwargs): - super(CaseInsensitiveDict, self).__init__(*args, **kwargs) - self._convert_keys() - - def __getitem__(self, key: Hashable) -> Any: - """Case insensitive __setitem__.""" - return super(CaseInsensitiveDict, self).__getitem__(self.__class__._k(key)) - - def __setitem__(self, key: Hashable, value: Any): - """Case insensitive __setitem__.""" - super(CaseInsensitiveDict, self).__setitem__(self.__class__._k(key), value) - - def __delitem__(self, key: Hashable) -> Any: - """Case insensitive __delitem__.""" - return super(CaseInsensitiveDict, self).__delitem__(self.__class__._k(key)) - - def __contains__(self, key: Hashable) -> bool: - """Case insensitive __contains__.""" - return super(CaseInsensitiveDict, self).__contains__(self.__class__._k(key)) - - def pop(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive pop.""" - return super(CaseInsensitiveDict, self).pop(self.__class__._k(key), *args, **kwargs) - - def get(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive get.""" - return super(CaseInsensitiveDict, self).get(self.__class__._k(key), *args, **kwargs) - - def setdefault(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive setdefault.""" - return super(CaseInsensitiveDict, self).setdefault(self.__class__._k(key), *args, **kwargs) - - def update(self, E: Any = None, **F) -> None: - """Case insensitive update.""" - super(CaseInsensitiveDict, self).update(self.__class__(E)) - super(CaseInsensitiveDict, self).update(self.__class__(**F)) - - def _convert_keys(self) -> None: - """Helper method to lowercase all existing string-like keys.""" - for k in list(self.keys()): - v = super(CaseInsensitiveDict, self).pop(k) - self.__setitem__(k, v) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py deleted file mode 100644 index d7bcc3ba6..000000000 --- a/tests/bot/test_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import unittest - -from bot import utils - - -class CaseInsensitiveDictTests(unittest.TestCase): - """Tests for the `CaseInsensitiveDict` container.""" - - def test_case_insensitive_key_access(self): - """Tests case insensitive key access and storage.""" - instance = utils.CaseInsensitiveDict() - - key = 'LEMON' - value = 'trees' - - instance[key] = value - self.assertIn(key, instance) - self.assertEqual(instance.get(key), value) - self.assertEqual(instance.get(key.casefold()), value) - self.assertEqual(instance.pop(key.casefold()), value) - self.assertNotIn(key, instance) - self.assertNotIn(key.casefold(), instance) - - instance.setdefault(key, value) - del instance[key] - self.assertNotIn(key, instance) - - def test_initialization_from_kwargs(self): - """Tests creating the dictionary from keyword arguments.""" - instance = utils.CaseInsensitiveDict({'FOO': 'bar'}) - self.assertEqual(instance['foo'], 'bar') - - def test_update_from_other_mapping(self): - """Tests updating the dictionary from another mapping.""" - instance = utils.CaseInsensitiveDict() - instance.update({'FOO': 'bar'}) - self.assertEqual(instance['foo'], 'bar') -- cgit v1.2.3 From 2c85b2241bd8a1e7ca8290cd385cded97c54f9bb Mon Sep 17 00:00:00 2001 From: "S. Co1" Date: Tue, 3 Mar 2020 22:59:07 -0500 Subject: Update code for pep8-naming compliance --- tests/base.py | 4 ++-- tests/bot/cogs/sync/test_base.py | 2 +- tests/bot/cogs/test_snekbox.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 42174e911..d99b9ac31 100644 --- a/tests/base.py +++ b/tests/base.py @@ -31,7 +31,7 @@ class LoggingTestsMixin: """ @contextmanager - def assertNotLogs(self, logger=None, level=None, msg=None): + def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802 """ Asserts that no logs of `level` and higher were emitted by `logger`. @@ -81,7 +81,7 @@ class LoggingTestsMixin: class CommandTestCase(unittest.IsolatedAsyncioTestCase): """TestCase with additional assertions that are useful for testing Discord commands.""" - async def assertHasPermissionsCheck( + async def assertHasPermissionsCheck( # noqa: N802 self, cmd: commands.Command, permissions: Dict[str, bool], diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index fe0594efe..6ee9dfda6 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -84,7 +84,7 @@ class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): method.assert_called_once_with(constants.Channels.dev_core) - async def test_send_prompt_returns_None_if_channel_fetch_fails(self): + async def test_send_prompt_returns_none_if_channel_fetch_fails(self): """None should be returned if there's an HTTPException when fetching the channel.""" self.bot.get_channel.return_value = None self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 9cd7f0154..fd9468829 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -89,15 +89,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(actual, expected) @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_Signals: Mock): + def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( self.cog.get_results_message({'stdout': '', 'returncode': 127}), ('Your eval job has completed with return code 127', '') ) @patch('bot.cogs.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_Signals: Mock): - mock_Signals.return_value.name = 'SIGTEST' + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' self.assertEqual( self.cog.get_results_message({'stdout': '', 'returncode': 127}), ('Your eval job has completed with return code 127 (SIGTEST)', '') -- cgit v1.2.3 From 5579f2d32d5faadad778d64c50cf6fbefccf4f28 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 09:05:06 +0200 Subject: (Information Cog, !roles command test): Applied empty parameter change. --- tests/bot/cogs/test_information.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index c6fd937b8..7c265bba8 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -47,7 +47,7 @@ class InformationCogTests(unittest.TestCase): self.assertEqual(embed.title, "Role information (Total 1 roles)") self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n\n") + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") def test_role_info_command(self): """Tests the `role info` command.""" -- cgit v1.2.3 From 0b75d3f5e717f99f53522d4224abea6223ef6c84 Mon Sep 17 00:00:00 2001 From: ks123 Date: Thu, 5 Mar 2020 09:08:10 +0200 Subject: (Information Cog, !roles command test): Removed 's' at end of "Total 1 role(s)" due changes in command. --- tests/bot/cogs/test_information.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 7c265bba8..3c26374f5 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -45,7 +45,7 @@ class InformationCogTests(unittest.TestCase): _, kwargs = self.ctx.send.call_args embed = kwargs.pop('embed') - self.assertEqual(embed.title, "Role information (Total 1 roles)") + self.assertEqual(embed.title, "Role information (Total 1 role)") self.assertEqual(embed.colour, discord.Colour.blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") -- cgit v1.2.3 From d4253e106771f90a983717a994349d52337b2de9 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 8 Mar 2020 19:40:34 +0100 Subject: Add tests for FirstHash class. --- tests/bot/cogs/moderation/__init__.py | 0 tests/bot/cogs/moderation/test_silence.py | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/bot/cogs/moderation/__init__.py create mode 100644 tests/bot/cogs/moderation/test_silence.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/__init__.py b/tests/bot/cogs/moderation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py new file mode 100644 index 000000000..2a06f5944 --- /dev/null +++ b/tests/bot/cogs/moderation/test_silence.py @@ -0,0 +1,25 @@ +import unittest + +from bot.cogs.moderation.silence import FirstHash + + +class FirstHashTests(unittest.TestCase): + def setUp(self) -> None: + self.test_cases = ( + (FirstHash(0, 4), FirstHash(0, 5)), + (FirstHash("string", None), FirstHash("string", True)) + ) + + def test_hashes_equal(self): + """Check hashes equal with same first item.""" + + for tuple1, tuple2 in self.test_cases: + with self.subTest(tuple1=tuple1, tuple2=tuple2): + self.assertEqual(hash(tuple1), hash(tuple2)) + + def test_eq(self): + """Check objects are equal with same first item.""" + + for tuple1, tuple2 in self.test_cases: + with self.subTest(tuple1=tuple1, tuple2=tuple2): + self.assertTrue(tuple1 == tuple2) -- cgit v1.2.3 From e872176b452ceca1b639ef42d640e18656c7c0c9 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 8 Mar 2020 19:42:18 +0100 Subject: Add test case for Silence cog. --- tests/bot/cogs/moderation/test_silence.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 2a06f5944..1db2b6eec 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,6 +1,7 @@ import unittest -from bot.cogs.moderation.silence import FirstHash +from bot.cogs.moderation.silence import FirstHash, Silence +from tests.helpers import MockBot, MockContext class FirstHashTests(unittest.TestCase): @@ -23,3 +24,11 @@ class FirstHashTests(unittest.TestCase): for tuple1, tuple2 in self.test_cases: with self.subTest(tuple1=tuple1, tuple2=tuple2): self.assertTrue(tuple1 == tuple2) + + +class SilenceTests(unittest.TestCase): + def setUp(self) -> None: + + self.bot = MockBot() + self.cog = Silence(self.bot) + self.ctx = MockContext() -- cgit v1.2.3 From 1d83a5752aae483224129ee798e529f3d7d8e132 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 8 Mar 2020 19:42:51 +0100 Subject: Add test for `silence` discord output. --- tests/bot/cogs/moderation/test_silence.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 1db2b6eec..088410bee 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,6 +1,10 @@ +import asyncio import unittest +from functools import partial +from unittest import mock from bot.cogs.moderation.silence import FirstHash, Silence +from bot.constants import Emojis from tests.helpers import MockBot, MockContext @@ -32,3 +36,23 @@ class SilenceTests(unittest.TestCase): self.bot = MockBot() self.cog = Silence(self.bot) self.ctx = MockContext() + + def test_silence_sent_correct_discord_message(self): + """Check if proper message was sent when called with duration in channel with previous state.""" + test_cases = ( + ((self.cog, self.ctx, 0.0001), f"{Emojis.check_mark} #channel silenced for 0.0001 minute(s).", True,), + ((self.cog, self.ctx, None), f"{Emojis.check_mark} #channel silenced indefinitely.", True,), + ((self.cog, self.ctx, 5), f"{Emojis.cross_mark} #channel is already silenced.", False,), + ) + for silence_call_args, result_message, _silence_patch_return in test_cases: + with self.subTest( + silence_duration=silence_call_args[-1], + result_message=result_message, + starting_unsilenced_state=_silence_patch_return + ): + with mock.patch( + "bot.cogs.moderation.silence.Silence._silence", + new_callable=partial(mock.AsyncMock, return_value=_silence_patch_return) + ): + asyncio.run(self.cog.silence.callback(*silence_call_args)) + self.ctx.send.call_args.assert_called_once_with(result_message) -- cgit v1.2.3 From cfbe3b9742b5531bdced1d5b099739f01033a6bb Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 8 Mar 2020 22:20:00 +0100 Subject: Add test for `unsilence` discord output. --- tests/bot/cogs/moderation/test_silence.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 088410bee..17420ce7d 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -56,3 +56,12 @@ class SilenceTests(unittest.TestCase): ): asyncio.run(self.cog.silence.callback(*silence_call_args)) self.ctx.send.call_args.assert_called_once_with(result_message) + + def test_unsilence_sent_correct_discord_message(self): + """Check if proper message was sent to `alert_chanel`.""" + with mock.patch( + "bot.cogs.moderation.silence.Silence._unsilence", + new_callable=partial(mock.AsyncMock, return_value=True) + ): + asyncio.run(self.cog.unsilence.callback(self.cog, self.ctx)) + self.ctx.channel.send.call_args.assert_called_once_with(f"{Emojis.check_mark} Unsilenced #channel.") -- cgit v1.2.3 From ee94c38063981ee6770c1d263eab9c0d2e178380 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Mon, 9 Mar 2020 20:41:55 +0100 Subject: Use `patch.object` instead of patch with direct `return_value`. --- tests/bot/cogs/moderation/test_silence.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 17420ce7d..53b3fd388 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,6 +1,5 @@ import asyncio import unittest -from functools import partial from unittest import mock from bot.cogs.moderation.silence import FirstHash, Silence @@ -50,18 +49,12 @@ class SilenceTests(unittest.TestCase): result_message=result_message, starting_unsilenced_state=_silence_patch_return ): - with mock.patch( - "bot.cogs.moderation.silence.Silence._silence", - new_callable=partial(mock.AsyncMock, return_value=_silence_patch_return) - ): + with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): asyncio.run(self.cog.silence.callback(*silence_call_args)) self.ctx.send.call_args.assert_called_once_with(result_message) def test_unsilence_sent_correct_discord_message(self): """Check if proper message was sent to `alert_chanel`.""" - with mock.patch( - "bot.cogs.moderation.silence.Silence._unsilence", - new_callable=partial(mock.AsyncMock, return_value=True) - ): + with mock.patch.object(self.cog, "_unsilence", return_value=True): asyncio.run(self.cog.unsilence.callback(self.cog, self.ctx)) self.ctx.channel.send.call_args.assert_called_once_with(f"{Emojis.check_mark} Unsilenced #channel.") -- cgit v1.2.3 From 60814ee9270d4c550047478bf8d4a179d7351696 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 13:09:08 -0700 Subject: Cog tests: create boilerplate for command name tests --- tests/bot/cogs/test_cogs.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 tests/bot/cogs/test_cogs.py (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py new file mode 100644 index 000000000..6f5d07030 --- /dev/null +++ b/tests/bot/cogs/test_cogs.py @@ -0,0 +1,7 @@ +"""Test suite for general tests which apply to all cogs.""" + +import unittest + + +class CommandNameTests(unittest.TestCase): + """Tests for shadowing command names and aliases.""" -- cgit v1.2.3 From d31f7e3f4a4876d51119d5875afa9221b14b285e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 13:10:21 -0700 Subject: Cog tests: add a function to get all commands For tests, ideally creating instances of cogs should be avoided to avoid extra code execution. This function was copied over from discord.py because their function is not a static method, though it still works as one. It was probably just a design decision on their part to not make it static. --- tests/bot/cogs/test_cogs.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 6f5d07030..b128ca123 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -1,7 +1,19 @@ """Test suite for general tests which apply to all cogs.""" +import typing as t import unittest +from discord.ext import commands + class CommandNameTests(unittest.TestCase): """Tests for shadowing command names and aliases.""" + + @staticmethod + def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: + """An iterator that recursively walks through `cog`'s commands and subcommands.""" + for command in cog.__cog_commands__: + if command.parent is None: + yield command + if isinstance(command, commands.GroupMixin): + yield from command.walk_commands() -- cgit v1.2.3 From 85c439fbf78f59ff314f4f9daef1467d486709c3 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 00:01:58 +0100 Subject: Remove unnecessary args from test cases. Needless call args which were constant were kept in the test cases, resulting in redundant code, the args were moved directly into the function call. --- tests/bot/cogs/moderation/test_silence.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 53b3fd388..1341911d5 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -39,18 +39,18 @@ class SilenceTests(unittest.TestCase): def test_silence_sent_correct_discord_message(self): """Check if proper message was sent when called with duration in channel with previous state.""" test_cases = ( - ((self.cog, self.ctx, 0.0001), f"{Emojis.check_mark} #channel silenced for 0.0001 minute(s).", True,), - ((self.cog, self.ctx, None), f"{Emojis.check_mark} #channel silenced indefinitely.", True,), - ((self.cog, self.ctx, 5), f"{Emojis.cross_mark} #channel is already silenced.", False,), + (0.0001, f"{Emojis.check_mark} #channel silenced for 0.0001 minute(s).", True,), + (None, f"{Emojis.check_mark} #channel silenced indefinitely.", True,), + (5, f"{Emojis.cross_mark} #channel is already silenced.", False,), ) - for silence_call_args, result_message, _silence_patch_return in test_cases: + for duration, result_message, _silence_patch_return in test_cases: with self.subTest( - silence_duration=silence_call_args[-1], + silence_duration=duration, result_message=result_message, starting_unsilenced_state=_silence_patch_return ): with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - asyncio.run(self.cog.silence.callback(*silence_call_args)) + asyncio.run(self.cog.silence.callback(self.cog, self.ctx, duration)) self.ctx.send.call_args.assert_called_once_with(result_message) def test_unsilence_sent_correct_discord_message(self): -- cgit v1.2.3 From adaf456607ba2f2724c6fd34308cd170c81aa651 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 00:56:31 +0100 Subject: Remove channel mentions from output discord messages. With the removal of the channel args, it's no longer necessary to mention the channel in the command output. Tests adjusted accordingly --- bot/cogs/moderation/silence.py | 8 ++++---- tests/bot/cogs/moderation/test_silence.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index 76c5a171d..68cad4062 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -92,13 +92,13 @@ class Silence(commands.Cog): If duration is forever, start a notifier loop that triggers every 15 minutes. """ if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): - await ctx.send(f"{Emojis.cross_mark} {ctx.channel.mention} is already silenced.") + await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") return if duration is None: - await ctx.send(f"{Emojis.check_mark} {ctx.channel.mention} silenced indefinitely.") + await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") return - await ctx.send(f"{Emojis.check_mark} {ctx.channel.mention} silenced for {duration} minute(s).") + await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") await asyncio.sleep(duration*60) await ctx.invoke(self.unsilence) @@ -110,7 +110,7 @@ class Silence(commands.Cog): Unsilence a previously silenced `channel` and remove it from indefinitely muted channels notice if applicable. """ if await self._unsilence(ctx.channel): - await ctx.send(f"{Emojis.check_mark} Unsilenced {ctx.channel.mention}.") + await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: """ diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 1341911d5..6da374a8f 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -39,9 +39,9 @@ class SilenceTests(unittest.TestCase): def test_silence_sent_correct_discord_message(self): """Check if proper message was sent when called with duration in channel with previous state.""" test_cases = ( - (0.0001, f"{Emojis.check_mark} #channel silenced for 0.0001 minute(s).", True,), - (None, f"{Emojis.check_mark} #channel silenced indefinitely.", True,), - (5, f"{Emojis.cross_mark} #channel is already silenced.", False,), + (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), + (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), + (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), ) for duration, result_message, _silence_patch_return in test_cases: with self.subTest( @@ -57,4 +57,4 @@ class SilenceTests(unittest.TestCase): """Check if proper message was sent to `alert_chanel`.""" with mock.patch.object(self.cog, "_unsilence", return_value=True): asyncio.run(self.cog.unsilence.callback(self.cog, self.ctx)) - self.ctx.channel.send.call_args.assert_called_once_with(f"{Emojis.check_mark} Unsilenced #channel.") + self.ctx.send.call_args.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") -- cgit v1.2.3 From 68d43946d1dc6393a4f7b8b4812b5c4787842c12 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 01:28:26 +0100 Subject: Add test for `_silence` method. --- tests/bot/cogs/moderation/test_silence.py | 35 ++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 6da374a8f..6a75db2a0 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,10 +1,11 @@ import asyncio import unittest from unittest import mock +from unittest.mock import Mock from bot.cogs.moderation.silence import FirstHash, Silence from bot.constants import Emojis -from tests.helpers import MockBot, MockContext +from tests.helpers import MockBot, MockContext, MockTextChannel class FirstHashTests(unittest.TestCase): @@ -35,6 +36,7 @@ class SilenceTests(unittest.TestCase): self.bot = MockBot() self.cog = Silence(self.bot) self.ctx = MockContext() + self.cog._verified_role = None def test_silence_sent_correct_discord_message(self): """Check if proper message was sent when called with duration in channel with previous state.""" @@ -58,3 +60,34 @@ class SilenceTests(unittest.TestCase): with mock.patch.object(self.cog, "_unsilence", return_value=True): asyncio.run(self.cog.unsilence.callback(self.cog, self.ctx)) self.ctx.send.call_args.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") + + def test_silence_private_for_false(self): + """Permissions are not set and `False` is returned in an already silenced channel.""" + perm_overwrite = Mock(send_messages=False) + channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) + + self.assertFalse(asyncio.run(self.cog._silence(channel, True, None))) + channel.set_permissions.assert_not_called() + + def test_silence_private_silenced_channel(self): + """Channel had `send_message` permissions revoked and was added to `muted_channels`.""" + channel = MockTextChannel() + muted_channels = Mock() + with mock.patch.object(self.cog, "muted_channels", new=muted_channels, create=True): + self.assertTrue(asyncio.run(self.cog._silence(channel, False, None))) + channel.set_permissions.assert_called_once() + self.assertFalse(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) + muted_channels.add.call_args.assert_called_once_with(channel) + + def test_silence_private_notifier(self): + """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" + channel = MockTextChannel() + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=True): + asyncio.run(self.cog._silence(channel, True, None)) + self.cog.notifier.add_channel.assert_called_once() + + with mock.patch.object(self.cog, "notifier", create=True): + with self.subTest(persistent=False): + asyncio.run(self.cog._silence(channel, False, None)) + self.cog.notifier.add_channel.assert_not_called() -- cgit v1.2.3 From fef8c8e8504d8431ae7cad23128733d0b9039c7a Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 02:31:36 +0100 Subject: Use async test case. This allows us to use coroutines with await directly instead of asyncio.run --- tests/bot/cogs/moderation/test_silence.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 6a75db2a0..33ff78ca6 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio import unittest from unittest import mock from unittest.mock import Mock @@ -30,15 +29,14 @@ class FirstHashTests(unittest.TestCase): self.assertTrue(tuple1 == tuple2) -class SilenceTests(unittest.TestCase): +class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: - self.bot = MockBot() self.cog = Silence(self.bot) self.ctx = MockContext() self.cog._verified_role = None - def test_silence_sent_correct_discord_message(self): + async def test_silence_sent_correct_discord_message(self): """Check if proper message was sent when called with duration in channel with previous state.""" test_cases = ( (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), @@ -52,42 +50,42 @@ class SilenceTests(unittest.TestCase): starting_unsilenced_state=_silence_patch_return ): with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): - asyncio.run(self.cog.silence.callback(self.cog, self.ctx, duration)) + await self.cog.silence.callback(self.cog, self.ctx, duration) self.ctx.send.call_args.assert_called_once_with(result_message) - def test_unsilence_sent_correct_discord_message(self): + async def test_unsilence_sent_correct_discord_message(self): """Check if proper message was sent to `alert_chanel`.""" with mock.patch.object(self.cog, "_unsilence", return_value=True): - asyncio.run(self.cog.unsilence.callback(self.cog, self.ctx)) + await self.cog.unsilence.callback(self.cog, self.ctx) self.ctx.send.call_args.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") - def test_silence_private_for_false(self): + async def test_silence_private_for_false(self): """Permissions are not set and `False` is returned in an already silenced channel.""" perm_overwrite = Mock(send_messages=False) channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - self.assertFalse(asyncio.run(self.cog._silence(channel, True, None))) + self.assertFalse(await self.cog._silence(channel, True, None)) channel.set_permissions.assert_not_called() - def test_silence_private_silenced_channel(self): + async def test_silence_private_silenced_channel(self): """Channel had `send_message` permissions revoked and was added to `muted_channels`.""" channel = MockTextChannel() muted_channels = Mock() with mock.patch.object(self.cog, "muted_channels", new=muted_channels, create=True): - self.assertTrue(asyncio.run(self.cog._silence(channel, False, None))) + self.assertTrue(await self.cog._silence(channel, False, None)) channel.set_permissions.assert_called_once() self.assertFalse(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) muted_channels.add.call_args.assert_called_once_with(channel) - def test_silence_private_notifier(self): + async def test_silence_private_notifier(self): """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" channel = MockTextChannel() with mock.patch.object(self.cog, "notifier", create=True): with self.subTest(persistent=True): - asyncio.run(self.cog._silence(channel, True, None)) + await self.cog._silence(channel, True, None) self.cog.notifier.add_channel.assert_called_once() with mock.patch.object(self.cog, "notifier", create=True): with self.subTest(persistent=False): - asyncio.run(self.cog._silence(channel, False, None)) + await self.cog._silence(channel, False, None) self.cog.notifier.add_channel.assert_not_called() -- cgit v1.2.3 From c575beccdbe5e4e715a4b11b378dd969a0327191 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 02:47:07 +0100 Subject: Add tests for `_unsilence` --- tests/bot/cogs/moderation/test_silence.py | 34 ++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 33ff78ca6..acfa3ffb8 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -1,6 +1,6 @@ import unittest from unittest import mock -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock from bot.cogs.moderation.silence import FirstHash, Silence from bot.constants import Emojis @@ -89,3 +89,35 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): with self.subTest(persistent=False): await self.cog._silence(channel, False, None) self.cog.notifier.add_channel.assert_not_called() + + async def test_unsilence_private_for_false(self): + """Permissions are not set and `False` is returned in an unsilenced channel.""" + channel = Mock() + self.assertFalse(await self.cog._unsilence(channel)) + channel.set_permissions.assert_not_called() + + async def test_unsilence_private_unsilenced_channel(self): + """Channel had `send_message` permissions restored""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + with mock.patch.object(self.cog, "notifier", create=True): + self.assertTrue(await self.cog._unsilence(channel)) + channel.set_permissions.assert_called_once() + self.assertTrue(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) + + async def test_unsilence_private_removed_notifier(self): + """Channel was removed from `notifier` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + with mock.patch.object(self.cog, "notifier", create=True): + await self.cog._unsilence(channel) + self.cog.notifier.remove_channel.call_args.assert_called_once_with(channel) + + async def test_unsilence_private_removed_muted_channel(self): + """Channel was removed from `muted_channels` on unsilence.""" + perm_overwrite = MagicMock(send_messages=False) + channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) + with mock.patch.object(self.cog, "muted_channels", create=True),\ + mock.patch.object(self.cog, "notifier", create=True): # noqa E127 + await self.cog._unsilence(channel) + self.cog.muted_channels.remove.call_args.assert_called_once_with(channel) -- cgit v1.2.3 From a3f07589b215317d6a0fc16d982c3b645fe96151 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 02:54:35 +0100 Subject: Separate tests for permissions and `muted_channels.add` on `_silence`. --- tests/bot/cogs/moderation/test_silence.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index acfa3ffb8..3a513f3a7 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -68,14 +68,11 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel.set_permissions.assert_not_called() async def test_silence_private_silenced_channel(self): - """Channel had `send_message` permissions revoked and was added to `muted_channels`.""" + """Channel had `send_message` permissions revoked.""" channel = MockTextChannel() - muted_channels = Mock() - with mock.patch.object(self.cog, "muted_channels", new=muted_channels, create=True): - self.assertTrue(await self.cog._silence(channel, False, None)) + self.assertTrue(await self.cog._silence(channel, False, None)) channel.set_permissions.assert_called_once() self.assertFalse(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) - muted_channels.add.call_args.assert_called_once_with(channel) async def test_silence_private_notifier(self): """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" @@ -90,6 +87,12 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog._silence(channel, False, None) self.cog.notifier.add_channel.assert_not_called() + async def test_silence_private_removed_muted_channel(self): + channel = MockTextChannel() + with mock.patch.object(self.cog, "muted_channels") as muted_channels: + await self.cog._silence(MockTextChannel(), False, None) + muted_channels.add.call_args.assert_called_once_with(channel) + async def test_unsilence_private_for_false(self): """Permissions are not set and `False` is returned in an unsilenced channel.""" channel = Mock() -- cgit v1.2.3 From c72d31f717ac5e755fe3848c99ebf426fcdf6d8b Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 02:58:09 +0100 Subject: Use patch decorators and assign names from `with` patches. --- tests/bot/cogs/moderation/test_silence.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 3a513f3a7..027508661 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -99,28 +99,28 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.assertFalse(await self.cog._unsilence(channel)) channel.set_permissions.assert_not_called() - async def test_unsilence_private_unsilenced_channel(self): + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_unsilenced_channel(self, _): """Channel had `send_message` permissions restored""" perm_overwrite = MagicMock(send_messages=False) channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "notifier", create=True): - self.assertTrue(await self.cog._unsilence(channel)) + self.assertTrue(await self.cog._unsilence(channel)) channel.set_permissions.assert_called_once() self.assertTrue(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) - async def test_unsilence_private_removed_notifier(self): + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_notifier(self, notifier): """Channel was removed from `notifier` on unsilence.""" perm_overwrite = MagicMock(send_messages=False) channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "notifier", create=True): - await self.cog._unsilence(channel) - self.cog.notifier.remove_channel.call_args.assert_called_once_with(channel) + await self.cog._unsilence(channel) + notifier.remove_channel.call_args.assert_called_once_with(channel) - async def test_unsilence_private_removed_muted_channel(self): + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_removed_muted_channel(self, _): """Channel was removed from `muted_channels` on unsilence.""" perm_overwrite = MagicMock(send_messages=False) channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) - with mock.patch.object(self.cog, "muted_channels", create=True),\ - mock.patch.object(self.cog, "notifier", create=True): # noqa E127 + with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._unsilence(channel) - self.cog.muted_channels.remove.call_args.assert_called_once_with(channel) + muted_channels.remove.call_args.assert_called_once_with(channel) -- cgit v1.2.3 From 44967038f39f4ecd1375fb9edff2b972becb5661 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 14:51:15 +0100 Subject: Add test for `cog_unload`. --- tests/bot/cogs/moderation/test_silence.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 027508661..fc2600f5c 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -3,7 +3,7 @@ from unittest import mock from unittest.mock import MagicMock, Mock from bot.cogs.moderation.silence import FirstHash, Silence -from bot.constants import Emojis +from bot.constants import Emojis, Roles from tests.helpers import MockBot, MockContext, MockTextChannel @@ -124,3 +124,19 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._unsilence(channel) muted_channels.remove.call_args.assert_called_once_with(channel) + + @mock.patch("bot.cogs.moderation.silence.asyncio") + @mock.patch.object(Silence, "_mod_alerts_channel", create=True) + def test_cog_unload(self, alert_channel, asyncio_mock): + """Task for sending an alert was created with present `muted_channels`.""" + with mock.patch.object(self.cog, "muted_channels"): + self.cog.cog_unload() + asyncio_mock.create_task.call_args.assert_called_once_with( + alert_channel.send(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") + ) + + @mock.patch("bot.cogs.moderation.silence.asyncio") + def test_cog_unload1(self, asyncio_mock): + """No task created with no channels.""" + self.cog.cog_unload() + asyncio_mock.create_task.assert_not_called() -- cgit v1.2.3 From cb9397ba9ef311917629c8904087c1b3c38cc2d3 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 15:26:03 +0100 Subject: Add test for `cog_check`. --- tests/bot/cogs/moderation/test_silence.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index fc2600f5c..eaf897d1d 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -140,3 +140,10 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): """No task created with no channels.""" self.cog.cog_unload() asyncio_mock.create_task.assert_not_called() + + @mock.patch("bot.cogs.moderation.silence.with_role_check") + @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) + def test_cog_check(self, role_check): + """Role check is called with `MODERATION_ROLES`""" + self.cog.cog_check(self.ctx) + role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) -- cgit v1.2.3 From fbee48ee04dc6b44f97f229549c62cbfd5cef615 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 15:32:06 +0100 Subject: Fix erroneous `assert_called_once_with` calls. `assert_called_once_with` was being tested on call_args which always reported success.st. --- tests/bot/cogs/moderation/test_silence.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index eaf897d1d..4163a9af7 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -51,13 +51,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): ): with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): await self.cog.silence.callback(self.cog, self.ctx, duration) - self.ctx.send.call_args.assert_called_once_with(result_message) + self.ctx.send.assert_called_once_with(result_message) async def test_unsilence_sent_correct_discord_message(self): """Check if proper message was sent to `alert_chanel`.""" with mock.patch.object(self.cog, "_unsilence", return_value=True): await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.call_args.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") + self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") async def test_silence_private_for_false(self): """Permissions are not set and `False` is returned in an already silenced channel.""" @@ -91,7 +91,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel() with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._silence(MockTextChannel(), False, None) - muted_channels.add.call_args.assert_called_once_with(channel) + muted_channels.add.assert_called_once_with(channel) async def test_unsilence_private_for_false(self): """Permissions are not set and `False` is returned in an unsilenced channel.""" @@ -114,7 +114,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): perm_overwrite = MagicMock(send_messages=False) channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) await self.cog._unsilence(channel) - notifier.remove_channel.call_args.assert_called_once_with(channel) + notifier.remove_channel.assert_called_once_with(channel) @mock.patch.object(Silence, "notifier", create=True) async def test_unsilence_private_removed_muted_channel(self, _): @@ -123,7 +123,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._unsilence(channel) - muted_channels.remove.call_args.assert_called_once_with(channel) + muted_channels.remove.assert_called_once_with(channel) @mock.patch("bot.cogs.moderation.silence.asyncio") @mock.patch.object(Silence, "_mod_alerts_channel", create=True) @@ -131,9 +131,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): """Task for sending an alert was created with present `muted_channels`.""" with mock.patch.object(self.cog, "muted_channels"): self.cog.cog_unload() - asyncio_mock.create_task.call_args.assert_called_once_with( - alert_channel.send(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") - ) + asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) + alert_channel.send.called_once_with(f"<@&{Roles.moderators}> chandnels left silenced on cog unload: ") @mock.patch("bot.cogs.moderation.silence.asyncio") def test_cog_unload1(self, asyncio_mock): -- cgit v1.2.3 From 64b27e557acf268a19246b2eb80ad6a743df95f4 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 15:32:24 +0100 Subject: Reset `self.ctx` call history after every subtest. --- tests/bot/cogs/moderation/test_silence.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 4163a9af7..ab2f091ec 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -52,6 +52,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): await self.cog.silence.callback(self.cog, self.ctx, duration) self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() async def test_unsilence_sent_correct_discord_message(self): """Check if proper message was sent to `alert_chanel`.""" -- cgit v1.2.3 From 10428d9a456c7bce533cda53100e4c35930211d6 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 15:33:05 +0100 Subject: Pass created channel instead of new object. Creating a new object caused the assert to fail because different objects were used. --- tests/bot/cogs/moderation/test_silence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index ab2f091ec..23f8a84ab 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -91,7 +91,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_silence_private_removed_muted_channel(self): channel = MockTextChannel() with mock.patch.object(self.cog, "muted_channels") as muted_channels: - await self.cog._silence(MockTextChannel(), False, None) + await self.cog._silence(channel, False, None) muted_channels.add.assert_called_once_with(channel) async def test_unsilence_private_for_false(self): -- cgit v1.2.3 From b2aa9af7f9f1485aa3ae8ed4d029fd2d72ea17ad Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 16:02:35 +0100 Subject: Add tests for `_get_instance_vars`. --- tests/bot/cogs/moderation/test_silence.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 23f8a84ab..c9aa7d84f 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -3,7 +3,7 @@ from unittest import mock from unittest.mock import MagicMock, Mock from bot.cogs.moderation.silence import FirstHash, Silence -from bot.constants import Emojis, Roles +from bot.constants import Channels, Emojis, Guild, Roles from tests.helpers import MockBot, MockContext, MockTextChannel @@ -36,6 +36,33 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext() self.cog._verified_role = None + async def test_instance_vars_got_guild(self): + """Bot got guild after it became available.""" + await self.cog._get_instance_vars() + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(Guild.id) + + async def test_instance_vars_got_role(self): + """Got `Roles.verified` role from guild.""" + await self.cog._get_instance_vars() + guild = self.bot.get_guild() + guild.get_role.assert_called_once_with(Roles.verified) + + async def test_instance_vars_got_channels(self): + """Got channels from bot.""" + await self.cog._get_instance_vars() + self.bot.get_channel.called_once_with(Channels.mod_alerts) + self.bot.get_channel.called_once_with(Channels.mod_log) + + @mock.patch("bot.cogs.moderation.silence.SilenceNotifier") + async def test_instance_vars_got_notifier(self, notifier): + """Notifier was started with channel.""" + mod_log = MockTextChannel() + self.bot.get_channel.side_effect = (None, mod_log) + await self.cog._get_instance_vars() + notifier.assert_called_once_with(mod_log) + self.bot.get_channel.side_effect = None + async def test_silence_sent_correct_discord_message(self): """Check if proper message was sent when called with duration in channel with previous state.""" test_cases = ( -- cgit v1.2.3 From 8ee70ffe645621a6b97172176afe1ac63261df31 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 16:10:38 +0100 Subject: Create test case for `SilenceNotifier` --- tests/bot/cogs/moderation/test_silence.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index c9aa7d84f..fc7734d45 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -2,7 +2,7 @@ import unittest from unittest import mock from unittest.mock import MagicMock, Mock -from bot.cogs.moderation.silence import FirstHash, Silence +from bot.cogs.moderation.silence import FirstHash, Silence, SilenceNotifier from bot.constants import Channels, Emojis, Guild, Roles from tests.helpers import MockBot, MockContext, MockTextChannel @@ -29,6 +29,12 @@ class FirstHashTests(unittest.TestCase): self.assertTrue(tuple1 == tuple2) +class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.alert_channel = MockTextChannel() + self.notifier = SilenceNotifier(self.alert_channel) + + class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot() -- cgit v1.2.3 From fd75f10f3c8a588bd1763873baad08b8f90d58a3 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 17:09:40 +0100 Subject: Add tests for `add_channel`. --- tests/bot/cogs/moderation/test_silence.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index fc7734d45..be5b8e550 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -33,6 +33,27 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.alert_channel = MockTextChannel() self.notifier = SilenceNotifier(self.alert_channel) + self.notifier.stop = self.notifier_stop_mock = Mock() + self.notifier.start = self.notifier_start_mock = Mock() + self.notifier._current_loop = self.current_loop_mock = Mock() + + def test_add_channel_adds_channel(self): + """Channel in FirstHash with current loop is added to internal set.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.add_channel(channel) + silenced_channels.add.assert_called_with(FirstHash(channel, self.current_loop_mock)) + + def test_add_channel_starts_loop(self): + """Loop is started if `_silenced_channels` was empty.""" + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_called_once() + + def test_add_channel_skips_start_with_channels(self): + """Loop start is not called when `_silenced_channels` is not empty.""" + with mock.patch.object(self.notifier, "_silenced_channels"): + self.notifier.add_channel(Mock()) + self.notifier_start_mock.assert_not_called() class SilenceTests(unittest.IsolatedAsyncioTestCase): -- cgit v1.2.3 From d9c904164a9e54750ce8ee36535bceacfc4800f5 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 18:06:53 +0100 Subject: Remove `_current_loop` from setup. --- tests/bot/cogs/moderation/test_silence.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index be5b8e550..2e04dc407 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -35,14 +35,13 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.notifier = SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() self.notifier.start = self.notifier_start_mock = Mock() - self.notifier._current_loop = self.current_loop_mock = Mock() def test_add_channel_adds_channel(self): """Channel in FirstHash with current loop is added to internal set.""" channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.add_channel(channel) - silenced_channels.add.assert_called_with(FirstHash(channel, self.current_loop_mock)) + silenced_channels.add.assert_called_with(FirstHash(channel, self.notifier._current_loop)) def test_add_channel_starts_loop(self): """Loop is started if `_silenced_channels` was empty.""" -- cgit v1.2.3 From 4740c0fcdc6da6f164963fb34715e78c5d586cec Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 18:07:14 +0100 Subject: Add tests for `remove_channel`. --- tests/bot/cogs/moderation/test_silence.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 2e04dc407..c52ca2a2a 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -54,6 +54,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.notifier.add_channel(Mock()) self.notifier_start_mock.assert_not_called() + def test_remove_channel_removes_channel(self): + """Channel in FirstHash is removed from `_silenced_channels`.""" + channel = Mock() + with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: + self.notifier.remove_channel(channel) + silenced_channels.remove.assert_called_with(FirstHash(channel)) + + def test_remove_channel_stops_loop(self): + """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" + with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_called_once() + + def test_remove_channel_skips_stop_with_channels(self): + """Notifier loop is not stopped if `_silenced_channels` is not empty after remove.""" + self.notifier.remove_channel(Mock()) + self.notifier_stop_mock.assert_not_called() + class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: -- cgit v1.2.3 From 28cf22bcd98d94fa27e80dde4c86c9054b33c538 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Wed, 11 Mar 2020 21:55:09 +0100 Subject: Add tests for `_notifier`. --- tests/bot/cogs/moderation/test_silence.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index c52ca2a2a..d4719159e 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -72,6 +72,25 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): self.notifier.remove_channel(Mock()) self.notifier_stop_mock.assert_not_called() + async def test_notifier_private_sends_alert(self): + """Alert is sent on 15 min intervals.""" + test_cases = (900, 1800, 2700) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") + self.alert_channel.send.reset_mock() + + async def test_notifier_skips_alert(self): + """Alert is skipped on first loop or not an increment of 900.""" + test_cases = (0, 15, 5000) + for current_loop in test_cases: + with self.subTest(current_loop=current_loop): + with mock.patch.object(self.notifier, "_current_loop", new=current_loop): + await self.notifier._notifier() + self.alert_channel.send.assert_not_called() + class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: -- cgit v1.2.3 From d9ed24922f6daa17d625b345cb195e7fae7758cc Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 13:31:44 -0700 Subject: Cog tests: add a function to get all extensions --- tests/bot/cogs/test_cogs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index b128ca123..386299fb1 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -1,10 +1,15 @@ """Test suite for general tests which apply to all cogs.""" +import importlib +import pkgutil import typing as t import unittest +from types import ModuleType from discord.ext import commands +from bot import cogs + class CommandNameTests(unittest.TestCase): """Tests for shadowing command names and aliases.""" @@ -17,3 +22,9 @@ class CommandNameTests(unittest.TestCase): yield command if isinstance(command, commands.GroupMixin): yield from command.walk_commands() + + @staticmethod + def walk_extensions() -> t.Iterator[ModuleType]: + """Yield imported extensions (modules) from the bot.cogs subpackage.""" + for module in pkgutil.iter_modules(cogs.__path__, "bot.cogs."): + yield importlib.import_module(module.name) -- cgit v1.2.3 From b923c0f844f65275d90e3807aa8e3eadf3920252 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 13:37:56 -0700 Subject: Cog tests: add a function to get all cogs --- tests/bot/cogs/test_cogs.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 386299fb1..4290c279c 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -28,3 +28,10 @@ class CommandNameTests(unittest.TestCase): """Yield imported extensions (modules) from the bot.cogs subpackage.""" for module in pkgutil.iter_modules(cogs.__path__, "bot.cogs."): yield importlib.import_module(module.name) + + @staticmethod + def walk_cogs(extension: ModuleType) -> t.Iterator[commands.Cog]: + """Yield all cogs defined in an extension.""" + for name, cls in extension.__dict__.items(): + if isinstance(cls, commands.Cog): + yield getattr(extension, name) -- cgit v1.2.3 From 0358121687988159cb6754e249eed1ee2d40a783 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 13:56:58 -0700 Subject: Cog tests: add a function to get all qualified names for a cmd --- tests/bot/cogs/test_cogs.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 4290c279c..e28717756 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -35,3 +35,11 @@ class CommandNameTests(unittest.TestCase): for name, cls in extension.__dict__.items(): if isinstance(cls, commands.Cog): yield getattr(extension, name) + + @staticmethod + def get_qualified_names(command: commands.Command) -> t.List[str]: + """Return a list of all qualified names, including aliases, for the `command`.""" + names = [f"{command.full_parent_name} {alias}" for alias in command.aliases] + names.append(command.qualified_name) + + return names -- cgit v1.2.3 From 5419b3e9599e8bb2f519949aa268eb3a4b3adbcc Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 14:17:23 -0700 Subject: Cog tests: add a function to yield all commands This will help reduce nesting in the actual test. --- tests/bot/cogs/test_cogs.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index e28717756..d260b46a7 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -43,3 +43,10 @@ class CommandNameTests(unittest.TestCase): names.append(command.qualified_name) return names + + def get_all_commands(self) -> t.Iterator[commands.Command]: + """Yield all commands for all cogs in all extensions.""" + for extension in self.walk_extensions(): + for cog in self.walk_cogs(extension): + for cmd in self.walk_commands(cog): + yield cmd -- cgit v1.2.3 From 1b4def2c8c0d82fc9738c1e969404e305c91cac9 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 14:31:56 -0700 Subject: Cog tests: fix Cog type check in `walk_cogs` --- tests/bot/cogs/test_cogs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index d260b46a7..75aa1dbf6 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -32,9 +32,9 @@ class CommandNameTests(unittest.TestCase): @staticmethod def walk_cogs(extension: ModuleType) -> t.Iterator[commands.Cog]: """Yield all cogs defined in an extension.""" - for name, cls in extension.__dict__.items(): - if isinstance(cls, commands.Cog): - yield getattr(extension, name) + for obj in extension.__dict__.values(): + if isinstance(obj, type) and issubclass(obj, commands.Cog): + yield obj @staticmethod def get_qualified_names(command: commands.Command) -> t.List[str]: -- cgit v1.2.3 From 78327b9fa7c64a04d527fce582b93210356451fe Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 14:49:08 -0700 Subject: Cog tests: fix duplicate cogs being yielded Have to check the modules are equal to prevent yielding imported cogs. --- tests/bot/cogs/test_cogs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 75aa1dbf6..de0982c93 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -33,7 +33,8 @@ class CommandNameTests(unittest.TestCase): def walk_cogs(extension: ModuleType) -> t.Iterator[commands.Cog]: """Yield all cogs defined in an extension.""" for obj in extension.__dict__.values(): - if isinstance(obj, type) and issubclass(obj, commands.Cog): + is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) + if is_cog and obj.__module__ == extension.__name__: yield obj @staticmethod -- cgit v1.2.3 From bbcdf24a4b5d4f84834bbc8a8da7db2da627541f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 15:02:22 -0700 Subject: Cog tests: fix nested modules not being found * Rename `walk_extensions` to `walk_modules` because some extensions don't consist of a single module --- tests/bot/cogs/test_cogs.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index de0982c93..3a9f07db6 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -24,17 +24,21 @@ class CommandNameTests(unittest.TestCase): yield from command.walk_commands() @staticmethod - def walk_extensions() -> t.Iterator[ModuleType]: - """Yield imported extensions (modules) from the bot.cogs subpackage.""" - for module in pkgutil.iter_modules(cogs.__path__, "bot.cogs."): - yield importlib.import_module(module.name) + def walk_modules() -> t.Iterator[ModuleType]: + """Yield imported modules from the bot.cogs subpackage.""" + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) + + for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): + if not module.ispkg: + yield importlib.import_module(module.name) @staticmethod - def walk_cogs(extension: ModuleType) -> t.Iterator[commands.Cog]: + def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: """Yield all cogs defined in an extension.""" - for obj in extension.__dict__.values(): + for obj in module.__dict__.values(): is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) - if is_cog and obj.__module__ == extension.__name__: + if is_cog and obj.__module__ == module.__name__: yield obj @staticmethod @@ -47,7 +51,7 @@ class CommandNameTests(unittest.TestCase): def get_all_commands(self) -> t.Iterator[commands.Command]: """Yield all commands for all cogs in all extensions.""" - for extension in self.walk_extensions(): - for cog in self.walk_cogs(extension): + for module in self.walk_modules(): + for cog in self.walk_cogs(module): for cmd in self.walk_commands(cog): yield cmd -- cgit v1.2.3 From 02fe32879be51b3f202501ea8cdc5314ca3b77b2 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 15:29:19 -0700 Subject: Cog tests: fix duplicate commands being yielded discord.py yields duplicate Command objects for each alias a command has, so the duplicates need to be removed on our end. --- tests/bot/cogs/test_cogs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 3a9f07db6..9d1d4ebea 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -21,7 +21,8 @@ class CommandNameTests(unittest.TestCase): if command.parent is None: yield command if isinstance(command, commands.GroupMixin): - yield from command.walk_commands() + # Annoyingly it returns duplicates for each alias so use a set to fix that + yield from set(command.walk_commands()) @staticmethod def walk_modules() -> t.Iterator[ModuleType]: -- cgit v1.2.3 From 7e7c538435c899f45ff277e05fb59d139f401954 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Mar 2020 15:34:12 -0700 Subject: Cog tests: add a test for duplicate command names & aliases --- tests/bot/cogs/test_cogs.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 9d1d4ebea..616f5f44a 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -4,6 +4,7 @@ import importlib import pkgutil import typing as t import unittest +from collections import defaultdict from types import ModuleType from discord.ext import commands @@ -56,3 +57,19 @@ class CommandNameTests(unittest.TestCase): for cog in self.walk_cogs(module): for cmd in self.walk_commands(cog): yield cmd + + def test_names_dont_shadow(self): + """Names and aliases of commands should be unique.""" + all_names = defaultdict(list) + for cmd in self.get_all_commands(): + func_name = f"{cmd.module}.{cmd.callback.__qualname__}" + + for name in self.get_qualified_names(cmd): + with self.subTest(cmd=func_name, name=name): + if name in all_names: + conflicts = ", ".join(all_names.get(name, "")) + self.fail( + f"Name '{name}' of the command {func_name} conflicts with {conflicts}." + ) + + all_names[name].append(func_name) -- cgit v1.2.3 From f105ae75a98ae3e0295352d7debbc4fe04c73afd Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 13 Mar 2020 17:32:16 -0700 Subject: Cog tests: fix leading space in aliases without parents --- tests/bot/cogs/test_cogs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 616f5f44a..cbd203786 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -46,7 +46,7 @@ class CommandNameTests(unittest.TestCase): @staticmethod def get_qualified_names(command: commands.Command) -> t.List[str]: """Return a list of all qualified names, including aliases, for the `command`.""" - names = [f"{command.full_parent_name} {alias}" for alias in command.aliases] + names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] names.append(command.qualified_name) return names -- cgit v1.2.3 From 4d4975544ffec249aa6cd43d14987c00794caf99 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 13 Mar 2020 17:42:24 -0700 Subject: Cog tests: fix error on import due to discord.ext.tasks.loop The tasks extensions loop requires an event loop to exist. To work around this, it's been mocked. --- tests/bot/cogs/test_cogs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index cbd203786..db559ded6 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -6,6 +6,7 @@ import typing as t import unittest from collections import defaultdict from types import ModuleType +from unittest import mock from discord.ext import commands @@ -31,9 +32,10 @@ class CommandNameTests(unittest.TestCase): def on_error(name: str) -> t.NoReturn: raise ImportError(name=name) - for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): - if not module.ispkg: - yield importlib.import_module(module.name) + with mock.patch("discord.ext.tasks.loop"): + for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): + if not module.ispkg: + yield importlib.import_module(module.name) @staticmethod def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: -- cgit v1.2.3 From 252b385e46ef542203e69f4f6d147dadbcec8f0f Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 15 Mar 2020 16:11:35 +0100 Subject: Use dict instead of a set and custom class. The FirstHash class is no longer necessary with only channels and the current loop in tuples. FirstHash was removed, along with its tests and tests were adjusted for new dict behaviour. --- bot/cogs/moderation/silence.py | 24 +++++------------------- tests/bot/cogs/moderation/test_silence.py | 28 +++------------------------- 2 files changed, 8 insertions(+), 44 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index 8ed1cb28b..5df1fbbc0 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -15,26 +15,12 @@ from bot.utils.checks import with_role_check log = logging.getLogger(__name__) -class FirstHash(tuple): - """Tuple with only first item used for hash and eq.""" - - def __new__(cls, *args): - """Construct tuple from `args`.""" - return super().__new__(cls, args) - - def __hash__(self): - return hash((self[0],)) - - def __eq__(self, other: "FirstHash"): - return self[0] == other[0] - - class SilenceNotifier(tasks.Loop): """Loop notifier for posting notices to `alert_channel` containing added channels.""" def __init__(self, alert_channel: TextChannel): super().__init__(self._notifier, seconds=1, minutes=0, hours=0, count=None, reconnect=True, loop=None) - self._silenced_channels = set() + self._silenced_channels = {} self._alert_channel = alert_channel def add_channel(self, channel: TextChannel) -> None: @@ -42,12 +28,12 @@ class SilenceNotifier(tasks.Loop): if not self._silenced_channels: self.start() log.info("Starting notifier loop.") - self._silenced_channels.add(FirstHash(channel, self._current_loop)) + self._silenced_channels[channel] = self._current_loop def remove_channel(self, channel: TextChannel) -> None: """Remove channel from `_silenced_channels` and stop loop if no channels remain.""" with suppress(KeyError): - self._silenced_channels.remove(FirstHash(channel)) + del self._silenced_channels[channel] if not self._silenced_channels: self.stop() log.info("Stopping notifier loop.") @@ -58,11 +44,11 @@ class SilenceNotifier(tasks.Loop): if self._current_loop and not self._current_loop/60 % 15: log.debug( f"Sending notice with channels: " - f"{', '.join(f'#{channel} ({channel.id})' for channel, _ in self._silenced_channels)}." + f"{', '.join(f'#{channel} ({channel.id})' for channel in self._silenced_channels)}." ) channels_text = ', '.join( f"{channel.mention} for {(self._current_loop-start)//60} min" - for channel, start in self._silenced_channels + for channel, start in self._silenced_channels.items() ) await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index d4719159e..6114fee21 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -2,33 +2,11 @@ import unittest from unittest import mock from unittest.mock import MagicMock, Mock -from bot.cogs.moderation.silence import FirstHash, Silence, SilenceNotifier +from bot.cogs.moderation.silence import Silence, SilenceNotifier from bot.constants import Channels, Emojis, Guild, Roles from tests.helpers import MockBot, MockContext, MockTextChannel -class FirstHashTests(unittest.TestCase): - def setUp(self) -> None: - self.test_cases = ( - (FirstHash(0, 4), FirstHash(0, 5)), - (FirstHash("string", None), FirstHash("string", True)) - ) - - def test_hashes_equal(self): - """Check hashes equal with same first item.""" - - for tuple1, tuple2 in self.test_cases: - with self.subTest(tuple1=tuple1, tuple2=tuple2): - self.assertEqual(hash(tuple1), hash(tuple2)) - - def test_eq(self): - """Check objects are equal with same first item.""" - - for tuple1, tuple2 in self.test_cases: - with self.subTest(tuple1=tuple1, tuple2=tuple2): - self.assertTrue(tuple1 == tuple2) - - class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.alert_channel = MockTextChannel() @@ -41,7 +19,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.add_channel(channel) - silenced_channels.add.assert_called_with(FirstHash(channel, self.notifier._current_loop)) + silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) def test_add_channel_starts_loop(self): """Loop is started if `_silenced_channels` was empty.""" @@ -59,7 +37,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): channel = Mock() with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: self.notifier.remove_channel(channel) - silenced_channels.remove.assert_called_with(FirstHash(channel)) + silenced_channels.__delitem__.assert_called_with(channel) def test_remove_channel_stops_loop(self): """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" -- cgit v1.2.3 From 36c57c6f89a070fbb77a641182e37c788b6de7a0 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 15 Mar 2020 18:21:49 +0100 Subject: Adjust tests for new calling behaviour. `.set_permissions` calls were changed to use kwargs directly instead of an overwrite, this reflects the changes in tests. --- tests/bot/cogs/moderation/test_silence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 6114fee21..b09426fde 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -141,7 +141,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel() self.assertTrue(await self.cog._silence(channel, False, None)) channel.set_permissions.assert_called_once() - self.assertFalse(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) + self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) async def test_silence_private_notifier(self): """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" @@ -175,7 +175,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) self.assertTrue(await self.cog._unsilence(channel)) channel.set_permissions.assert_called_once() - self.assertTrue(channel.set_permissions.call_args.kwargs['overwrite'].send_messages) + self.assertTrue(channel.set_permissions.call_args.kwargs['send_messages']) @mock.patch.object(Silence, "notifier", create=True) async def test_unsilence_private_removed_notifier(self, notifier): -- cgit v1.2.3 From 0a2774fadddd18a86822a47599ebc4b76f1e5a7e Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Sun, 15 Mar 2020 18:23:11 +0100 Subject: Set `_get_instance_vars_event` in test's `setUp`. --- tests/bot/cogs/moderation/test_silence.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index b09426fde..c6f1fc1da 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -76,6 +76,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.cog = Silence(self.bot) self.ctx = MockContext() self.cog._verified_role = None + # Set event so command callbacks can continue. + self.cog._get_instance_vars_event.set() async def test_instance_vars_got_guild(self): """Bot got guild after it became available.""" -- cgit v1.2.3 From b8559cc12fa75dd4b4a52697cf5aa313d3c397d0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 16 Mar 2020 10:27:21 -0700 Subject: Cog tests: comment some code for clarification --- tests/bot/cogs/test_cogs.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index db559ded6..39f6492cb 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -19,6 +19,7 @@ class CommandNameTests(unittest.TestCase): @staticmethod def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]: """An iterator that recursively walks through `cog`'s commands and subcommands.""" + # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods. for command in cog.__cog_commands__: if command.parent is None: yield command @@ -32,6 +33,7 @@ class CommandNameTests(unittest.TestCase): def on_error(name: str) -> t.NoReturn: raise ImportError(name=name) + # The mock prevents asyncio.get_event_loop() from being called. with mock.patch("discord.ext.tasks.loop"): for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error): if not module.ispkg: @@ -41,6 +43,7 @@ class CommandNameTests(unittest.TestCase): def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]: """Yield all cogs defined in an extension.""" for obj in module.__dict__.values(): + # Check if it's a class type cause otherwise issubclass() may raise a TypeError. is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog) if is_cog and obj.__module__ == module.__name__: yield obj -- cgit v1.2.3 From 039a04462be58e9d345e32efcae13c8c999776db Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 13:23:44 +0100 Subject: Fix `test_cog_unload` passing tests with invalid values. The first assert - `asyncio_mock.create_task.assert_called_once_with` called `alert_channel`'s send resulting in an extra call. `send` on `alert_channel` was not tested properly because of a typo and a missing assert in the method call. --- tests/bot/cogs/moderation/test_silence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index c6f1fc1da..febfd584b 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -202,8 +202,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): """Task for sending an alert was created with present `muted_channels`.""" with mock.patch.object(self.cog, "muted_channels"): self.cog.cog_unload() + alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - alert_channel.send.called_once_with(f"<@&{Roles.moderators}> chandnels left silenced on cog unload: ") @mock.patch("bot.cogs.moderation.silence.asyncio") def test_cog_unload1(self, asyncio_mock): -- cgit v1.2.3 From 2803c13c477634ceefe3501ad9cb7c76cfecf450 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 17:10:16 +0100 Subject: Rename `cog_unload` tests. Previous names were undescriptive from testing phases. --- tests/bot/cogs/moderation/test_silence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index febfd584b..07a70e7dc 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -198,7 +198,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): @mock.patch("bot.cogs.moderation.silence.asyncio") @mock.patch.object(Silence, "_mod_alerts_channel", create=True) - def test_cog_unload(self, alert_channel, asyncio_mock): + def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): """Task for sending an alert was created with present `muted_channels`.""" with mock.patch.object(self.cog, "muted_channels"): self.cog.cog_unload() @@ -206,7 +206,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) @mock.patch("bot.cogs.moderation.silence.asyncio") - def test_cog_unload1(self, asyncio_mock): + def test_cog_unload_skips_task_start(self, asyncio_mock): """No task created with no channels.""" self.cog.cog_unload() asyncio_mock.create_task.assert_not_called() -- cgit v1.2.3 From 20c41f2c5af6fd716c3e7f15de412f7f16f5ff1e Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 17:15:10 +0100 Subject: Remove one indentation level. Co-authored-by: MarkKoz --- tests/bot/cogs/moderation/test_silence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 07a70e7dc..8b9e30cfe 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -115,9 +115,9 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): ) for duration, result_message, _silence_patch_return in test_cases: with self.subTest( - silence_duration=duration, - result_message=result_message, - starting_unsilenced_state=_silence_patch_return + silence_duration=duration, + result_message=result_message, + starting_unsilenced_state=_silence_patch_return ): with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): await self.cog.silence.callback(self.cog, self.ctx, duration) -- cgit v1.2.3 From d456e40ac97a38ee99561546bcafb6aa94117cb7 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 17:16:31 +0100 Subject: Remove `alert_channel` mention from docstring. After removing the optional channel arg and changing output message channels we're only testing `ctx`'s `send`. --- tests/bot/cogs/moderation/test_silence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 8b9e30cfe..b4a34bbc7 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -125,7 +125,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() async def test_unsilence_sent_correct_discord_message(self): - """Check if proper message was sent to `alert_chanel`.""" + """Proper reply after a successful unsilence.""" with mock.patch.object(self.cog, "_unsilence", return_value=True): await self.cog.unsilence.callback(self.cog, self.ctx) self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") -- cgit v1.2.3 From 386c93a6f18adbf84691f17b13f5113800a353ae Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 19:40:27 +0100 Subject: Fix test name. `removed` was describing the opposite behaviour. --- tests/bot/cogs/moderation/test_silence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index b4a34bbc7..55193e2f8 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -158,7 +158,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog._silence(channel, False, None) self.cog.notifier.add_channel.assert_not_called() - async def test_silence_private_removed_muted_channel(self): + async def test_silence_private_added_muted_channel(self): channel = MockTextChannel() with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._silence(channel, False, None) -- cgit v1.2.3 From dced6fdf5f571b82bc975dd3159af57c6f9a12b3 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 19:41:13 +0100 Subject: Add docstring to test. --- tests/bot/cogs/moderation/test_silence.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 55193e2f8..71541086d 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -159,6 +159,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.cog.notifier.add_channel.assert_not_called() async def test_silence_private_added_muted_channel(self): + """Channel was added to `muted_channels` on silence.""" channel = MockTextChannel() with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._silence(channel, False, None) -- cgit v1.2.3 From c68b943708eaca110ddfa6121872513a422bbef4 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 20:10:50 +0100 Subject: Use set `discard` instead of `remove`. Discard ignores non present values, allowing us to skip the KeyError suppress. --- bot/cogs/moderation/silence.py | 3 +-- tests/bot/cogs/moderation/test_silence.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index 1523baf11..a1446089e 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -141,8 +141,7 @@ class Silence(commands.Cog): await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=True)) log.info(f"Unsilenced channel #{channel} ({channel.id}).") self.notifier.remove_channel(channel) - with suppress(KeyError): - self.muted_channels.remove(channel) + self.muted_channels.discard(channel) return True log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") return False diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 71541086d..eee020455 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -195,7 +195,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) with mock.patch.object(self.cog, "muted_channels") as muted_channels: await self.cog._unsilence(channel) - muted_channels.remove.assert_called_once_with(channel) + muted_channels.discard.assert_called_once_with(channel) @mock.patch("bot.cogs.moderation.silence.asyncio") @mock.patch.object(Silence, "_mod_alerts_channel", create=True) -- cgit v1.2.3 From cd429230fcb18c7101afd931317d37ad142bfe4b Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 20:11:44 +0100 Subject: Add tests ensuring permissions get preserved. --- tests/bot/cogs/moderation/test_silence.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index eee020455..44682a1bd 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -2,6 +2,8 @@ import unittest from unittest import mock from unittest.mock import MagicMock, Mock +from discord import PermissionOverwrite + from bot.cogs.moderation.silence import Silence, SilenceNotifier from bot.constants import Channels, Emojis, Guild, Roles from tests.helpers import MockBot, MockContext, MockTextChannel @@ -145,6 +147,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel.set_permissions.assert_called_once() self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) + async def test_silence_private_preserves_permissions(self): + """Previous permissions were preserved when channel was silenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite() + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._silence(channel, False, None) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + async def test_silence_private_notifier(self): """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" channel = MockTextChannel() @@ -197,6 +213,21 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog._unsilence(channel) muted_channels.discard.assert_called_once_with(channel) + @mock.patch.object(Silence, "notifier", create=True) + async def test_unsilence_private_preserves_permissions(self, _): + """Previous permissions were preserved when channel was unsilenced.""" + channel = MockTextChannel() + # Set up mock channel permission state. + mock_permissions = PermissionOverwrite(send_messages=False) + mock_permissions_dict = dict(mock_permissions) + channel.overwrites_for.return_value = mock_permissions + await self.cog._unsilence(channel) + new_permissions = channel.set_permissions.call_args.kwargs + # Remove 'send_messages' key because it got changed in the method. + del new_permissions['send_messages'] + del mock_permissions_dict['send_messages'] + self.assertDictEqual(mock_permissions_dict, new_permissions) + @mock.patch("bot.cogs.moderation.silence.asyncio") @mock.patch.object(Silence, "_mod_alerts_channel", create=True) def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): -- cgit v1.2.3 From cefcc575b6faa94fb18f1985f039125d023b2580 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Tue, 17 Mar 2020 22:18:58 +0100 Subject: Add tests for `HushDurationConverter`. --- tests/bot/test_converters.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'tests') diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 1e5ca62ae..ca8cb6825 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -8,6 +8,7 @@ from discord.ext.commands import BadArgument from bot.converters import ( Duration, + HushDurationConverter, ISODateTime, TagContentConverter, TagNameConverter, @@ -271,3 +272,32 @@ class ConverterTests(unittest.TestCase): exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" with self.assertRaises(BadArgument, msg=exception_message): asyncio.run(converter.convert(self.context, datetime_string)) + + def test_hush_duration_converter_for_valid(self): + """HushDurationConverter returns correct value for minutes duration or `"forever"` strings.""" + test_values = ( + ("0", 0), + ("15", 15), + ("10", 10), + ("5m", 5), + ("5M", 5), + ("forever", None), + ) + converter = HushDurationConverter() + for minutes_string, expected_minutes in test_values: + with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes): + converted = asyncio.run(converter.convert(self.context, minutes_string)) + self.assertEqual(expected_minutes, converted) + + def test_hush_duration_converter_for_invalid(self): + """HushDurationConverter raises correct exception for invalid minutes duration strings.""" + test_values = ( + ("16", "Duration must be at most 15 minutes."), + ("10d", "10d is not a valid minutes duration."), + ("-1", "-1 is not a valid minutes duration."), + ) + converter = HushDurationConverter() + for invalid_minutes_string, exception_message in test_values: + with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message): + with self.assertRaisesRegex(BadArgument, exception_message): + asyncio.run(converter.convert(self.context, invalid_minutes_string)) -- cgit v1.2.3 From 430c616ec4ec60a5ddb1e66d3aacc622c9a78ae6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 25 Mar 2020 12:21:57 -0700 Subject: Snekbox tests: test `get_code` Should return 1st arg (or None) if eval cmd in message, otherwise return full content. --- tests/bot/cogs/test_snekbox.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index fd9468829..1fad6904b 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -3,9 +3,11 @@ import logging import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from discord.ext import commands + +from bot import constants from bot.cogs import snekbox from bot.cogs.snekbox import Snekbox -from bot.constants import URLs from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser @@ -23,7 +25,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(await self.cog.post_eval("import random"), "return") self.bot.http_session.post.assert_called_with( - URLs.snekbox_eval_api, + constants.URLs.snekbox_eval_api, json={"input": "import random"}, raise_for_status=True ) @@ -43,10 +45,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual( await self.cog.upload_output("My awesome output"), - URLs.paste_service.format(key=key) + constants.URLs.paste_service.format(key=key) ) self.bot.http_session.post.assert_called_with( - URLs.paste_service.format(key="documents"), + constants.URLs.paste_service.format(key="documents"), data="My awesome output", raise_for_status=True ) @@ -302,6 +304,32 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(actual, None) ctx.message.clear_reactions.assert_called_once() + async def test_get_code(self): + """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" + prefix = constants.Bot.prefix + subtests = ( + (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"), + (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None), + (MagicMock(spec=commands.Command), f"{prefix}tags get foo"), + (None, "print(123)") + ) + + for command, content, *expected_code in subtests: + if not expected_code: + expected_code = content + else: + [expected_code] = expected_code + + with self.subTest(content=content, expected_code=expected_code): + self.bot.get_context.reset_mock() + self.bot.get_context.return_value = MockContext(command=command) + message = MockMessage(content=content) + + actual_code = await self.cog.get_code(message) + + self.bot.get_context.assert_awaited_once_with(message) + self.assertEqual(actual_code, expected_code) + def test_predicate_eval_message_edit(self): """Test the predicate_eval_message_edit function.""" msg0 = MockMessage(id=1, content='abc') -- cgit v1.2.3 From c3e9a290a93c978a4dfec3ab121a0e45147855c8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 25 Mar 2020 14:08:34 -0700 Subject: Snekbox tests: use `get_code` in `test_continue_eval_does_continue` --- tests/bot/cogs/test_snekbox.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1fad6904b..1dec0ccaf 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -1,7 +1,7 @@ import asyncio import logging import unittest -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch from discord.ext import commands @@ -281,11 +281,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test that the continue_eval function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) response = MockMessage(delete=AsyncMock()) - new_msg = MockMessage(content='!e NewCode') + new_msg = MockMessage() self.bot.wait_for.side_effect = ((None, new_msg), None) + expected = "NewCode" + self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) actual = await self.cog.continue_eval(ctx, response) - self.assertEqual(actual, 'NewCode') + self.cog.get_code.assert_awaited_once_with(new_msg) + self.assertEqual(actual, expected) self.bot.wait_for.assert_has_awaits( ( call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), -- cgit v1.2.3 From 582ddbb1ca8bab2cb883781911f5f35962330995 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Mon, 30 Mar 2020 13:36:34 +0200 Subject: Set unsilence permissions to inherit instead of true The "unsilence" action of the silence/hush command used `send_messages=True` when unsilencing a hushed channel. This had the side effect of also enabling send messages permissions for those with the Muted rule, as an explicit True permission apparently overwrites an explicit False permission, even if the latter was set for a higher top-role. The solution is to revert back to the `Inherit` permission by assigning `None`. This is what we normally use when Developers are allowed to send messages to a channel. --- bot/cogs/moderation/silence.py | 2 +- tests/bot/cogs/moderation/test_silence.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index a1446089e..1ef3967a9 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -138,7 +138,7 @@ class Silence(commands.Cog): """ current_overwrite = channel.overwrites_for(self._verified_role) if current_overwrite.send_messages is False: - await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=True)) + await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) log.info(f"Unsilenced channel #{channel} ({channel.id}).") self.notifier.remove_channel(channel) self.muted_channels.discard(channel) diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 44682a1bd..3fd149f04 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -194,7 +194,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) self.assertTrue(await self.cog._unsilence(channel)) channel.set_permissions.assert_called_once() - self.assertTrue(channel.set_permissions.call_args.kwargs['send_messages']) + self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) @mock.patch.object(Silence, "notifier", create=True) async def test_unsilence_private_removed_notifier(self, notifier): -- cgit v1.2.3 From 7434ed3152e6d3f89babe2fef332983925d04434 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 9 Apr 2020 14:45:32 +0300 Subject: (Syncer Tests): Replaced wrong side effect Replaced `TimeoutError` with `asyncio.TimeoutError`. --- tests/bot/cogs/sync/test_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py index 6ee9dfda6..70aea2bab 100644 --- a/tests/bot/cogs/sync/test_base.py +++ b/tests/bot/cogs/sync/test_base.py @@ -1,3 +1,4 @@ +import asyncio import unittest from unittest import mock @@ -211,7 +212,7 @@ class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): subtests = ( (constants.Emojis.check_mark, True, None), ("InVaLiD", False, None), - (None, False, TimeoutError), + (None, False, asyncio.TimeoutError), ) for emoji, ret_val, side_effect in subtests: -- cgit v1.2.3 From 085decd12867f89a0803806928741fe6dd3c76bb Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 15 Apr 2020 08:18:19 +0300 Subject: (Test Helpers): Added `__ge__` function to `MockRole` for comparing. --- tests/helpers.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 8e13f0f28..227bac95f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -205,6 +205,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position < other.position + def __ge__(self, other): + """Simplified position-based comparisons similar to those of `discord.Role`.""" + return self.position >= other.position + # Create a Member instance to get a realistic Mock of `discord.Member` member_data = {'user': 'lemon', 'roles': [1]} -- cgit v1.2.3 From 81f6efc2f4e9e157e2f7fb9f191ea410af066632 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 16 Apr 2020 11:15:16 +0300 Subject: (Infraction Tests): Created reason shortening tests for ban and kick. --- tests/bot/cogs/moderation/test_infractions.py | 54 +++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/bot/cogs/moderation/test_infractions.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..39ea93952 --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,54 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class ShorteningTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason shortening.""" + + def setUp(self): + self.bot = MockBot() + self.cog = Infractions(self.bot) + self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) + self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) + self.guild = MockGuild(id=4567) + self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + + @patch("bot.cogs.moderation.utils.has_active_infraction") + @patch("bot.cogs.moderation.utils.post_infraction") + async def test_apply_ban_reason_shortening(self, post_infraction_mock, has_active_mock): + """Should truncate reason for `ctx.guild.ban`.""" + has_active_mock.return_value = False + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.bot.get_cog.return_value = AsyncMock() + self.cog.mod_log.ignore = Mock() + + await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) + ban = self.cog.apply_infraction.call_args[0][3] + self.assertEqual( + ban.cr_frame.f_locals["kwargs"]["reason"], + textwrap.shorten("foo bar" * 3000, 512, placeholder=" ...") + ) + # Await ban to avoid warning + await ban + + @patch("bot.cogs.moderation.utils.post_infraction") + async def test_apply_kick_reason_shortening(self, post_infraction_mock) -> None: + """Should truncate reason for `Member.kick`.""" + post_infraction_mock.return_value = {"foo": "bar"} + + self.cog.apply_infraction = AsyncMock() + self.cog.mod_log.ignore = Mock() + + await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) + kick = self.cog.apply_infraction.call_args[0][3] + self.assertEqual( + kick.cr_frame.f_locals["kwargs"]["reason"], + textwrap.shorten("foo bar" * 3000, 512, placeholder="...") + ) + await kick -- cgit v1.2.3 From 216953044a870f2440fe44fcd2f9ca3ee7cf37e9 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 16 Apr 2020 11:30:09 +0300 Subject: (ModLog Tests): Created reason shortening tests for `send_log_message`. --- tests/bot/cogs/moderation/test_modlog.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/bot/cogs/moderation/test_modlog.py (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..46e01d2ea --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): + """Tests for moderation logs.""" + + def setUp(self): + self.bot = MockBot() + self.cog = ModLog(self.bot) + self.channel = MockTextChannel() + + async def test_log_entry_description_shortening(self): + """Should truncate embed description for ModLog entry.""" + self.bot.get_channel.return_value = self.channel + await self.cog.send_log_message( + icon_url="foo", + colour=discord.Colour.blue(), + title="bar", + text="foo bar" * 3000 + ) + embed = self.channel.send.call_args[1]["embed"] + self.assertEqual( + embed.description, ("foo bar" * 3000)[:2046] + "..." + ) -- cgit v1.2.3 From 1a3fa6a395141c4fcdd1d388d6ce3e7bd89bcbf0 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Thu, 16 Apr 2020 13:40:47 +0300 Subject: (Infractions and ModLog Tests): Replaced `shortening` with `truncation`, removed unnecessary type hint and added comment to kick truncation test about awaiting `kick`. --- tests/bot/cogs/moderation/test_infractions.py | 9 +++++---- tests/bot/cogs/moderation/test_modlog.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 39ea93952..51a8cc645 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -6,8 +6,8 @@ from bot.cogs.moderation.infractions import Infractions from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole -class ShorteningTests(unittest.IsolatedAsyncioTestCase): - """Tests for ban and kick command reason shortening.""" +class TruncationTests(unittest.IsolatedAsyncioTestCase): + """Tests for ban and kick command reason truncation.""" def setUp(self): self.bot = MockBot() @@ -19,7 +19,7 @@ class ShorteningTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.has_active_infraction") @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_ban_reason_shortening(self, post_infraction_mock, has_active_mock): + async def test_apply_ban_reason_truncation(self, post_infraction_mock, has_active_mock): """Should truncate reason for `ctx.guild.ban`.""" has_active_mock.return_value = False post_infraction_mock.return_value = {"foo": "bar"} @@ -38,7 +38,7 @@ class ShorteningTests(unittest.IsolatedAsyncioTestCase): await ban @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_kick_reason_shortening(self, post_infraction_mock) -> None: + async def test_apply_kick_reason_truncation(self, post_infraction_mock): """Should truncate reason for `Member.kick`.""" post_infraction_mock.return_value = {"foo": "bar"} @@ -51,4 +51,5 @@ class ShorteningTests(unittest.IsolatedAsyncioTestCase): kick.cr_frame.f_locals["kwargs"]["reason"], textwrap.shorten("foo bar" * 3000, 512, placeholder="...") ) + # Await kick to avoid warning await kick diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py index 46e01d2ea..d60836474 100644 --- a/tests/bot/cogs/moderation/test_modlog.py +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -14,7 +14,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.cog = ModLog(self.bot) self.channel = MockTextChannel() - async def test_log_entry_description_shortening(self): + async def test_log_entry_description_truncation(self): """Should truncate embed description for ModLog entry.""" self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( -- cgit v1.2.3 From 1140e9690644e46196a1c8cad900272ffb3ae09a Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 20 Apr 2020 18:46:30 +0200 Subject: Replace `in_channel` decorator by `in_whitelisted_context` The `in_channel` decorator that served as a factory for `in_channel` checks was replaced by the broaded `in_whitelisted_context` decorator. This means that we can now whitelist commands using channel IDs, category IDs, and/or role IDs. The whitelists will be applied in an "OR" fashion, meaning that as soon as some part of the context happens to be whitelisted, the `predicate` check the decorator produces will return `True`. To reflect that this is now a broader decorator that checks for a whitelisted *context* (as opposed to just whitelisted channels), the exception the predicate raises has been changed to `InWhitelistedContextCheckFailure` to reflect the broader scope of the decorator. I've updated all the commands that used the previous version, `in_channel`, to use the replacement. --- bot/cogs/error_handler.py | 6 +-- bot/cogs/information.py | 10 +++-- bot/cogs/snekbox.py | 11 ++++- bot/cogs/utils.py | 8 +++- bot/cogs/verification.py | 18 +++++--- bot/decorators.py | 84 +++++++++++++++++++++++++------------- tests/bot/cogs/test_information.py | 4 +- 7 files changed, 94 insertions(+), 47 deletions(-) (limited to 'tests') diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index dae283c6a..3f56a9798 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import TagNameConverter -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure log = logging.getLogger(__name__) @@ -202,7 +202,7 @@ class ErrorHandler(Cog): * BotMissingRole * BotMissingAnyRole * NoPrivateMessage - * InChannelCheckFailure + * InWhitelistedContextCheckFailure """ bot_missing_errors = ( errors.BotMissingPermissions, @@ -215,7 +215,7 @@ class ErrorHandler(Cog): await ctx.send( f"Sorry, it looks like I don't have the permissions or roles I need to do that." ) - elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)): + elif isinstance(e, (InWhitelistedContextCheckFailure, errors.NoPrivateMessage)): ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") await ctx.send(e) diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 7921a4932..6b3fc0c96 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot -from bot.decorators import InChannelCheckFailure, in_channel, with_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, with_role from bot.pagination import LinePaginator from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -152,7 +152,7 @@ class Information(Cog): # Non-staff may only do this in #bot-commands if not with_role_check(ctx, *constants.STAFF_ROLES): if not ctx.channel.id == constants.Channels.bot_commands: - raise InChannelCheckFailure(constants.Channels.bot_commands) + raise InWhitelistedContextCheckFailure(constants.Channels.bot_commands) embed = await self.create_user_embed(ctx, user) @@ -331,7 +331,11 @@ class Information(Cog): @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) @group(invoke_without_command=True) - @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + whitelisted_roles=constants.STAFF_ROLES, + redirect_channel=constants.Channels.bot_commands, + ) async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: """Shows information about the raw API response.""" # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 315383b12..8827cb585 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Channels, Roles, URLs -from bot.decorators import in_channel +from bot.decorators import in_whitelisted_context from bot.utils.messages import wait_for_deletion log = logging.getLogger(__name__) @@ -38,6 +38,9 @@ RAW_CODE_REGEX = re.compile( ) MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 @@ -265,7 +268,11 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) + @in_whitelisted_context( + whitelisted_channels=EVAL_CHANNELS, + whitelisted_roles=EVAL_ROLES, + redirect_channel=Channels.bot_commands, + ) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 3ed471bbf..234ec514d 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -13,7 +13,7 @@ from discord.ext.commands import BadArgument, Cog, Context, command from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES -from bot.decorators import in_channel, with_role +from bot.decorators import in_whitelisted_context, with_role from bot.utils.time import humanize_delta log = logging.getLogger(__name__) @@ -118,7 +118,11 @@ class Utils(Cog): await ctx.message.channel.send(embed=pep_embed) @command() - @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES) + @in_whitelisted_context( + whitelisted_channels=(Channels.bot_commands,), + whitelisted_roles=STAFF_ROLES, + redirect_channel=Channels.bot_commands, + ) async def charinfo(self, ctx: Context, *, characters: str) -> None: """Shows you information on up to 25 unicode characters.""" match = re.match(r"<(a?):(\w+):(\d+)>", characters) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b0a493e68..040f52fbf 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot from bot.cogs.moderation import ModLog -from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, without_role from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class Verification(Cog): @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) @without_role(constants.Roles.verified) - @in_channel(constants.Channels.verification) + @in_whitelisted_context(whitelisted_channels=(constants.Channels.verification,)) async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Accept our rules and gain access to the rest of the server.""" log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") @@ -138,7 +138,10 @@ class Verification(Cog): await ctx.message.delete() @command(name='subscribe') - @in_channel(constants.Channels.bot_commands) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + redirect_channel=constants.Channels.bot_commands, + ) async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Subscribe to announcement notifications by assigning yourself the role.""" has_role = False @@ -162,7 +165,10 @@ class Verification(Cog): ) @command(name='unsubscribe') - @in_channel(constants.Channels.bot_commands) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + redirect_channel=constants.Channels.bot_commands, + ) async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Unsubscribe from announcement notifications by removing the role from yourself.""" has_role = False @@ -187,8 +193,8 @@ class Verification(Cog): # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InChannelCheckFailure.""" - if isinstance(error, InChannelCheckFailure): + """Check for & ignore any InWhitelistedContextCheckFailure.""" + if isinstance(error, InWhitelistedContextCheckFailure): error.handled = True @staticmethod diff --git a/bot/decorators.py b/bot/decorators.py index 2d18eaa6a..149564d18 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -3,7 +3,7 @@ import random from asyncio import Lock, sleep from contextlib import suppress from functools import wraps -from typing import Callable, Container, Union +from typing import Callable, Container, Optional, Union from weakref import WeakValueDictionary from discord import Colour, Embed, Member @@ -17,48 +17,74 @@ from bot.utils.checks import with_role_check, without_role_check log = logging.getLogger(__name__) -class InChannelCheckFailure(CheckFailure): - """Raised when a check fails for a message being sent in a whitelisted channel.""" +class InWhitelistedContextCheckFailure(CheckFailure): + """Raised when the `in_whitelist` check fails.""" - def __init__(self, *channels: int): - self.channels = channels - channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) + def __init__(self, redirect_channel: Optional[int] = None): + error_message = "Sorry, but you are not allowed to use that command here." - super().__init__(f"Sorry, but you may only use this command within {channels_str}.") + if redirect_channel: + error_message += f" Please use the <#{redirect_channel}> channel instead." + super().__init__(error_message) + + +def in_whitelisted_context( + *, + whitelisted_channels: Container[int] = (), + whitelisted_categories: Container[int] = (), + whitelisted_roles: Container[int] = (), + redirect_channel: Optional[int] = None, -def in_channel( - *channels: int, - hidden_channels: Container[int] = None, - bypass_roles: Container[int] = None ) -> Callable: """ - Checks that the message is in a whitelisted channel or optionally has a bypass role. + Check if a command was issued in a whitelisted context. + + The whitelists that can be provided are: - Hidden channels are channels which will not be displayed in the InChannelCheckFailure error - message. + - `channels`: a container with channel ids for whitelisted channels + - `categories`: a container with category ids for whitelisted categories + - `roles`: a container with with role ids for whitelisted roles + + An optional `redirect_channel` can be provided to redirect users that are not + authorized to use the command in the current context. If no such channel is + provided, the users are simply told that they are not authorized to use the + command. """ - hidden_channels = hidden_channels or [] - bypass_roles = bypass_roles or [] + if redirect_channel and redirect_channel not in whitelisted_channels: + # It does not make sense for the channel whitelist to not contain the redirection + # channel (if provided). That's why we add the redirection channel to the `channels` + # container if it's not already in it. As we allow any container type to be passed, + # we first create a tuple in order to safely add the redirection channel. + # + # Note: It's possible for the redirect channel to be in a whitelisted category, but + # there's no easy way to check that and as a channel can easily be moved in and out of + # categories, it's probably not wise to rely on its category in any case. + whitelisted_channels = tuple(whitelisted_channels) + (redirect_channel,) def predicate(ctx: Context) -> bool: - """In-channel checker predicate.""" - if ctx.channel.id in channels or ctx.channel.id in hidden_channels: - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The command was used in a whitelisted channel.") + """Check if a command was issued in a whitelisted context.""" + if whitelisted_channels and ctx.channel.id in whitelisted_channels: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") return True - if bypass_roles: - if any(r.id in bypass_roles for r in ctx.author.roles): - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The command was not used in a whitelisted channel, " - f"but the author had a role to bypass the in_channel check.") - return True + # Only check the category id if we have a category whitelist and the channel has a `category_id` + if ( + whitelisted_categories + and hasattr(ctx.channel, "category_id") + and ctx.channel.category_id in whitelisted_categories + ): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") + return True - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The in_channel check failed.") + # Only check the roles whitelist if we have one and ensure the author's roles attribute returns + # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). + if whitelisted_roles and any(r.id in whitelisted_roles for r in getattr(ctx.author, "roles", ())): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") + return True - raise InChannelCheckFailure(*channels) + log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + raise InWhitelistedContextCheckFailure(redirect_channel) return commands.check(predicate) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 3c26374f5..4a36fe030 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,7 +7,7 @@ import discord from bot import constants from bot.cogs import information -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure from tests import helpers @@ -525,7 +525,7 @@ class UserCommandTests(unittest.TestCase): ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InChannelCheckFailure, msg=msg): + with self.assertRaises(InWhitelistedContextCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) -- cgit v1.2.3 From 00291d7d5f859e4131cb5c94541a90f80f358376 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 20 Apr 2020 18:53:31 +0200 Subject: Remove vestigial kwargs from MockTextChannel.__init__ --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 8e13f0f28..9001deedf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -315,7 +315,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ spec_set = channel_instance - def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: + def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -- cgit v1.2.3 From 57e69925af9a941dfe32acc0431a9699eda027f5 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 20 Apr 2020 18:57:12 +0200 Subject: Add tests for `in_whitelisted_context` decorator I have added tests for the new `in_whitelisted_context` decorator. They work by calling the decorator with different kwargs to generate a specific predicate callable. That callable is then called to assess if it comes to the right conclusion. --- tests/bot/test_decorators.py | 115 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 tests/bot/test_decorators.py (limited to 'tests') diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py new file mode 100644 index 000000000..fae7c0c52 --- /dev/null +++ b/tests/bot/test_decorators.py @@ -0,0 +1,115 @@ +import collections +import unittest +import unittest.mock + +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context +from tests import helpers + + +WhitelistedContextTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx")) + + +class InWhitelistedContextTests(unittest.TestCase): + """Tests for the `in_whitelisted_context` check.""" + + @classmethod + def setUpClass(cls): + """Set up helpers that only need to be defined once.""" + cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456) + cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654) + cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666) + + cls.non_staff_member = helpers.MockMember() + cls.staff_role = helpers.MockRole(id=121212) + cls.staff_member = helpers.MockMember(roles=(cls.staff_role,)) + + cls.whitelisted_channels = (cls.bot_commands.id,) + cls.whitelisted_categories = (cls.help_channel.category_id,) + cls.whitelisted_roles = (cls.staff_role.id,) + + def test_predicate_returns_true_for_whitelisted_context(self): + """The predicate should return `True` if a whitelisted context was passed to it.""" + test_cases = ( + # Commands issued in whitelisted channels by members without whitelisted roles + WhitelistedContextTestCase( + kwargs={"whitelisted_channels": self.whitelisted_channels}, + ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) + ), + # `redirect_channel` should be added implicitly to the `whitelisted_channels` + WhitelistedContextTestCase( + kwargs={"redirect_channel": self.bot_commands.id}, + ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) + ), + + # Commands issued in a whitelisted category by members without whitelisted roles + WhitelistedContextTestCase( + kwargs={"whitelisted_categories": self.whitelisted_categories}, + ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member) + ), + + # Command issued by a staff member in a non-whitelisted channel/category + WhitelistedContextTestCase( + kwargs={"whitelisted_roles": self.whitelisted_roles}, + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member) + ), + + # With all kwargs provided + WhitelistedContextTestCase( + kwargs={ + "whitelisted_channels": self.whitelisted_channels, + "whitelisted_categories": self.whitelisted_categories, + "whitelisted_roles": self.whitelisted_roles, + "redirect_channel": self.bot_commands, + }, + ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member) + ), + ) + + for test_case in test_cases: + # patch `commands.check` with a no-op lambda that just returns the predicate passed to it + # so we can test the predicate that was generated from the specified kwargs. + with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): + predicate = in_whitelisted_context(**test_case.kwargs) + + with self.subTest(test_case=test_case): + self.assertTrue(predicate(test_case.ctx)) + + def test_predicate_raises_exception_for_non_whitelisted_context(self): + """The predicate should raise `InWhitelistedContextCheckFailure` for a non-whitelisted context.""" + test_cases = ( + # Failing check with `redirect_channel` + WhitelistedContextTestCase( + kwargs={ + "whitelisted_categories": self.whitelisted_categories, + "whitelisted_channels": self.whitelisted_channels, + "whitelisted_roles": self.whitelisted_roles, + "redirect_channel": self.bot_commands.id, + }, + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) + ), + + # Failing check without `redirect_channel` + WhitelistedContextTestCase( + kwargs={ + "whitelisted_categories": self.whitelisted_categories, + "whitelisted_channels": self.whitelisted_channels, + "whitelisted_roles": self.whitelisted_roles, + }, + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) + ), + ) + + for test_case in test_cases: + # Create expected exception message based on whether or not a redirect channel was provided + expected_message = "Sorry, but you are not allowed to use that command here." + if test_case.kwargs.get("redirect_channel"): + expected_message += f" Please use the <#{test_case.kwargs['redirect_channel']}> channel instead." + + # patch `commands.check` with a no-op lambda that just returns the predicate passed to it + # so we can test the predicate that was generated from the specified kwargs. + with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): + predicate = in_whitelisted_context(**test_case.kwargs) + + with self.subTest(test_case=test_case): + with self.assertRaises(InWhitelistedContextCheckFailure, msg=expected_message): + predicate(test_case.ctx) -- cgit v1.2.3 From b20bb7471b8d1d01f217f0620f8597bf1bae4456 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Thu, 23 Apr 2020 15:51:58 +0200 Subject: Simplify `in_whitelisted_context` decorator API The API of the `in_whitelisted_context` decorator was a bit clunky: - The long parameter names frequently required multiline decorators - Despite `#bot-commands` being the defacto default, it needed to be passed - The name of the function, `in_whitelisted_context` is fairly long in itself To shorten the call length of the decorator, the parameter names were shortened by dropping the `whitelisted_` prefix. This means that the parameter names are now just `channels`, `categories`, and `roles`. This already means that all current usages of the decorator are reduced to one line. In addition, `#bot-commands` has now been made the default redirect channel for the decorator. This means that if no `redirect` was passed, users will be redirected to `bot-commands` to use the command. If needed, `None` (or any falsey value) can be passed to disable redirection. Passing another channel id will trigger that channel to be used as the redirection target instead of bot-commands. Finally, the name of the decorator was shortened to `in_whitelist`, which already communicates what it is supposed to do. --- bot/cogs/error_handler.py | 6 ++-- bot/cogs/information.py | 10 ++----- bot/cogs/snekbox.py | 9 ++---- bot/cogs/utils.py | 8 ++--- bot/cogs/verification.py | 18 ++++-------- bot/decorators.py | 49 +++++++++++++++---------------- tests/bot/cogs/test_information.py | 4 +-- tests/bot/test_decorators.py | 60 +++++++++++++++++++------------------- 8 files changed, 72 insertions(+), 92 deletions(-) (limited to 'tests') diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 3f56a9798..b2f4c59f6 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import TagNameConverter -from bot.decorators import InWhitelistedContextCheckFailure +from bot.decorators import InWhitelistCheckFailure log = logging.getLogger(__name__) @@ -202,7 +202,7 @@ class ErrorHandler(Cog): * BotMissingRole * BotMissingAnyRole * NoPrivateMessage - * InWhitelistedContextCheckFailure + * InWhitelistCheckFailure """ bot_missing_errors = ( errors.BotMissingPermissions, @@ -215,7 +215,7 @@ class ErrorHandler(Cog): await ctx.send( f"Sorry, it looks like I don't have the permissions or roles I need to do that." ) - elif isinstance(e, (InWhitelistedContextCheckFailure, errors.NoPrivateMessage)): + elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") await ctx.send(e) diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 6b3fc0c96..4eb36c340 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot -from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, with_role +from bot.decorators import InWhitelistCheckFailure, in_whitelist, with_role from bot.pagination import LinePaginator from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -152,7 +152,7 @@ class Information(Cog): # Non-staff may only do this in #bot-commands if not with_role_check(ctx, *constants.STAFF_ROLES): if not ctx.channel.id == constants.Channels.bot_commands: - raise InWhitelistedContextCheckFailure(constants.Channels.bot_commands) + raise InWhitelistCheckFailure(constants.Channels.bot_commands) embed = await self.create_user_embed(ctx, user) @@ -331,11 +331,7 @@ class Information(Cog): @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) @group(invoke_without_command=True) - @in_whitelisted_context( - whitelisted_channels=(constants.Channels.bot_commands,), - whitelisted_roles=constants.STAFF_ROLES, - redirect_channel=constants.Channels.bot_commands, - ) + @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES) async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: """Shows information about the raw API response.""" # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 4999074b6..8d4688114 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelisted_context +from bot.decorators import in_whitelist from bot.utils.messages import wait_for_deletion log = logging.getLogger(__name__) @@ -269,12 +269,7 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_whitelisted_context( - whitelisted_channels=EVAL_CHANNELS, - whitelisted_categories=EVAL_CATEGORIES, - whitelisted_roles=EVAL_ROLES, - redirect_channel=Channels.bot_commands, - ) + @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 234ec514d..8023eb962 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -13,7 +13,7 @@ from discord.ext.commands import BadArgument, Cog, Context, command from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES -from bot.decorators import in_whitelisted_context, with_role +from bot.decorators import in_whitelist, with_role from bot.utils.time import humanize_delta log = logging.getLogger(__name__) @@ -118,11 +118,7 @@ class Utils(Cog): await ctx.message.channel.send(embed=pep_embed) @command() - @in_whitelisted_context( - whitelisted_channels=(Channels.bot_commands,), - whitelisted_roles=STAFF_ROLES, - redirect_channel=Channels.bot_commands, - ) + @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) async def charinfo(self, ctx: Context, *, characters: str) -> None: """Shows you information on up to 25 unicode characters.""" match = re.match(r"<(a?):(\w+):(\d+)>", characters) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 040f52fbf..388b7a338 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot from bot.cogs.moderation import ModLog -from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, without_role +from bot.decorators import InWhitelistCheckFailure, in_whitelist, without_role from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class Verification(Cog): @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) @without_role(constants.Roles.verified) - @in_whitelisted_context(whitelisted_channels=(constants.Channels.verification,)) + @in_whitelist(channels=(constants.Channels.verification,)) async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Accept our rules and gain access to the rest of the server.""" log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") @@ -138,10 +138,7 @@ class Verification(Cog): await ctx.message.delete() @command(name='subscribe') - @in_whitelisted_context( - whitelisted_channels=(constants.Channels.bot_commands,), - redirect_channel=constants.Channels.bot_commands, - ) + @in_whitelist(channels=(constants.Channels.bot_commands,)) async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Subscribe to announcement notifications by assigning yourself the role.""" has_role = False @@ -165,10 +162,7 @@ class Verification(Cog): ) @command(name='unsubscribe') - @in_whitelisted_context( - whitelisted_channels=(constants.Channels.bot_commands,), - redirect_channel=constants.Channels.bot_commands, - ) + @in_whitelist(channels=(constants.Channels.bot_commands,)) async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Unsubscribe from announcement notifications by removing the role from yourself.""" has_role = False @@ -193,8 +187,8 @@ class Verification(Cog): # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InWhitelistedContextCheckFailure.""" - if isinstance(error, InWhitelistedContextCheckFailure): + """Check for & ignore any InWhitelistCheckFailure.""" + if isinstance(error, InWhitelistCheckFailure): error.handled = True @staticmethod diff --git a/bot/decorators.py b/bot/decorators.py index 149564d18..2ee5879f2 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -11,30 +11,34 @@ from discord.errors import NotFound from discord.ext import commands from discord.ext.commands import CheckFailure, Cog, Context -from bot.constants import ERROR_REPLIES, RedirectOutput +from bot.constants import Channels, ERROR_REPLIES, RedirectOutput from bot.utils.checks import with_role_check, without_role_check log = logging.getLogger(__name__) -class InWhitelistedContextCheckFailure(CheckFailure): +class InWhitelistCheckFailure(CheckFailure): """Raised when the `in_whitelist` check fails.""" - def __init__(self, redirect_channel: Optional[int] = None): - error_message = "Sorry, but you are not allowed to use that command here." + def __init__(self, redirect_channel: Optional[int]) -> None: + self.redirect_channel = redirect_channel if redirect_channel: - error_message += f" Please use the <#{redirect_channel}> channel instead." + redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" + else: + redirect_message = "" + + error_message = f"You are not allowed to use that command{redirect_message}." super().__init__(error_message) -def in_whitelisted_context( +def in_whitelist( *, - whitelisted_channels: Container[int] = (), - whitelisted_categories: Container[int] = (), - whitelisted_roles: Container[int] = (), - redirect_channel: Optional[int] = None, + channels: Container[int] = (), + categories: Container[int] = (), + roles: Container[int] = (), + redirect: Optional[int] = Channels.bot_commands, ) -> Callable: """ @@ -46,45 +50,40 @@ def in_whitelisted_context( - `categories`: a container with category ids for whitelisted categories - `roles`: a container with with role ids for whitelisted roles - An optional `redirect_channel` can be provided to redirect users that are not - authorized to use the command in the current context. If no such channel is - provided, the users are simply told that they are not authorized to use the - command. + If the command was invoked in a context that was not whitelisted, the member is either + redirected to the `redirect` channel that was passed (default: #bot-commands) or simply + told that they're not allowed to use this particular command (if `None` was passed). """ - if redirect_channel and redirect_channel not in whitelisted_channels: + if redirect and redirect not in channels: # It does not make sense for the channel whitelist to not contain the redirection - # channel (if provided). That's why we add the redirection channel to the `channels` + # channel (if applicable). That's why we add the redirection channel to the `channels` # container if it's not already in it. As we allow any container type to be passed, # we first create a tuple in order to safely add the redirection channel. # # Note: It's possible for the redirect channel to be in a whitelisted category, but # there's no easy way to check that and as a channel can easily be moved in and out of # categories, it's probably not wise to rely on its category in any case. - whitelisted_channels = tuple(whitelisted_channels) + (redirect_channel,) + channels = tuple(channels) + (redirect,) def predicate(ctx: Context) -> bool: """Check if a command was issued in a whitelisted context.""" - if whitelisted_channels and ctx.channel.id in whitelisted_channels: + if channels and ctx.channel.id in channels: log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") return True # Only check the category id if we have a category whitelist and the channel has a `category_id` - if ( - whitelisted_categories - and hasattr(ctx.channel, "category_id") - and ctx.channel.category_id in whitelisted_categories - ): + if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") return True # Only check the roles whitelist if we have one and ensure the author's roles attribute returns # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). - if whitelisted_roles and any(r.id in whitelisted_roles for r in getattr(ctx.author, "roles", ())): + if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") return True log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") - raise InWhitelistedContextCheckFailure(redirect_channel) + raise InWhitelistCheckFailure(redirect) return commands.check(predicate) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 4a36fe030..6dace1080 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,7 +7,7 @@ import discord from bot import constants from bot.cogs import information -from bot.decorators import InWhitelistedContextCheckFailure +from bot.decorators import InWhitelistCheckFailure from tests import helpers @@ -525,7 +525,7 @@ class UserCommandTests(unittest.TestCase): ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InWhitelistedContextCheckFailure, msg=msg): + with self.assertRaises(InWhitelistCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index fae7c0c52..645051fec 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -2,15 +2,15 @@ import collections import unittest import unittest.mock -from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context +from bot.decorators import InWhitelistCheckFailure, in_whitelist from tests import helpers WhitelistedContextTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx")) -class InWhitelistedContextTests(unittest.TestCase): - """Tests for the `in_whitelisted_context` check.""" +class InWhitelistTests(unittest.TestCase): + """Tests for the `in_whitelist` check.""" @classmethod def setUpClass(cls): @@ -23,43 +23,43 @@ class InWhitelistedContextTests(unittest.TestCase): cls.staff_role = helpers.MockRole(id=121212) cls.staff_member = helpers.MockMember(roles=(cls.staff_role,)) - cls.whitelisted_channels = (cls.bot_commands.id,) - cls.whitelisted_categories = (cls.help_channel.category_id,) - cls.whitelisted_roles = (cls.staff_role.id,) + cls.channels = (cls.bot_commands.id,) + cls.categories = (cls.help_channel.category_id,) + cls.roles = (cls.staff_role.id,) def test_predicate_returns_true_for_whitelisted_context(self): """The predicate should return `True` if a whitelisted context was passed to it.""" test_cases = ( # Commands issued in whitelisted channels by members without whitelisted roles WhitelistedContextTestCase( - kwargs={"whitelisted_channels": self.whitelisted_channels}, + kwargs={"channels": self.channels}, ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) ), - # `redirect_channel` should be added implicitly to the `whitelisted_channels` + # `redirect` should be added implicitly to the `channels` WhitelistedContextTestCase( - kwargs={"redirect_channel": self.bot_commands.id}, + kwargs={"redirect": self.bot_commands.id}, ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) ), # Commands issued in a whitelisted category by members without whitelisted roles WhitelistedContextTestCase( - kwargs={"whitelisted_categories": self.whitelisted_categories}, + kwargs={"categories": self.categories}, ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member) ), # Command issued by a staff member in a non-whitelisted channel/category WhitelistedContextTestCase( - kwargs={"whitelisted_roles": self.whitelisted_roles}, + kwargs={"roles": self.roles}, ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member) ), # With all kwargs provided WhitelistedContextTestCase( kwargs={ - "whitelisted_channels": self.whitelisted_channels, - "whitelisted_categories": self.whitelisted_categories, - "whitelisted_roles": self.whitelisted_roles, - "redirect_channel": self.bot_commands, + "channels": self.channels, + "categories": self.categories, + "roles": self.roles, + "redirect": self.bot_commands, }, ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member) ), @@ -69,31 +69,31 @@ class InWhitelistedContextTests(unittest.TestCase): # patch `commands.check` with a no-op lambda that just returns the predicate passed to it # so we can test the predicate that was generated from the specified kwargs. with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): - predicate = in_whitelisted_context(**test_case.kwargs) + predicate = in_whitelist(**test_case.kwargs) with self.subTest(test_case=test_case): self.assertTrue(predicate(test_case.ctx)) def test_predicate_raises_exception_for_non_whitelisted_context(self): - """The predicate should raise `InWhitelistedContextCheckFailure` for a non-whitelisted context.""" + """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context.""" test_cases = ( - # Failing check with `redirect_channel` + # Failing check with `redirect` WhitelistedContextTestCase( kwargs={ - "whitelisted_categories": self.whitelisted_categories, - "whitelisted_channels": self.whitelisted_channels, - "whitelisted_roles": self.whitelisted_roles, - "redirect_channel": self.bot_commands.id, + "categories": self.categories, + "channels": self.channels, + "roles": self.roles, + "redirect": self.bot_commands.id, }, ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) ), - # Failing check without `redirect_channel` + # Failing check without `redirect` WhitelistedContextTestCase( kwargs={ - "whitelisted_categories": self.whitelisted_categories, - "whitelisted_channels": self.whitelisted_channels, - "whitelisted_roles": self.whitelisted_roles, + "categories": self.categories, + "channels": self.channels, + "roles": self.roles, }, ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) ), @@ -102,14 +102,14 @@ class InWhitelistedContextTests(unittest.TestCase): for test_case in test_cases: # Create expected exception message based on whether or not a redirect channel was provided expected_message = "Sorry, but you are not allowed to use that command here." - if test_case.kwargs.get("redirect_channel"): - expected_message += f" Please use the <#{test_case.kwargs['redirect_channel']}> channel instead." + if test_case.kwargs.get("redirect"): + expected_message += f" Please use the <#{test_case.kwargs['redirect']}> channel instead." # patch `commands.check` with a no-op lambda that just returns the predicate passed to it # so we can test the predicate that was generated from the specified kwargs. with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): - predicate = in_whitelisted_context(**test_case.kwargs) + predicate = in_whitelist(**test_case.kwargs) with self.subTest(test_case=test_case): - with self.assertRaises(InWhitelistedContextCheckFailure, msg=expected_message): + with self.assertRaises(InWhitelistCheckFailure, msg=expected_message): predicate(test_case.ctx) -- cgit v1.2.3 From f5bb251bbfd92bfe67ee9638f2bf6d054eb30502 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 27 Apr 2020 16:01:45 +0200 Subject: Exclude never-run lines from coverage --- tests/bot/cogs/test_cogs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 39f6492cb..fdda59a8f 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -31,7 +31,7 @@ class CommandNameTests(unittest.TestCase): def walk_modules() -> t.Iterator[ModuleType]: """Yield imported modules from the bot.cogs subpackage.""" def on_error(name: str) -> t.NoReturn: - raise ImportError(name=name) + raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. with mock.patch("discord.ext.tasks.loop"): @@ -71,7 +71,7 @@ class CommandNameTests(unittest.TestCase): for name in self.get_qualified_names(cmd): with self.subTest(cmd=func_name, name=name): - if name in all_names: + if name in all_names: # pragma: no cover conflicts = ", ".join(all_names.get(name, "")) self.fail( f"Name '{name}' of the command {func_name} conflicts with {conflicts}." -- cgit v1.2.3 From 167f57b9cc78708b7c6b48f64442d7bddce2f75c Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 27 Apr 2020 16:02:15 +0200 Subject: Add mock for discord.DMChannels --- tests/helpers.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 9001deedf..2b79a6c2a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -323,6 +323,27 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): self.mention = f"#{self.name}" +# Create data for the DMChannel instance +state = unittest.mock.MagicMock() +me = unittest.mock.MagicMock() +dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) + + +class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): + """ + A MagicMock subclass to mock TextChannel objects. + + Instances of this class will follow the specifications of `discord.TextChannel` instances. For + more information, see the `MockGuild` docstring. + """ + spec_set = dm_channel_instance + + def __init__(self, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + + # Create a Message instance to get a realistic MagicMock of `discord.Message` message_data = { 'id': 1, -- cgit v1.2.3 From d21e5962be961a267cef6ffef4f7d4aaf1114a08 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Mon, 27 Apr 2020 16:03:12 +0200 Subject: Add DMChannel tests for in_whitelist decorator The `in_whitelist` decorator should not fail when a decorated command was called in a DMChannel; it should simply conclude that the user is not allowed to use the command. I've added a test case that uses a DMChannel context with User, not Member, objects. In addition, I've opted to display a test case description in the `subTest`: Simply printing the actual arguments and context is messy and does not actually show you the information you'd like. This description is enough to figure out which test is failing and what the gist of the test is. --- tests/bot/test_decorators.py | 94 +++++++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 31 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index 645051fec..a17dd3e16 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -2,11 +2,12 @@ import collections import unittest import unittest.mock +from bot import constants from bot.decorators import InWhitelistCheckFailure, in_whitelist from tests import helpers -WhitelistedContextTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx")) +InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) class InWhitelistTests(unittest.TestCase): @@ -18,6 +19,7 @@ class InWhitelistTests(unittest.TestCase): cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456) cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654) cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666) + cls.dm_channel = helpers.MockDMChannel() cls.non_staff_member = helpers.MockMember() cls.staff_role = helpers.MockRole(id=121212) @@ -30,38 +32,35 @@ class InWhitelistTests(unittest.TestCase): def test_predicate_returns_true_for_whitelisted_context(self): """The predicate should return `True` if a whitelisted context was passed to it.""" test_cases = ( - # Commands issued in whitelisted channels by members without whitelisted roles - WhitelistedContextTestCase( + InWhitelistTestCase( kwargs={"channels": self.channels}, - ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) + ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), + description="In whitelisted channels by members without whitelisted roles", ), - # `redirect` should be added implicitly to the `channels` - WhitelistedContextTestCase( + InWhitelistTestCase( kwargs={"redirect": self.bot_commands.id}, - ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member) + ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), + description="`redirect` should be implicitly added to `channels`", ), - - # Commands issued in a whitelisted category by members without whitelisted roles - WhitelistedContextTestCase( + InWhitelistTestCase( kwargs={"categories": self.categories}, - ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member) + ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member), + description="Whitelisted category without whitelisted role", ), - - # Command issued by a staff member in a non-whitelisted channel/category - WhitelistedContextTestCase( + InWhitelistTestCase( kwargs={"roles": self.roles}, - ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member) + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member), + description="Whitelisted role outside of whitelisted channel/category" ), - - # With all kwargs provided - WhitelistedContextTestCase( + InWhitelistTestCase( kwargs={ "channels": self.channels, "categories": self.categories, "roles": self.roles, "redirect": self.bot_commands, }, - ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member) + ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member), + description="Case with all whitelist kwargs used", ), ) @@ -71,45 +70,78 @@ class InWhitelistTests(unittest.TestCase): with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): predicate = in_whitelist(**test_case.kwargs) - with self.subTest(test_case=test_case): + with self.subTest(test_description=test_case.description): self.assertTrue(predicate(test_case.ctx)) def test_predicate_raises_exception_for_non_whitelisted_context(self): """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context.""" test_cases = ( - # Failing check with `redirect` - WhitelistedContextTestCase( + # Failing check with explicit `redirect` + InWhitelistTestCase( kwargs={ "categories": self.categories, "channels": self.channels, "roles": self.roles, "redirect": self.bot_commands.id, }, - ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), + description="Failing check with an explicit redirect channel", + ), + + # Failing check with implicit `redirect` + InWhitelistTestCase( + kwargs={ + "categories": self.categories, + "channels": self.channels, + "roles": self.roles, + }, + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), + description="Failing check with an implicit redirect channel", ), # Failing check without `redirect` - WhitelistedContextTestCase( + InWhitelistTestCase( + kwargs={ + "categories": self.categories, + "channels": self.channels, + "roles": self.roles, + "redirect": None, + }, + ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), + description="Failing check without a redirect channel", + ), + + # Command issued in DM channel + InWhitelistTestCase( kwargs={ "categories": self.categories, "channels": self.channels, "roles": self.roles, + "redirect": None, }, - ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member) + ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me), + description="Commands issued in DM channel should be rejected", ), ) for test_case in test_cases: - # Create expected exception message based on whether or not a redirect channel was provided - expected_message = "Sorry, but you are not allowed to use that command here." - if test_case.kwargs.get("redirect"): - expected_message += f" Please use the <#{test_case.kwargs['redirect']}> channel instead." + if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None: + # There are two cases in which we have a redirect channel: + # 1. No redirect channel was passed; the default value of `bot_commands` is used + # 2. An explicit `redirect` is set that is "not None" + redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands) + redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" + else: + # If an explicit `None` was passed for `redirect`, there is no redirect channel + redirect_message = "" + + exception_message = f"You are not allowed to use that command{redirect_message}." # patch `commands.check` with a no-op lambda that just returns the predicate passed to it # so we can test the predicate that was generated from the specified kwargs. with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): predicate = in_whitelist(**test_case.kwargs) - with self.subTest(test_case=test_case): - with self.assertRaises(InWhitelistCheckFailure, msg=expected_message): + with self.subTest(test_description=test_case.description): + with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message): predicate(test_case.ctx) -- cgit v1.2.3 From 96920935f9af6d325a2ff91d197285204b3221c9 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 28 Apr 2020 15:56:15 -0700 Subject: Test for out of range datetime in the Duration converter --- tests/bot/test_converters.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'tests') diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index ca8cb6825..e42bfc7ee 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -198,6 +198,17 @@ class ConverterTests(unittest.TestCase): with self.assertRaises(BadArgument, msg=exception_message): asyncio.run(converter.convert(self.context, invalid_duration)) + @patch("bot.converters.datetime") + def test_duration_converter_out_of_range(self, mock_datetime): + """Duration converter should raise BadArgument if datetime raises a ValueError.""" + mock_datetime.__add__.side_effect = ValueError + mock_datetime.utcnow.return_value = mock_datetime + + duration = f"{datetime.MAXYEAR}y" + exception_message = f"`{duration}` results in a datetime outside the supported range." + with self.assertRaisesRegex(BadArgument, exception_message): + asyncio.run(Duration().convert(self.context, duration)) + def test_isodatetime_converter_for_valid(self): """ISODateTime converter returns correct datetime for valid datetime string.""" test_values = ( -- cgit v1.2.3 From 298389f57166fb5c775e550175c8bb2685fa37ae Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 28 Apr 2020 15:58:29 -0700 Subject: Remove redundant parenthesis from test values --- tests/bot/test_converters.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index e42bfc7ee..51d7affba 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -166,28 +166,28 @@ class ConverterTests(unittest.TestCase): """Duration raises the right exception for invalid duration strings.""" test_values = ( # Units in wrong order - ('1d1w'), - ('1s1y'), + '1d1w', + '1s1y', # Duplicated units - ('1 year 2 years'), - ('1 M 10 minutes'), + '1 year 2 years', + '1 M 10 minutes', # Unknown substrings - ('1MVes'), - ('1y3breads'), + '1MVes', + '1y3breads', # Missing amount - ('ym'), + 'ym', # Incorrect whitespace - (" 1y"), - ("1S "), - ("1y 1m"), + " 1y", + "1S ", + "1y 1m", # Garbage - ('Guido van Rossum'), - ('lemon lemon lemon lemon lemon lemon lemon'), + 'Guido van Rossum', + 'lemon lemon lemon lemon lemon lemon lemon', ) converter = Duration() @@ -262,19 +262,19 @@ class ConverterTests(unittest.TestCase): """ISODateTime converter raises the correct exception for invalid datetime strings.""" test_values = ( # Make sure it doesn't interfere with the Duration converter - ('1Y'), - ('1d'), - ('1H'), + '1Y', + '1d', + '1H', # Check if it fails when only providing the optional time part - ('10:10:10'), - ('10:00'), + '10:10:10', + '10:00', # Invalid date format - ('19-01-01'), + '19-01-01', # Other non-valid strings - ('fisk the tag master'), + 'fisk the tag master', ) converter = ISODateTime() -- cgit v1.2.3 From 837bc230976328df8dabdc6e8be90188b2ff2ff3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 28 Apr 2020 16:01:22 -0700 Subject: Use await instead of asyncio.run in converter tests --- tests/bot/test_converters.py | 55 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 51d7affba..146a8b5fa 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -1,4 +1,3 @@ -import asyncio import datetime import unittest from unittest.mock import MagicMock, patch @@ -16,7 +15,7 @@ from bot.converters import ( ) -class ConverterTests(unittest.TestCase): +class ConverterTests(unittest.IsolatedAsyncioTestCase): """Tests our custom argument converters.""" @classmethod @@ -26,7 +25,7 @@ class ConverterTests(unittest.TestCase): cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') - def test_tag_content_converter_for_valid(self): + async def test_tag_content_converter_for_valid(self): """TagContentConverter should return correct values for valid input.""" test_values = ( ('hello', 'hello'), @@ -35,10 +34,10 @@ class ConverterTests(unittest.TestCase): for content, expected_conversion in test_values: with self.subTest(content=content, expected_conversion=expected_conversion): - conversion = asyncio.run(TagContentConverter.convert(self.context, content)) + conversion = await TagContentConverter.convert(self.context, content) self.assertEqual(conversion, expected_conversion) - def test_tag_content_converter_for_invalid(self): + async def test_tag_content_converter_for_invalid(self): """TagContentConverter should raise the proper exception for invalid input.""" test_values = ( ('', "Tag contents should not be empty, or filled with whitespace."), @@ -48,9 +47,9 @@ class ConverterTests(unittest.TestCase): for value, exception_message in test_values: with self.subTest(tag_content=value, exception_message=exception_message): with self.assertRaises(BadArgument, msg=exception_message): - asyncio.run(TagContentConverter.convert(self.context, value)) + await TagContentConverter.convert(self.context, value) - def test_tag_name_converter_for_valid(self): + async def test_tag_name_converter_for_valid(self): """TagNameConverter should return the correct values for valid tag names.""" test_values = ( ('tracebacks', 'tracebacks'), @@ -60,10 +59,10 @@ class ConverterTests(unittest.TestCase): for name, expected_conversion in test_values: with self.subTest(name=name, expected_conversion=expected_conversion): - conversion = asyncio.run(TagNameConverter.convert(self.context, name)) + conversion = await TagNameConverter.convert(self.context, name) self.assertEqual(conversion, expected_conversion) - def test_tag_name_converter_for_invalid(self): + async def test_tag_name_converter_for_invalid(self): """TagNameConverter should raise the correct exception for invalid tag names.""" test_values = ( ('👋', "Don't be ridiculous, you can't use that character!"), @@ -76,18 +75,18 @@ class ConverterTests(unittest.TestCase): for invalid_name, exception_message in test_values: with self.subTest(invalid_name=invalid_name, exception_message=exception_message): with self.assertRaises(BadArgument, msg=exception_message): - asyncio.run(TagNameConverter.convert(self.context, invalid_name)) + await TagNameConverter.convert(self.context, invalid_name) - def test_valid_python_identifier_for_valid(self): + async def test_valid_python_identifier_for_valid(self): """ValidPythonIdentifier returns valid identifiers unchanged.""" test_values = ('foo', 'lemon') for name in test_values: with self.subTest(identifier=name): - conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name)) + conversion = await ValidPythonIdentifier.convert(self.context, name) self.assertEqual(name, conversion) - def test_valid_python_identifier_for_invalid(self): + async def test_valid_python_identifier_for_invalid(self): """ValidPythonIdentifier raises the proper exception for invalid identifiers.""" test_values = ('nested.stuff', '#####') @@ -95,9 +94,9 @@ class ConverterTests(unittest.TestCase): with self.subTest(identifier=name): exception_message = f'`{name}` is not a valid Python identifier' with self.assertRaises(BadArgument, msg=exception_message): - asyncio.run(ValidPythonIdentifier.convert(self.context, name)) + await ValidPythonIdentifier.convert(self.context, name) - def test_duration_converter_for_valid(self): + async def test_duration_converter_for_valid(self): """Duration returns the correct `datetime` for valid duration strings.""" test_values = ( # Simple duration strings @@ -159,10 +158,10 @@ class ConverterTests(unittest.TestCase): mock_datetime.utcnow.return_value = self.fixed_utc_now with self.subTest(duration=duration, duration_dict=duration_dict): - converted_datetime = asyncio.run(converter.convert(self.context, duration)) + converted_datetime = await converter.convert(self.context, duration) self.assertEqual(converted_datetime, expected_datetime) - def test_duration_converter_for_invalid(self): + async def test_duration_converter_for_invalid(self): """Duration raises the right exception for invalid duration strings.""" test_values = ( # Units in wrong order @@ -196,10 +195,10 @@ class ConverterTests(unittest.TestCase): with self.subTest(invalid_duration=invalid_duration): exception_message = f'`{invalid_duration}` is not a valid duration string.' with self.assertRaises(BadArgument, msg=exception_message): - asyncio.run(converter.convert(self.context, invalid_duration)) + await converter.convert(self.context, invalid_duration) @patch("bot.converters.datetime") - def test_duration_converter_out_of_range(self, mock_datetime): + async def test_duration_converter_out_of_range(self, mock_datetime): """Duration converter should raise BadArgument if datetime raises a ValueError.""" mock_datetime.__add__.side_effect = ValueError mock_datetime.utcnow.return_value = mock_datetime @@ -207,9 +206,9 @@ class ConverterTests(unittest.TestCase): duration = f"{datetime.MAXYEAR}y" exception_message = f"`{duration}` results in a datetime outside the supported range." with self.assertRaisesRegex(BadArgument, exception_message): - asyncio.run(Duration().convert(self.context, duration)) + await Duration().convert(self.context, duration) - def test_isodatetime_converter_for_valid(self): + async def test_isodatetime_converter_for_valid(self): """ISODateTime converter returns correct datetime for valid datetime string.""" test_values = ( # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` @@ -254,11 +253,11 @@ class ConverterTests(unittest.TestCase): for datetime_string, expected_dt in test_values: with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt): - converted_dt = asyncio.run(converter.convert(self.context, datetime_string)) + converted_dt = await converter.convert(self.context, datetime_string) self.assertIsNone(converted_dt.tzinfo) self.assertEqual(converted_dt, expected_dt) - def test_isodatetime_converter_for_invalid(self): + async def test_isodatetime_converter_for_invalid(self): """ISODateTime converter raises the correct exception for invalid datetime strings.""" test_values = ( # Make sure it doesn't interfere with the Duration converter @@ -282,9 +281,9 @@ class ConverterTests(unittest.TestCase): with self.subTest(datetime_string=datetime_string): exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" with self.assertRaises(BadArgument, msg=exception_message): - asyncio.run(converter.convert(self.context, datetime_string)) + await converter.convert(self.context, datetime_string) - def test_hush_duration_converter_for_valid(self): + async def test_hush_duration_converter_for_valid(self): """HushDurationConverter returns correct value for minutes duration or `"forever"` strings.""" test_values = ( ("0", 0), @@ -297,10 +296,10 @@ class ConverterTests(unittest.TestCase): converter = HushDurationConverter() for minutes_string, expected_minutes in test_values: with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes): - converted = asyncio.run(converter.convert(self.context, minutes_string)) + converted = await converter.convert(self.context, minutes_string) self.assertEqual(expected_minutes, converted) - def test_hush_duration_converter_for_invalid(self): + async def test_hush_duration_converter_for_invalid(self): """HushDurationConverter raises correct exception for invalid minutes duration strings.""" test_values = ( ("16", "Duration must be at most 15 minutes."), @@ -311,4 +310,4 @@ class ConverterTests(unittest.TestCase): for invalid_minutes_string, exception_message in test_values: with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message): with self.assertRaisesRegex(BadArgument, exception_message): - asyncio.run(converter.convert(self.context, invalid_minutes_string)) + await converter.convert(self.context, invalid_minutes_string) -- cgit v1.2.3 From 1e4766d9934396a72cc759649049b07e5814004a Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 28 Apr 2020 16:18:34 -0700 Subject: Fix exception message assertions in converter tests The `msg` arg is for displaying a message when the assertion fails. To match against the exception's message, `assertRaisesRegex` must be used. Since all of the messages are meant to be interpreted literally rather than as regex, `re.escape` is used. --- tests/bot/test_converters.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 146a8b5fa..c42111f3f 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -1,4 +1,5 @@ import datetime +import re import unittest from unittest.mock import MagicMock, patch @@ -46,7 +47,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): for value, exception_message in test_values: with self.subTest(tag_content=value, exception_message=exception_message): - with self.assertRaises(BadArgument, msg=exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await TagContentConverter.convert(self.context, value) async def test_tag_name_converter_for_valid(self): @@ -74,7 +75,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): for invalid_name, exception_message in test_values: with self.subTest(invalid_name=invalid_name, exception_message=exception_message): - with self.assertRaises(BadArgument, msg=exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await TagNameConverter.convert(self.context, invalid_name) async def test_valid_python_identifier_for_valid(self): @@ -93,7 +94,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): for name in test_values: with self.subTest(identifier=name): exception_message = f'`{name}` is not a valid Python identifier' - with self.assertRaises(BadArgument, msg=exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await ValidPythonIdentifier.convert(self.context, name) async def test_duration_converter_for_valid(self): @@ -194,7 +195,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): for invalid_duration in test_values: with self.subTest(invalid_duration=invalid_duration): exception_message = f'`{invalid_duration}` is not a valid duration string.' - with self.assertRaises(BadArgument, msg=exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await converter.convert(self.context, invalid_duration) @patch("bot.converters.datetime") @@ -205,7 +206,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): duration = f"{datetime.MAXYEAR}y" exception_message = f"`{duration}` results in a datetime outside the supported range." - with self.assertRaisesRegex(BadArgument, exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await Duration().convert(self.context, duration) async def test_isodatetime_converter_for_valid(self): @@ -280,7 +281,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): for datetime_string in test_values: with self.subTest(datetime_string=datetime_string): exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" - with self.assertRaises(BadArgument, msg=exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await converter.convert(self.context, datetime_string) async def test_hush_duration_converter_for_valid(self): @@ -309,5 +310,5 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): converter = HushDurationConverter() for invalid_minutes_string, exception_message in test_values: with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message): - with self.assertRaisesRegex(BadArgument, exception_message): + with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): await converter.convert(self.context, invalid_minutes_string) -- cgit v1.2.3 From b43379d663a86680f762d20a7bd27a20927d4bfc Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 30 Apr 2020 18:35:03 -0700 Subject: Tests: change avatar_url_as assertion to use static_format --- tests/bot/cogs/test_information.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 6dace1080..b5f928dd6 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -485,7 +485,7 @@ class UserEmbedTests(unittest.TestCase): user.avatar_url_as.return_value = "avatar url" embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - user.avatar_url_as.assert_called_once_with(format="png") + user.avatar_url_as.assert_called_once_with(static_format="png") self.assertEqual(embed.thumbnail.url, "avatar url") -- cgit v1.2.3 From 601ff03823deb842d74f4689fecb68f7ce1693e6 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 18:11:44 +0200 Subject: AntiMalware Tests - Added unittest for message without attachment --- tests/bot/cogs/test_antimalware.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/bot/cogs/test_antimalware.py (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..41ca19e17 --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,20 @@ +import asyncio +import unittest + +from bot.cogs import antimalware +from tests.helpers import MockBot, MockMessage + + +class AntiMalwareCogTests(unittest.TestCase): + """Test the AntiMalware cog.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.bot = MockBot() + self.cog = antimalware.AntiMalware(self.bot) + self.message = MockMessage() + + def test_message_without_attachment(self): + """Messages without attachments should result in no action.""" + coroutine = self.cog.on_message(self.message) + self.assertIsNone(asyncio.run(coroutine)) -- cgit v1.2.3 From 9889f0fdd1ba403ae50ba20be38feca0932d1dda Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 19:06:46 +0200 Subject: AntiMalware Tests - Added unittests for deletion of message and ignoring of dms --- tests/bot/cogs/test_antimalware.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 41ca19e17..ebf3a1277 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,8 +1,9 @@ import asyncio import unittest +from unittest.mock import AsyncMock from bot.cogs import antimalware -from tests.helpers import MockBot, MockMessage +from tests.helpers import MockAttachment, MockBot, MockMessage class AntiMalwareCogTests(unittest.TestCase): @@ -13,8 +14,27 @@ class AntiMalwareCogTests(unittest.TestCase): self.bot = MockBot() self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.message.delete = AsyncMock() def test_message_without_attachment(self): """Messages without attachments should result in no action.""" coroutine = self.cog.on_message(self.message) self.assertIsNone(asyncio.run(coroutine)) + self.message.delete.assert_not_called() + + def test_direct_message_with_attachment(self): + """Direct messages should have no action taken.""" + attachment = MockAttachment(filename="python.asdfsff") + self.message.attachments = [attachment] + self.message.guild = None + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + self.message.delete.assert_not_called() + + def test_message_with_illegal_extension_gets_deleted(self): + """A message containing an illegal extension should send an embed.""" + attachment = MockAttachment(filename="python.asdfsff") + self.message.attachments = [attachment] + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + self.message.delete.assert_called_once() -- cgit v1.2.3 From 90d2ce0e3717d4ddf79eb986e22f7542ca1770e1 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 19:13:45 +0200 Subject: AntiMalware Tests - Added unittest for messages send by staff --- tests/bot/cogs/test_antimalware.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index ebf3a1277..e3fd477fa 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -3,7 +3,8 @@ import unittest from unittest.mock import AsyncMock from bot.cogs import antimalware -from tests.helpers import MockAttachment, MockBot, MockMessage +from bot.constants import Roles +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole class AntiMalwareCogTests(unittest.TestCase): @@ -38,3 +39,13 @@ class AntiMalwareCogTests(unittest.TestCase): coroutine = self.cog.on_message(self.message) asyncio.run(coroutine) self.message.delete.assert_called_once() + + def test_message_send_by_staff(self): + """A message send by a member of staff should be ignored.""" + moderator_role = MockRole(name="Moderator", id=Roles.moderators) + self.message.author.roles.append(moderator_role) + attachment = MockAttachment(filename="python.asdfsff") + self.message.attachments = [attachment] + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + self.message.delete.assert_not_called() -- cgit v1.2.3 From 3913a8eba46bf98bd09e13145da33f7a09f77960 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 19:45:57 +0200 Subject: AntiMalware Tests - Added unittest for the embed for a python file. --- tests/bot/cogs/test_antimalware.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index e3fd477fa..0bb5af943 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import AsyncMock from bot.cogs import antimalware -from bot.constants import Roles +from bot.constants import Roles, URLs from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole @@ -28,16 +28,20 @@ class AntiMalwareCogTests(unittest.TestCase): attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] self.message.guild = None + coroutine = self.cog.on_message(self.message) asyncio.run(coroutine) + self.message.delete.assert_not_called() def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] + coroutine = self.cog.on_message(self.message) asyncio.run(coroutine) + self.message.delete.assert_called_once() def test_message_send_by_staff(self): @@ -46,6 +50,25 @@ class AntiMalwareCogTests(unittest.TestCase): self.message.author.roles.append(moderator_role) attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] + coroutine = self.cog.on_message(self.message) asyncio.run(coroutine) + self.message.delete.assert_not_called() + + def test_python_file_redirect_embed(self): + """A message containing a .python file should result in an embed redirecting the user to our paste site""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + + self.assertEqual(args[0], f"Hey {self.message.author.mention}!") + self.assertEqual(embed.description, ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" + )) -- cgit v1.2.3 From 19c15d957040b6857a4141e15c32fd0526f9920d Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 20:15:17 +0200 Subject: AntiMalware Tests - Added unittest for messages that were deleted in the meantime. --- tests/bot/cogs/test_antimalware.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 0bb5af943..da5cd9d11 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,6 +1,9 @@ import asyncio +import logging import unittest -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock + +from discord import NotFound from bot.cogs import antimalware from bot.constants import Roles, URLs @@ -72,3 +75,18 @@ class AntiMalwareCogTests(unittest.TestCase): "It looks like you tried to attach a Python file - " f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" )) + + def test_removing_deleted_message_logs(self): + """Removing an already deleted message logs the correct message""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + + coroutine = self.cog.on_message(self.message) + logger = logging.getLogger("bot.cogs.antimalware") + + with self.assertLogs(logger=logger, level="INFO") as logs: + asyncio.run(coroutine) + self.assertIn( + f"INFO:bot.cogs.antimalware:Tried to delete message `{self.message.id}`, but message could not be found.", + logs.output) -- cgit v1.2.3 From 4a0b3ea1ef182ddbbb1f9d731b28768a049a531d Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 20:23:00 +0200 Subject: AntiMalware Tests - Added unittest for cog setup --- tests/bot/cogs/test_antimalware.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index da5cd9d11..67c640d23 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -90,3 +90,13 @@ class AntiMalwareCogTests(unittest.TestCase): self.assertIn( f"INFO:bot.cogs.antimalware:Tried to delete message `{self.message.id}`, but message could not be found.", logs.output) + + +class AntiMalwareSetupTests(unittest.TestCase): + """Tests setup of the `AntiMalware` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + antimalware.setup(bot) + bot.add_cog.assert_called_once() -- cgit v1.2.3 From 3090141f673279f2836cb3aca95397eb9950ad0f Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 20:41:31 +0200 Subject: AntiMalware Tests - Added unittest message deletion log --- tests/bot/cogs/test_antimalware.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 67c640d23..b4e31b5ce 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,14 +1,17 @@ import asyncio import logging import unittest +from os.path import splitext from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import Roles, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Roles, URLs from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole +MODULE = "bot.cogs.antimalware" + class AntiMalwareCogTests(unittest.TestCase): """Test the AntiMalware cog.""" @@ -78,17 +81,38 @@ class AntiMalwareCogTests(unittest.TestCase): def test_removing_deleted_message_logs(self): """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.py") + attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) coroutine = self.cog.on_message(self.message) - logger = logging.getLogger("bot.cogs.antimalware") + logger = logging.getLogger(MODULE) with self.assertLogs(logger=logger, level="INFO") as logs: asyncio.run(coroutine) self.assertIn( - f"INFO:bot.cogs.antimalware:Tried to delete message `{self.message.id}`, but message could not be found.", + f"INFO:{MODULE}:Tried to delete message `{self.message.id}`, but message could not be found.", + logs.output) + + def test_message_with_illegal_attachment_logs(self): + """Deleting a message with an illegal attachment should result in a log.""" + attachment = MockAttachment(filename="python.asdfsff") + self.message.attachments = [attachment] + + coroutine = self.cog.on_message(self.message) + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} + extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + blocked_extensions_str = ', '.join(extensions_blocked) + logger = logging.getLogger(MODULE) + + with self.assertLogs(logger=logger, level="INFO") as logs: + asyncio.run(coroutine) + self.assertEqual( + [ + f"INFO:{MODULE}:" + f"User '{self.message.author}' ({self.message.author.id}) " + f"uploaded blacklisted file(s): {blocked_extensions_str}" + ], logs.output) -- cgit v1.2.3 From f0bc9d800dd141b9126c48251a80618e138d61f1 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 20:46:15 +0200 Subject: AntiMalware Tests - Added unittest for valid attachment --- tests/bot/cogs/test_antimalware.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index b4e31b5ce..407fa05c1 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -23,6 +23,15 @@ class AntiMalwareCogTests(unittest.TestCase): self.message = MockMessage() self.message.delete = AsyncMock() + def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should not be deleted""" + attachment = MockAttachment(filename=f"python.{AntiMalwareConfig.whitelist[0]}") + self.message.attachments = [attachment] + + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + self.message.delete.assert_not_called() + def test_message_without_attachment(self): """Messages without attachments should result in no action.""" coroutine = self.cog.on_message(self.message) -- cgit v1.2.3 From 75f6ca6bd9b695a5deb4a4d78311bc63eb2a74d0 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 21:04:47 +0200 Subject: AntiMalware Tests - Added unittest for txt file attachment --- tests/bot/cogs/test_antimalware.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 407fa05c1..eba439afb 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Roles, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, Roles, URLs from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole MODULE = "bot.cogs.antimalware" @@ -21,7 +21,6 @@ class AntiMalwareCogTests(unittest.TestCase): self.bot = MockBot() self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() - self.message.delete = AsyncMock() def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" @@ -88,6 +87,28 @@ class AntiMalwareCogTests(unittest.TestCase): f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" )) + def test_txt_file_redirect_embed(self): + attachment = MockAttachment(filename="python.txt") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + + coroutine = self.cog.on_message(self.message) + asyncio.run(coroutine) + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + cmd_channel = self.bot.get_channel(Channels.bot_commands) + + self.assertEqual(args[0], f"Hey {self.message.author.mention}!") + self.assertEqual(embed.description, ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + f"{cmd_channel.mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" + )) + def test_removing_deleted_message_logs(self): """Removing an already deleted message logs the correct message""" attachment = MockAttachment(filename="python.asdfsff") -- cgit v1.2.3 From c8bf44e30c286b27768601d5a04cd2459f170d4c Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Thu, 7 May 2020 21:29:15 +0200 Subject: AntiMalware Tests - Switched to unittest.IsolatedAsyncioTestCase --- tests/bot/cogs/test_antimalware.py | 48 +++++++++++++++----------------------- 1 file changed, 19 insertions(+), 29 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index eba439afb..6fb7b399e 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,4 +1,3 @@ -import asyncio import logging import unittest from os.path import splitext @@ -13,7 +12,7 @@ from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole MODULE = "bot.cogs.antimalware" -class AntiMalwareCogTests(unittest.TestCase): +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Test the AntiMalware cog.""" def setUp(self): @@ -22,62 +21,56 @@ class AntiMalwareCogTests(unittest.TestCase): self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() - def test_message_with_allowed_attachment(self): + async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" attachment = MockAttachment(filename=f"python.{AntiMalwareConfig.whitelist[0]}") self.message.attachments = [attachment] - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.message.delete.assert_not_called() - def test_message_without_attachment(self): + async def test_message_without_attachment(self): """Messages without attachments should result in no action.""" - coroutine = self.cog.on_message(self.message) - self.assertIsNone(asyncio.run(coroutine)) + self.assertIsNone(await self.cog.on_message(self.message)) self.message.delete.assert_not_called() - def test_direct_message_with_attachment(self): + async def test_direct_message_with_attachment(self): """Direct messages should have no action taken.""" attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] self.message.guild = None - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.message.delete.assert_not_called() - def test_message_with_illegal_extension_gets_deleted(self): + async def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.message.delete.assert_called_once() - def test_message_send_by_staff(self): + async def test_message_send_by_staff(self): """A message send by a member of staff should be ignored.""" moderator_role = MockRole(name="Moderator", id=Roles.moderators) self.message.author.roles.append(moderator_role) attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.message.delete.assert_not_called() - def test_python_file_redirect_embed(self): + async def test_python_file_redirect_embed(self): """A message containing a .python file should result in an embed redirecting the user to our paste site""" attachment = MockAttachment(filename="python.py") self.message.attachments = [attachment] self.message.channel.send = AsyncMock() - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) args, kwargs = self.message.channel.send.call_args embed = kwargs.pop("embed") @@ -87,13 +80,12 @@ class AntiMalwareCogTests(unittest.TestCase): f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" )) - def test_txt_file_redirect_embed(self): + async def test_txt_file_redirect_embed(self): attachment = MockAttachment(filename="python.txt") self.message.attachments = [attachment] self.message.channel.send = AsyncMock() - coroutine = self.cog.on_message(self.message) - asyncio.run(coroutine) + await self.cog.on_message(self.message) args, kwargs = self.message.channel.send.call_args embed = kwargs.pop("embed") cmd_channel = self.bot.get_channel(Channels.bot_commands) @@ -109,34 +101,32 @@ class AntiMalwareCogTests(unittest.TestCase): f"\n\n{URLs.site_schema}{URLs.site_paste}" )) - def test_removing_deleted_message_logs(self): + async def test_removing_deleted_message_logs(self): """Removing an already deleted message logs the correct message""" attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - coroutine = self.cog.on_message(self.message) logger = logging.getLogger(MODULE) with self.assertLogs(logger=logger, level="INFO") as logs: - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.assertIn( f"INFO:{MODULE}:Tried to delete message `{self.message.id}`, but message could not be found.", logs.output) - def test_message_with_illegal_attachment_logs(self): + async def test_message_with_illegal_attachment_logs(self): """Deleting a message with an illegal attachment should result in a log.""" attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] - coroutine = self.cog.on_message(self.message) file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) blocked_extensions_str = ', '.join(extensions_blocked) logger = logging.getLogger(MODULE) with self.assertLogs(logger=logger, level="INFO") as logs: - asyncio.run(coroutine) + await self.cog.on_message(self.message) self.assertEqual( [ f"INFO:{MODULE}:" -- cgit v1.2.3 From bd9537ba85154ece1dca39ec03d36dd7d39a8388 Mon Sep 17 00:00:00 2001 From: MrGrote Date: Fri, 8 May 2020 22:11:54 +0200 Subject: Update tests/bot/cogs/test_antimalware.py Co-authored-by: Mark --- tests/bot/cogs/test_antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 6fb7b399e..e0aa9d6d2 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -65,7 +65,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.delete.assert_not_called() async def test_python_file_redirect_embed(self): - """A message containing a .python file should result in an embed redirecting the user to our paste site""" + """A message containing a .py file should result in an embed redirecting the user to our paste site""" attachment = MockAttachment(filename="python.py") self.message.attachments = [attachment] self.message.channel.send = AsyncMock() -- cgit v1.2.3 From 847a78a76c08a670e85d926e3afa43e1cc3180f4 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 11 May 2020 19:41:46 +0200 Subject: AntiMalware Tests - implemented minor feedback --- tests/bot/cogs/test_antimalware.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index e0aa9d6d2..6e06df0a8 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,4 +1,3 @@ -import logging import unittest from os.path import splitext from unittest.mock import AsyncMock, Mock @@ -6,7 +5,7 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, Roles, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole MODULE = "bot.cogs.antimalware" @@ -31,7 +30,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_message_without_attachment(self): """Messages without attachments should result in no action.""" - self.assertIsNone(await self.cog.on_message(self.message)) + await self.cog.on_message(self.message) self.message.delete.assert_not_called() async def test_direct_message_with_attachment(self): @@ -55,8 +54,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_message_send_by_staff(self): """A message send by a member of staff should be ignored.""" - moderator_role = MockRole(name="Moderator", id=Roles.moderators) - self.message.author.roles.append(moderator_role) + staff_role = MockRole(id=STAFF_ROLES[0]) + self.message.author.roles.append(staff_role) attachment = MockAttachment(filename="python.asdfsff") self.message.attachments = [attachment] @@ -71,6 +70,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.channel.send = AsyncMock() await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() args, kwargs = self.message.channel.send.call_args embed = kwargs.pop("embed") @@ -107,13 +107,13 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.attachments = [attachment] self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - logger = logging.getLogger(MODULE) - - with self.assertLogs(logger=logger, level="INFO") as logs: + with self.assertLogs(logger=antimalware.log, level="INFO") as logs: await self.cog.on_message(self.message) + self.message.delete.assert_called_once() self.assertIn( f"INFO:{MODULE}:Tried to delete message `{self.message.id}`, but message could not be found.", - logs.output) + logs.output + ) async def test_message_with_illegal_attachment_logs(self): """Deleting a message with an illegal attachment should result in a log.""" @@ -123,9 +123,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) blocked_extensions_str = ', '.join(extensions_blocked) - logger = logging.getLogger(MODULE) - with self.assertLogs(logger=logger, level="INFO") as logs: + with self.assertLogs(logger=antimalware.log, level="INFO") as logs: await self.cog.on_message(self.message) self.assertEqual( [ @@ -133,7 +132,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): f"User '{self.message.author}' ({self.message.author.id}) " f"uploaded blacklisted file(s): {blocked_extensions_str}" ], - logs.output) + logs.output + ) class AntiMalwareSetupTests(unittest.TestCase): -- cgit v1.2.3 From ba71ac5b002dd3e1ee6a916ba2705a7cff697a66 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 11 May 2020 20:24:20 +0200 Subject: AntiMalware Tests - extracted the method for determining disallowed extensions and added a test for it. --- tests/bot/cogs/test_antimalware.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 6e06df0a8..78ad996f2 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -19,10 +19,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + AntiMalwareConfig.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename=f"python.{AntiMalwareConfig.whitelist[0]}") + attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -35,7 +36,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_direct_message_with_attachment(self): """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.asdfsff") + attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] self.message.guild = None @@ -45,7 +46,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.asdfsff") + attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -56,7 +57,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """A message send by a member of staff should be ignored.""" staff_role = MockRole(id=STAFF_ROLES[0]) self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.asdfsff") + attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -103,7 +104,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_removing_deleted_message_logs(self): """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.asdfsff") + attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) @@ -117,7 +118,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_illegal_attachment_logs(self): """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.asdfsff") + attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} @@ -135,6 +136,22 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): logs.output ) + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + (AntiMalwareConfig.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], [".disallowed"]), + ([".disallowed"], [".disallowed"]), + ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] + disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + class AntiMalwareSetupTests(unittest.TestCase): """Tests setup of the `AntiMalware` cog.""" -- cgit v1.2.3 From ecaddcedab6946ac4650b699a790471ef2a898c9 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 11 May 2020 20:39:25 +0200 Subject: AntiMalware Tests - added a missing case for no extensions in test_get_disallowed_extensions --- tests/bot/cogs/test_antimalware.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 78ad996f2..7caee6f3c 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -139,6 +139,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): async def test_get_disallowed_extensions(self): """The return value should include all non-whitelisted extensions.""" test_values = ( + ([], []), (AntiMalwareConfig.whitelist, []), ([".first"], []), ([".first", ".disallowed"], [".disallowed"]), -- cgit v1.2.3 From fa467e4ef133186ff462b0178bcab08e8a3d6b2d Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 11 May 2020 20:58:51 +0200 Subject: AntiMalware Tests - Removed exact log content checks --- tests/bot/cogs/test_antimalware.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 7caee6f3c..a2ce9a740 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,5 +1,4 @@ import unittest -from os.path import splitext from unittest.mock import AsyncMock, Mock from discord import NotFound @@ -108,33 +107,17 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.attachments = [attachment] self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - with self.assertLogs(logger=antimalware.log, level="INFO") as logs: + with self.assertLogs(logger=antimalware.log, level="INFO"): await self.cog.on_message(self.message) self.message.delete.assert_called_once() - self.assertIn( - f"INFO:{MODULE}:Tried to delete message `{self.message.id}`, but message could not be found.", - logs.output - ) async def test_message_with_illegal_attachment_logs(self): """Deleting a message with an illegal attachment should result in a log.""" attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] - file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments} - extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) - blocked_extensions_str = ', '.join(extensions_blocked) - - with self.assertLogs(logger=antimalware.log, level="INFO") as logs: + with self.assertLogs(logger=antimalware.log, level="INFO"): await self.cog.on_message(self.message) - self.assertEqual( - [ - f"INFO:{MODULE}:" - f"User '{self.message.author}' ({self.message.author.id}) " - f"uploaded blacklisted file(s): {blocked_extensions_str}" - ], - logs.output - ) async def test_get_disallowed_extensions(self): """The return value should include all non-whitelisted extensions.""" -- cgit v1.2.3 From d193a93828582965eb361dc6f3185291fff649a7 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:11:39 -0700 Subject: Test on_message_edit of token remover uses on_message --- tests/bot/cogs/test_token_remover.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 33d1ec170..e7b5a9bea 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,6 +1,7 @@ import asyncio import logging import unittest +from unittest import mock from unittest.mock import AsyncMock, MagicMock from discord import Colour @@ -14,7 +15,7 @@ from bot.constants import Channels, Colours, Event, Icons from tests.helpers import MockBot, MockMessage -class TokenRemoverTests(unittest.TestCase): +class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): """Tests the `TokenRemover` cog.""" def setUp(self): @@ -58,6 +59,13 @@ class TokenRemoverTests(unittest.TestCase): self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) self.bot.get_cog.assert_called_once_with('ModLog') + async def test_on_message_edit_uses_on_message(self): + """The edit listener should delegate handling of the message to the normal listener.""" + self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) + + await self.cog.on_message_edit(MockMessage(), self.msg) + self.cog.on_message.assert_awaited_once_with(self.msg) + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True @@ -77,7 +85,7 @@ class TokenRemoverTests(unittest.TestCase): for content in ('foo.bar.baz', 'x.y.'): with self.subTest(content=content): self.msg.content = content - coroutine = self.cog.on_message(self.msg) + coroutine = self.cog.is_maybe_token(self.msg) self.assertIsNone(asyncio.run(coroutine)) def test_censors_valid_tokens(self): -- cgit v1.2.3 From 0bfd003dbfc5919220129f984dc043421e535f8c Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:38:12 -0700 Subject: Add a test helper function to patch multiple attributes with autospecs This helper reduces redundancy/boilerplate by setting default values. It also has the consequence of shortening the length of the invocation, which makes it faster to use and easier to read. --- tests/helpers.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2b79a6c2a..d444cc49d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -23,6 +23,15 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) +def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: + """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" + # Caller's kwargs should take priority and overwrite the defaults. + kwargs = {'spec_set': True, 'autospec': True, **kwargs} + attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} + + return unittest.mock.patch.multiple(target, **attributes, **kwargs) + + class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. -- cgit v1.2.3 From e8bd69a6c556d78eca1a1eb2adfa26248273a1cd Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:42:07 -0700 Subject: Test token remover takes action if a token is found --- tests/bot/cogs/test_token_remover.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e7b5a9bea..e0ec67684 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -12,7 +12,7 @@ from bot.cogs.token_remover import ( setup as setup_cog, ) from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import MockBot, MockMessage +from tests.helpers import MockBot, MockMessage, autospec class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @@ -66,6 +66,18 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): await self.cog.on_message_edit(MockMessage(), self.msg) self.cog.on_message.assert_awaited_once_with(self.msg) + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_takes_action(self, find_token_in_message, take_action): + """Should take action if a valid token is found when a message is sent.""" + cog = TokenRemover(self.bot) + found_token = "foobar" + find_token_in_message.return_value = found_token + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_awaited_once_with(cog, self.msg, found_token) + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True -- cgit v1.2.3 From 4cf7996a1d4630ccb05f57569ca62b1798dc7a93 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 14:44:54 -0700 Subject: Test token remover skips messages without tokens --- tests/bot/cogs/test_token_remover.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e0ec67684..2b377e221 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -78,6 +78,17 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): find_token_in_message.assert_called_once_with(self.msg) take_action.assert_awaited_once_with(cog, self.msg, found_token) + @autospec(TokenRemover, "find_token_in_message", "take_action") + async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): + """Shouldn't take action if a valid token isn't found when a message is sent.""" + cog = TokenRemover(self.bot) + find_token_in_message.return_value = False + + await cog.on_message(self.msg) + + find_token_in_message.assert_called_once_with(self.msg) + take_action.assert_not_awaited() + def test_ignores_bot_messages(self): """When the message event handler is called with a bot message, nothing is done.""" self.msg.author.bot = True -- cgit v1.2.3 From 593e09299c6e4115d41bfd5b074785a5e304a8d0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:41:14 -0700 Subject: Allow using arbitrary parameter names with the autospec decorator This gives the caller more flexibility. Sometimes attribute names are too long or they don't follow a naming scheme accepted by the linter. --- tests/helpers.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index d444cc49d..1ab8b455f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -24,12 +24,25 @@ for logger in logging.Logger.manager.loggerDict.values(): def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: - """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" + """ + Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + + To allow for arbitrary parameter names to be used by the decorated function, the patchers have + no attribute names associated with them. As a consequence, it will not be possible to retrieve + mocks by their attribute names when using this as a context manager, + """ # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} - return unittest.mock.patch.multiple(target, **attributes, **kwargs) + patcher = unittest.mock.patch.multiple(target, **attributes, **kwargs) + + # Unset attribute names to allow arbitrary parameter names for the decorator function. + patcher.attribute_name = None + for additional_patcher in patcher.additional_patchers: + additional_patcher.attribute_name = None + + return patcher class HashableMixin(discord.mixins.EqualityComparable): -- cgit v1.2.3 From b0dd290710799c342240d066abaebbe9e6940b54 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:09:22 -0700 Subject: Fix test for token remover ignoring bot messages It's not possible to test this via asserting the return value of `on_message` since it never returns anything. Instead, the actual relevant unit, `find_token_in_message,` should be tested. --- tests/bot/cogs/test_token_remover.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2b377e221..e8b641101 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -89,11 +89,16 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): find_token_in_message.assert_called_once_with(self.msg) take_action.assert_not_awaited() - def test_ignores_bot_messages(self): - """When the message event handler is called with a bot message, nothing is done.""" + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_ignores_bot_messages(self, token_re): + """The token finder should ignore messages authored by bots.""" + cog = TokenRemover(self.bot) self.msg.author.bot = True - coroutine = self.cog.on_message(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + + return_value = cog.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.findall.assert_not_called() def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" -- cgit v1.2.3 From 52f0f8a29d7f239c961beaa81881bf4b09da4749 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 15:53:06 -0700 Subject: Test `find_token_in_message` returns None if no matches found --- tests/bot/cogs/test_token_remover.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index e8b641101..5932cf4f0 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -100,6 +100,20 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.findall.assert_not_called() + @autospec(TokenRemover, "is_maybe_token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): + """None should be returned if the regex matches no tokens in a message.""" + cog = TokenRemover(self.bot) + token_re.findall.return_value = () + self.msg.content = "foobar" + + return_value = cog.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.findall.assert_called_once_with(self.msg.content) + is_maybe_token.assert_not_called() + def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" for content in ('', 'lemon wins'): -- cgit v1.2.3 From cf658bd58559b2683527443f2908257f197ef0bb Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 16:06:47 -0700 Subject: Test `find_token_in_message` returns the found token --- tests/bot/cogs/test_token_remover.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5932cf4f0..2b946778b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -114,6 +114,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): token_re.findall.assert_called_once_with(self.msg.content) is_maybe_token.assert_not_called() + @autospec(TokenRemover, "is_maybe_token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_returns_found_token(self, token_re, is_maybe_token): + """The found token should be returned.""" + true_index = 1 + matches = ("foo", "bar", "baz") + side_effects = [False] * len(matches) + side_effects[true_index] = True + + cog = TokenRemover(self.bot) + self.msg.content = "foobar" + token_re.findall.return_value = matches + is_maybe_token.side_effect = side_effects + + return_value = cog.find_token_in_message(self.msg) + + self.assertEqual(return_value, matches[true_index]) + token_re.findall.assert_called_once_with(self.msg.content) + + # assert_has_calls isn't used cause it'd allow for extra calls before or after. + # The function should short-circuit, so nothing past true_index should have been used. + calls = [mock.call(match) for match in matches[:true_index + 1]] + self.assertEqual(is_maybe_token.mock_calls, calls) + def test_ignores_messages_without_tokens(self): """Messages without anything looking like a token are ignored.""" for content in ('', 'lemon wins'): -- cgit v1.2.3 From f92bc80d6bddb5c57c190187adaa528ae44536f6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 16:25:14 -0700 Subject: Test token regex doesn't match invalid tokens --- tests/bot/cogs/test_token_remover.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2b946778b..b67602eb9 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -8,6 +8,7 @@ from discord import Colour from bot.cogs.token_remover import ( DELETION_MESSAGE_TEMPLATE, + TOKEN_RE, TokenRemover, setup as setup_cog, ) @@ -138,13 +139,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): calls = [mock.call(match) for match in matches[:true_index + 1]] self.assertEqual(is_maybe_token.mock_calls, calls) - def test_ignores_messages_without_tokens(self): - """Messages without anything looking like a token are ignored.""" - for content in ('', 'lemon wins'): - with self.subTest(content=content): - self.msg.content = content - coroutine = self.cog.on_message(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "'.'.'", + '"."."', + "(.(.(", + ").).)" + ) + + for token in tokens: + with self.subTest(token=token): + results = TOKEN_RE.findall(token) + self.assertEquals(len(results), 0) def test_ignores_messages_with_invalid_tokens(self): """Messages with values that are invalid tokens are ignored.""" -- cgit v1.2.3 From 34b836a8eba0f006c77a7b3f48f7ab14c37d31ee Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 17:47:09 -0700 Subject: Fix autospec decorator when used with multiple attributes The original approach of messing with the `attribute_name` didn't work for reasons I won't discuss here (would require knowledge of patcher internals). The new approach doesn't use patch.multiple but mimics it by applying multiple patch decorators to the function. As a consequence, this can no longer be used as a context manager. --- tests/helpers.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 1ab8b455f..dfbe539ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -24,25 +24,21 @@ for logger in logging.Logger.manager.loggerDict.values(): def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: - """ - Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. - - To allow for arbitrary parameter names to be used by the decorated function, the patchers have - no attribute names associated with them. As a consequence, it will not be possible to retrieve - mocks by their attribute names when using this as a context manager, - """ + """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} - attributes = {attribute: unittest.mock.DEFAULT for attribute in attributes} - - patcher = unittest.mock.patch.multiple(target, **attributes, **kwargs) - - # Unset attribute names to allow arbitrary parameter names for the decorator function. - patcher.attribute_name = None - for additional_patcher in patcher.additional_patchers: - additional_patcher.attribute_name = None - return patcher + # Import the target if it's a string. + # This is to support both object and string targets like patch.multiple. + if type(target) is str: + target = unittest.mock._importer(target) + + def decorator(func): + for attribute in attributes: + patcher = unittest.mock.patch.object(target, attribute, **kwargs) + func = patcher(func) + return func + return decorator class HashableMixin(discord.mixins.EqualityComparable): -- cgit v1.2.3 From 834bd543d1d301bb853e713560a7447dc75f1ab8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 17:53:40 -0700 Subject: Test `is_maybe_token` returns False for missing parts In practice, this won't ever happen since the regex wouldn't match strings with missing parts. However, the function does check it so may as well test it. It's not necessarily bound to always use inputs from the regex either I suppose. --- tests/bot/cogs/test_token_remover.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index b67602eb9..9e1d96a37 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -164,6 +164,16 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = TOKEN_RE.findall(token) self.assertEquals(len(results), 0) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): + """False should be returned for tokens which do not have all 3 parts.""" + cog = TokenRemover(self.bot) + return_value = cog.is_maybe_token("x.y") + + self.assertFalse(return_value) + valid_user.assert_not_called() + valid_time.assert_not_called() + def test_ignores_messages_with_invalid_tokens(self): """Messages with values that are invalid tokens are ignored.""" for content in ('foo.bar.baz', 'x.y.'): -- cgit v1.2.3 From ab5d194b90a7e068c8ab7171939f471e252ee073 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:11:31 -0700 Subject: Test is_maybe_token --- tests/bot/cogs/test_token_remover.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 9e1d96a37..85bbbdf6b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -174,13 +174,30 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): valid_user.assert_not_called() valid_time.assert_not_called() - def test_ignores_messages_with_invalid_tokens(self): - """Messages with values that are invalid tokens are ignored.""" - for content in ('foo.bar.baz', 'x.y.'): - with self.subTest(content=content): - self.msg.content = content - coroutine = self.cog.is_maybe_token(self.msg) - self.assertIsNone(asyncio.run(coroutine)) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + def test_is_maybe_token(self, valid_user, valid_time): + """Should return True if the user ID and timestamp are valid or return False otherwise.""" + cog = TokenRemover(self.bot) + subtests = ( + (False, True, False), + (True, False, False), + (True, True, True), + ) + + for user_return, time_return, expected in subtests: + valid_user.reset_mock() + valid_time.reset_mock() + + with self.subTest(user_return=user_return, time_return=time_return, expected=expected): + valid_user.return_value = user_return + valid_time.return_value = time_return + + actual = cog.is_maybe_token("x.y.z") + self.assertIs(actual, expected) + + valid_user.assert_called_once_with("x") + if user_return: + valid_time.assert_called_once_with("y") def test_censors_valid_tokens(self): """Valid tokens are censored.""" -- cgit v1.2.3 From 4b6fde69a7e193382701dccf80a5471ea7ccea72 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:22:31 -0700 Subject: Test token regex matches valid tokens --- tests/bot/cogs/test_token_remover.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 85bbbdf6b..7310b4637 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -164,6 +164,27 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = TOKEN_RE.findall(token) self.assertEquals(len(results), 0) + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, the token's been invalidated. + tokens = ( + "x1.y2.z_3", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8" + ) + + for token in tokens: + with self.subTest(token=token): + results = TOKEN_RE.findall(token) + self.assertIn(token, results) + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + tokens = ["x.y.z", "a.b.c"] + message = f"garbage {tokens[0]} hello {tokens[1]} world" + + results = TOKEN_RE.findall(message) + self.assertEquals(tokens, results) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): """False should be returned for tokens which do not have all 3 parts.""" -- cgit v1.2.3 From d8d8e144adfe4c2de15dbbf4346e2eec548a9f67 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 10 May 2020 18:28:06 -0700 Subject: Correct the return type annotation for the autospec decorator --- tests/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index dfbe539ec..3cd8a63c0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,7 +4,7 @@ import collections import itertools import logging import unittest.mock -from typing import Iterable, Optional +from typing import Callable, Iterable, Optional import discord from discord.ext.commands import Context @@ -23,7 +23,7 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> unittest.mock._patch: +def autospec(target, *attributes: str, **kwargs) -> Callable: """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" # Caller's kwargs should take priority and overwrite the defaults. kwargs = {'spec_set': True, 'autospec': True, **kwargs} -- cgit v1.2.3 From 5b9bf9aba686f570322cb9996dd35d3ab669a162 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:26:16 -0700 Subject: Avoid instantiating the cog when testing static/class methods --- tests/bot/cogs/test_token_remover.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 7310b4637..6a8247070 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -93,10 +93,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_ignores_bot_messages(self, token_re): """The token finder should ignore messages authored by bots.""" - cog = TokenRemover(self.bot) self.msg.author.bot = True - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) token_re.findall.assert_not_called() @@ -105,11 +104,10 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" - cog = TokenRemover(self.bot) token_re.findall.return_value = () self.msg.content = "foobar" - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) token_re.findall.assert_called_once_with(self.msg.content) @@ -124,12 +122,11 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): side_effects = [False] * len(matches) side_effects[true_index] = True - cog = TokenRemover(self.bot) self.msg.content = "foobar" token_re.findall.return_value = matches is_maybe_token.side_effect = side_effects - return_value = cog.find_token_in_message(self.msg) + return_value = TokenRemover.find_token_in_message(self.msg) self.assertEqual(return_value, matches[true_index]) token_re.findall.assert_called_once_with(self.msg.content) @@ -188,8 +185,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): """False should be returned for tokens which do not have all 3 parts.""" - cog = TokenRemover(self.bot) - return_value = cog.is_maybe_token("x.y") + return_value = TokenRemover.is_maybe_token("x.y") self.assertFalse(return_value) valid_user.assert_not_called() @@ -198,7 +194,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token(self, valid_user, valid_time): """Should return True if the user ID and timestamp are valid or return False otherwise.""" - cog = TokenRemover(self.bot) subtests = ( (False, True, False), (True, False, False), @@ -213,7 +208,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): valid_user.return_value = user_return valid_time.return_value = time_return - actual = cog.is_maybe_token("x.y.z") + actual = TokenRemover.is_maybe_token("x.y.z") self.assertIs(actual, expected) valid_user.assert_called_once_with("x") -- cgit v1.2.3 From 2127239840085ba523d411899e0b7a188530df07 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:33:05 -0700 Subject: Simplify token remover's message mock * Rely on default values for the author * Set the content to a non-empty string --- tests/bot/cogs/test_token_remover.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 6a8247070..5ca863926 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -26,14 +26,10 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.bot.get_cog.return_value.send_log_message = AsyncMock() self.cog = TokenRemover(bot=self.bot) - self.msg = MockMessage(id=555, content='') - self.msg.author.__str__ = MagicMock() - self.msg.author.__str__.return_value = 'lemon' - self.msg.author.bot = False - self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' - self.msg.author.id = 42 - self.msg.author.mention = '@lemon' + self.msg = MockMessage(id=555, content="hello world") self.msg.channel.mention = "#lemonade-stand" + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + self.msg.author.avatar_url_as.return_value = "picture-lemon.png" def test_is_valid_user_id_is_true_for_numeric_content(self): """A string decoding to numeric characters is a valid user ID.""" @@ -105,7 +101,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" token_re.findall.return_value = () - self.msg.content = "foobar" return_value = TokenRemover.find_token_in_message(self.msg) @@ -122,7 +117,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): side_effects = [False] * len(matches) side_effects[true_index] = True - self.msg.content = "foobar" token_re.findall.return_value = matches is_maybe_token.side_effect = side_effects -- cgit v1.2.3 From e4790b330da1605573b5d23615bfe62b481e1e04 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:37:59 -0700 Subject: Test token remover's message deletion --- tests/bot/cogs/test_token_remover.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5ca863926..d65ce2ce5 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -209,6 +209,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): if user_return: valid_time.assert_called_once_with("y") + async def test_delete_message(self): + """The message should be deleted, and a message should be sent to the same channel.""" + await TokenRemover.delete_message(self.msg) + + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + def test_censors_valid_tokens(self): """Valid tokens are censored.""" cases = ( -- cgit v1.2.3 From 567a5f9242912d6a3340c088c0ae1a62977a141e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 10:46:02 -0700 Subject: Test TokenRemover.format_log_message --- tests/bot/cogs/test_token_remover.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index d65ce2ce5..f5412e692 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -218,6 +218,22 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) ) + @autospec("bot.cogs.token_remover", "LOG_MESSAGE") + async def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + log_message.format.return_value = "Howdy" + return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") + + self.assertEqual(return_value, log_message.format.return_value) + log_message.format.assert_called_once_with( + author=self.msg.author, + author_id=self.msg.author.id, + channel=self.msg.channel.mention, + user_id="MTIz", + timestamp="DN9R_A", + hmac="xxx", + ) + def test_censors_valid_tokens(self): """Valid tokens are censored.""" cases = ( -- cgit v1.2.3 From f47cbef0b47ef11b8c1fd63076105e4cb7d73601 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:29:28 -0700 Subject: Test TokenRemover.take_action * Remove `bot.get_cog` mocks in `setUp` * Mock the logger cause it's easier to assert logs * Remove subtests * Assert helper functions were called * Create an autospec for ModLog --- tests/bot/cogs/test_token_remover.py | 73 +++++++++++++++--------------------- 1 file changed, 30 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index f5412e692..3546e7964 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,11 +1,10 @@ -import asyncio -import logging import unittest from unittest import mock -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock from discord import Colour +from bot.cogs.moderation import ModLog from bot.cogs.token_remover import ( DELETION_MESSAGE_TEMPLATE, TOKEN_RE, @@ -22,8 +21,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Adds the cog, a bot, and a message to the instance for usage in tests.""" self.bot = MockBot() - self.bot.get_cog.return_value = MagicMock() - self.bot.get_cog.return_value.send_log_message = AsyncMock() self.cog = TokenRemover(bot=self.bot) self.msg = MockMessage(id=555, content="hello world") @@ -234,46 +231,36 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): hmac="xxx", ) - def test_censors_valid_tokens(self): - """Valid tokens are censored.""" - cases = ( - # (content, censored_token) - ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), + @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) + @autospec("bot.cogs.token_remover", "log") + @autospec(TokenRemover, "delete_message", "format_log_message") + async def test_take_action(self, delete_message, format_log_message, logger, mod_log_property): + """Should delete the message and send a mod log.""" + cog = TokenRemover(self.bot) + mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) + token = "MTIz.DN9R_A.xyz" + log_msg = "testing123" + + mod_log_property.return_value = mod_log + format_log_message.return_value = log_msg + + await cog.take_action(self.msg, token) + + delete_message.assert_awaited_once_with(self.msg) + format_log_message.assert_called_once_with(self.msg, token) + logger.debug.assert_called_with(log_msg) + self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") + + mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) + mod_log.send_log_message.assert_called_once_with( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=log_msg, + thumbnail=self.msg.author.avatar_url_as.return_value, + channel_id=Channels.mod_alerts ) - for content, censored_token in cases: - with self.subTest(content=content, censored_token=censored_token): - self.msg.content = content - coroutine = self.cog.on_message(self.msg) - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: - self.assertIsNone(asyncio.run(coroutine)) # no return value - - [line] = cm.output - log_message = ( - "Censored a seemingly valid token sent by " - "lemon (`42`) in #lemonade-stand, " - f"token was `{censored_token}`" - ) - self.assertIn(log_message, line) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') - ) - self.bot.get_cog.assert_called_with('ModLog') - self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') - - mod_log = self.bot.get_cog.return_value - mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=log_message, - thumbnail='picture-lemon.png', - channel_id=Channels.mod_alerts - ) - class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" -- cgit v1.2.3 From 5734a4d84922a9497014dfeb3eba31ad3c57536f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:44:08 -0700 Subject: Refactor `TokenRemoverSetupTests` and add a more thorough test The test now ensures the cog is instantiated and that the instance is passed as an argument to `add_cog`. --- tests/bot/cogs/test_token_remover.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3546e7964..c377de7b2 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -262,11 +262,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): ) -class TokenRemoverSetupTests(unittest.TestCase): - """Tests setup of the `TokenRemover` cog.""" +class TokenRemoverExtensionTests(unittest.TestCase): + """Tests for the token_remover extension.""" - def test_setup(self): - """Setup of the extension should call add_cog.""" + @autospec("bot.cogs.token_remover", "TokenRemover") + def test_extension_setup(self, cog): + """The TokenRemover cog should be added.""" bot = MockBot() setup_cog(bot) + + cog.assert_called_once_with(bot) bot.add_cog.assert_called_once() + self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) -- cgit v1.2.3 From d0303d715d485842a2d5c906099d767d74cf8bd8 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:45:50 -0700 Subject: Replace deprecated assertion methods --- tests/bot/cogs/test_token_remover.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index c377de7b2..aecb51403 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -150,7 +150,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): results = TOKEN_RE.findall(token) - self.assertEquals(len(results), 0) + self.assertEqual(len(results), 0) def test_regex_valid_tokens(self): """Messages that look like tokens should be matched.""" @@ -171,7 +171,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): message = f"garbage {tokens[0]} hello {tokens[1]} world" results = TOKEN_RE.findall(message) - self.assertEquals(tokens, results) + self.assertEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 862153f2e4ab5b1408719fb2c1abc5143cfb15ce Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:47:40 -0700 Subject: Clean up token remover test imports --- tests/bot/cogs/test_token_remover.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index aecb51403..5cc8c7ad1 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -4,14 +4,10 @@ from unittest.mock import MagicMock from discord import Colour +from bot import constants +from bot.cogs import token_remover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import ( - DELETION_MESSAGE_TEMPLATE, - TOKEN_RE, - TokenRemover, - setup as setup_cog, -) -from bot.constants import Channels, Colours, Event, Icons +from bot.cogs.token_remover import TokenRemover from tests.helpers import MockBot, MockMessage, autospec @@ -149,7 +145,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = TOKEN_RE.findall(token) + results = token_remover.TOKEN_RE.findall(token) self.assertEqual(len(results), 0) def test_regex_valid_tokens(self): @@ -162,7 +158,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = TOKEN_RE.findall(token) + results = token_remover.TOKEN_RE.findall(token) self.assertIn(token, results) def test_regex_matches_multiple_valid(self): @@ -170,7 +166,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): tokens = ["x.y.z", "a.b.c"] message = f"garbage {tokens[0]} hello {tokens[1]} world" - results = TOKEN_RE.findall(message) + results = token_remover.TOKEN_RE.findall(message) self.assertEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") @@ -212,7 +208,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.delete.assert_called_once_with() self.msg.channel.send.assert_called_once_with( - DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) ) @autospec("bot.cogs.token_remover", "LOG_MESSAGE") @@ -251,14 +247,14 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): logger.debug.assert_called_with(log_msg) self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) + mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) mod_log.send_log_message.assert_called_once_with( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), + icon_url=constants.Icons.token_removed, + colour=Colour(constants.Colours.soft_red), title="Token removed!", text=log_msg, thumbnail=self.msg.author.avatar_url_as.return_value, - channel_id=Channels.mod_alerts + channel_id=constants.Channels.mod_alerts ) @@ -269,7 +265,7 @@ class TokenRemoverExtensionTests(unittest.TestCase): def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - setup_cog(bot) + token_remover.setup(bot) cog.assert_called_once_with(bot) bot.add_cog.assert_called_once() -- cgit v1.2.3 From 4701b0da36c7f42792c0af258b785076237fd661 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 11 May 2020 11:56:15 -0700 Subject: Use subtests for valid ID/timestamp tests and test non-ASCII inputs --- tests/bot/cogs/test_token_remover.py | 43 +++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5cc8c7ad1..f1a56c235 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -24,24 +24,31 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - def test_is_valid_user_id_is_true_for_numeric_content(self): - """A string decoding to numeric characters is a valid user ID.""" - # MTIz = base64(123) - self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) - - def test_is_valid_user_id_is_false_for_alphabetic_content(self): - """A string decoding to alphabetic characters is not a valid user ID.""" - # YWJj = base64(abc) - self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) - - def test_is_valid_timestamp_is_true_for_valid_timestamps(self): - """A string decoding to a valid timestamp should be recognized as such.""" - self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) - - def test_is_valid_timestamp_is_false_for_invalid_values(self): - """A string not decoding to a valid timestamp should not be recognized as such.""" - # MTIz = base64(123) - self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) + def test_is_valid_user_id(self): + """Should correctly discern valid user IDs and ignore non-numeric and non-ASCII IDs.""" + subtests = ( + ("MTIz", True), # base64(123) + ("YWJj", False), # base64(abc) + ("λδµ", False), + ) + + for user_id, is_valid in subtests: + with self.subTest(user_id=user_id, is_valid=is_valid): + result = TokenRemover.is_valid_user_id(user_id) + self.assertIs(result, is_valid) + + def test_is_valid_timestamp(self): + """Should correctly discern valid timestamps.""" + subtests = ( + ("DN9r_A", True), + ("MTIz", False), # base64(123) + ("λδµ", False), + ) + + for timestamp, is_valid in subtests: + with self.subTest(timestamp=timestamp, is_valid=is_valid): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertIs(result, is_valid) def test_mod_log_property(self): """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -- cgit v1.2.3 From ddfe583d0b1e72f98855f628ff01b72c82fa491d Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 11 May 2020 21:56:39 +0200 Subject: AntiMalware Refactor - Moved embed descriptions into constants, added tests for embed descriptions --- bot/cogs/antimalware.py | 44 ++++++++++++++++++++-------------- tests/bot/cogs/test_antimalware.py | 48 ++++++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index f5fd5e2d9..ea257442e 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -10,6 +10,27 @@ from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLE log = logging.getLogger(__name__) +PY_EMBED_DESCRIPTION = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" +) + +TXT_EMBED_DESCRIPTION = ( + "**Uh-oh!** It looks like your message got zapped by our spam filter. " + "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" + "• If you attempted to send a message longer than 2000 characters, try shortening your message " + "to fit within the character limit or use a pasting service (see below) \n\n" + "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " + "{cmd_channel_mention} for more information) or use a pasting service like: " + f"\n\n{URLs.site_schema}{URLs.site_paste}" +) + +DISALLOWED_EMBED_DESCRIPTION = ( + "It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + f"We currently allow the following file types: **{', '.join(AntiMalwareConfig.whitelist)}**.\n\n" + "Feel free to ask in {meta_channel_mention} if you think this is a mistake." +) + class AntiMalware(Cog): """Delete messages which contain attachments with non-whitelisted file extensions.""" @@ -34,29 +55,16 @@ class AntiMalware(Cog): blocked_extensions_str = ', '.join(extensions_blocked) if ".py" in extensions_blocked: # Short-circuit on *.py files to provide a pastebin link - embed.description = ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" - ) + embed.description = PY_EMBED_DESCRIPTION elif ".txt" in extensions_blocked: # Work around Discord AutoConversion of messages longer than 2000 chars to .txt cmd_channel = self.bot.get_channel(Channels.bot_commands) - embed.description = ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - f"{cmd_channel.mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" - ) + embed.description = TXT_EMBED_DESCRIPTION.format(cmd_channel_mention=cmd_channel.mention) elif extensions_blocked: - whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) meta_channel = self.bot.get_channel(Channels.meta) - embed.description = ( - f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " - f"We currently allow the following file types: **{whitelisted_types}**.\n\n" - f"Feel free to ask in {meta_channel.mention} if you think this is a mistake." + embed.description = DISALLOWED_EMBED_DESCRIPTION.format( + blocked_extensions_str=blocked_extensions_str, + meta_channel_mention=meta_channel.mention, ) if embed.description: diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index a2ce9a740..fab063201 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole MODULE = "bot.cogs.antimalware" @@ -63,7 +63,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.delete.assert_not_called() - async def test_python_file_redirect_embed(self): + async def test_python_file_redirect_embed_description(self): """A message containing a .py file should result in an embed redirecting the user to our paste site""" attachment = MockAttachment(filename="python.py") self.message.attachments = [attachment] @@ -74,32 +74,44 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): args, kwargs = self.message.channel.send.call_args embed = kwargs.pop("embed") - self.assertEqual(args[0], f"Hey {self.message.author.mention}!") - self.assertEqual(embed.description, ( - "It looks like you tried to attach a Python file - " - f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" - )) + self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - async def test_txt_file_redirect_embed(self): + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt file should result in the correct embed.""" attachment = MockAttachment(filename="python.txt") self.message.attachments = [attachment] self.message.channel.send = AsyncMock() + antimalware.TXT_EMBED_DESCRIPTION = Mock() + antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() args, kwargs = self.message.channel.send.call_args embed = kwargs.pop("embed") cmd_channel = self.bot.get_channel(Channels.bot_commands) - self.assertEqual(args[0], f"Hey {self.message.author.mention}!") - self.assertEqual(embed.description, ( - "**Uh-oh!** It looks like your message got zapped by our spam filter. " - "We currently don't allow `.txt` attachments, so here are some tips to help you travel safely: \n\n" - "• If you attempted to send a message longer than 2000 characters, try shortening your message " - "to fit within the character limit or use a pasting service (see below) \n\n" - "• If you tried to show someone your code, you can use codeblocks \n(run `!code-blocks` in " - f"{cmd_channel.mention} for more information) or use a pasting service like: " - f"\n\n{URLs.site_schema}{URLs.site_paste}" - )) + self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) + antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + + async def test_other_disallowed_extention_embed_description(self): + """Test the description for a non .py/.txt disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + self.message.channel.send = AsyncMock() + antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + + await self.cog.on_message(self.message) + self.message.channel.send.assert_called_once() + args, kwargs = self.message.channel.send.call_args + embed = kwargs.pop("embed") + meta_channel = self.bot.get_channel(Channels.meta) + + self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) + antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) async def test_removing_deleted_message_logs(self): """Removing an already deleted message logs the correct message""" -- cgit v1.2.3 From 31aff51655d3783bc70f04628f189cf3c3591028 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 13 May 2020 18:58:43 -0700 Subject: Fix a test needlessly being a coroutine --- tests/bot/cogs/test_token_remover.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index f1a56c235..8e743a715 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -219,7 +219,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): ) @autospec("bot.cogs.token_remover", "LOG_MESSAGE") - async def test_format_log_message(self, log_message): + def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" log_message.format.return_value = "Howdy" return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") -- cgit v1.2.3 From 5a48ed0d60ebc9984cae27b19953b50b52df83d9 Mon Sep 17 00:00:00 2001 From: Numerlor <25886452+Numerlor@users.noreply.github.com> Date: Fri, 15 May 2020 19:52:26 +0200 Subject: Change tests to use the new timeout constant --- tests/bot/cogs/test_snekbox.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..ccc090f02 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -291,7 +291,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(actual, expected) self.bot.wait_for.assert_has_awaits( ( - call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), + call( + 'message_edit', + check=partial_mock(snekbox.predicate_eval_message_edit, ctx), + timeout=snekbox.REEVAL_TIMEOUT, + ), call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) ) ) -- cgit v1.2.3 From bf6c113319d47594e103c71f8ff5b0ea48d15b38 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 17 May 2020 03:18:19 +0200 Subject: Test suite for the redis dict. --- tests/bot/utils/test_redis_dict.py | 189 +++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/bot/utils/test_redis_dict.py (limited to 'tests') diff --git a/tests/bot/utils/test_redis_dict.py b/tests/bot/utils/test_redis_dict.py new file mode 100644 index 000000000..f422887ce --- /dev/null +++ b/tests/bot/utils/test_redis_dict.py @@ -0,0 +1,189 @@ +import unittest + +import fakeredis +from redis import DataError + +from bot.utils import RedisDict + +redis_server = fakeredis.FakeServer() +RedisDict._redis = fakeredis.FakeStrictRedis(server=redis_server) + + +class RedisDictTests(unittest.TestCase): + """Tests the RedisDict class from utils.redis_dict.py.""" + + redis = RedisDict() + + def test_class_attribute_namespace(self): + """Test that RedisDict creates a namespace automatically for class attributes.""" + self.assertEqual(self.redis._namespace, "RedisDictTests.redis") + + def test_custom_namespace(self): + """Test that users can set a custom namespaces which never collide.""" + test_cases = ( + (RedisDict("firedog")._namespace, "firedog"), + (RedisDict("firedog")._namespace, "firedog_"), + (RedisDict("firedog")._namespace, "firedog__"), + ) + + for test_case, result in test_cases: + self.assertEqual(test_case, result) + + def test_custom_namespace_takes_precedence(self): + """Test that custom namespaces take precedence over class attribute ones.""" + class LemonJuice: + citrus = RedisDict("citrus") + watercat = RedisDict() + + test_class = LemonJuice() + self.assertEqual(test_class.citrus._namespace, "citrus") + self.assertEqual(test_class.watercat._namespace, "LemonJuice.watercat") + + def test_set_get_item(self): + """Test that users can set and get items from the RedisDict.""" + self.redis['favorite_fruit'] = 'melon' + self.redis['favorite_number'] = 86 + self.assertEqual(self.redis['favorite_fruit'], 'melon') + self.assertEqual(self.redis['favorite_number'], 86) + + def test_set_item_value_types(self): + """Test that setitem rejects values that are not JSON serializable.""" + with self.assertRaises(TypeError): + self.redis['favorite_thing'] = object + self.redis['favorite_stuff'] = RedisDict + + def test_set_item_key_types(self): + """Test that setitem rejects keys that are not strings, ints or floats.""" + fruits = ["lemon", "melon", "apple"] + + with self.assertRaises(DataError): + self.redis[fruits] = "nice" + + def test_get_method(self): + """Test that the .get method works like in a dict.""" + self.redis['favorite_movie'] = 'Code Jam Highlights' + + self.assertEqual(self.redis.get('favorite_movie'), 'Code Jam Highlights') + self.assertEqual(self.redis.get('favorite_youtuber', 'pydis'), 'pydis') + self.assertIsNone(self.redis.get('favorite_dog')) + + def test_membership(self): + """Test that we can reliably use the `in` operator with our RedisDict.""" + self.redis['favorite_country'] = "Burkina Faso" + + self.assertIn('favorite_country', self.redis) + self.assertNotIn('favorite_dentist', self.redis) + + def test_del_item(self): + """Test that users can delete items from the RedisDict.""" + self.redis['favorite_band'] = "Radiohead" + self.assertIn('favorite_band', self.redis) + + del self.redis['favorite_band'] + self.assertNotIn('favorite_band', self.redis) + + def test_iter(self): + """Test that the RedisDict can be iterated.""" + self.redis.clear() + test_cases = ( + ('favorite_turtle', 'Donatello'), + ('second_favorite_turtle', 'Leonardo'), + ('third_favorite_turtle', 'Raphael'), + ) + for key, value in test_cases: + self.redis[key] = value + + # Test regular iteration + for test_case, key in zip(test_cases, self.redis): + value = test_case[1] + self.assertEqual(self.redis[key], value) + + # Test .items iteration + for key, value in self.redis.items(): + self.assertEqual(self.redis[key], value) + + # Test .keys iteration + for test_case, key in zip(test_cases, self.redis.keys()): + value = test_case[1] + self.assertEqual(self.redis[key], value) + + def test_len(self): + """Test that we can get the correct len() from the RedisDict.""" + self.redis.clear() + self.redis['one'] = 1 + self.redis['two'] = 2 + self.redis['three'] = 3 + self.assertEqual(len(self.redis), 3) + + self.redis['four'] = 4 + self.assertEqual(len(self.redis), 4) + + def test_copy(self): + """Test that the .copy method returns a workable dictionary copy.""" + copy = self.redis.copy() + local_copy = dict(self.redis.items()) + self.assertIs(type(copy), dict) + self.assertEqual(copy, local_copy) + + def test_clear(self): + """Test that the .clear method removes the entire hash.""" + self.redis.clear() + self.redis['teddy'] = "with me" + self.redis['in my dreams'] = "you have a weird hat" + self.assertEqual(len(self.redis), 2) + + self.redis.clear() + self.assertEqual(len(self.redis), 0) + + def test_pop(self): + """Test that we can .pop an item from the RedisDict.""" + self.redis.clear() + self.redis['john'] = 'was afraid' + + self.assertEqual(self.redis.pop('john'), 'was afraid') + self.assertEqual(self.redis.pop('pete', 'breakneck'), 'breakneck') + self.assertEqual(len(self.redis), 0) + + def test_popitem(self): + """Test that we can .popitem an item from the RedisDict.""" + self.redis.clear() + self.redis['john'] = 'the revalator' + self.redis['teddy'] = 'big bear' + + self.assertEqual(len(self.redis), 2) + self.assertEqual(self.redis.popitem(), 'big bear') + self.assertEqual(len(self.redis), 1) + + def test_setdefault(self): + """Test that we can .setdefault an item from the RedisDict.""" + self.redis.clear() + self.redis.setdefault('john', 'is yellow and weak') + self.assertEqual(self.redis['john'], 'is yellow and weak') + + with self.assertRaises(TypeError): + self.redis.setdefault('geisha', object) + + def test_update(self): + """Test that we can .update the RedisDict with multiple items.""" + self.redis.clear() + self.redis["reckfried"] = "lona" + self.redis["bel air"] = "prince" + self.redis.update({ + "reckfried": "jona", + "mega": "hungry, though", + }) + + result = { + "reckfried": "jona", + "bel air": "prince", + "mega": "hungry, though", + } + self.assertEqual(self.redis.copy(), result) + + def test_equals(self): + """Test that RedisDicts can be compared with == and !=.""" + new_redis_dict = RedisDict("firedog_the_sequel") + new_new_redis_dict = new_redis_dict + + self.assertEqual(new_redis_dict, new_new_redis_dict) + self.assertNotEqual(new_redis_dict, self.redis) -- cgit v1.2.3 From 4e24e9c43a331ebc0f9b598f4de6c45e04216782 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Sun, 17 May 2020 12:10:56 +0200 Subject: Use `send_help` to invoke command help After the refactoring of the help command, we need to use the built-in method of calling the help command: `Context.send_help`. As an argument, the qualified name (a string containing the full command path, including parents) of the command can be passed. Examples: - await ctx.send_help("reminders edit") This would send a help embed with information on `!reminders edit` to the Context. - await ctx.send_help(ctx.command.qualified_name) This would extract the qualified name of the command, which is the full command path, and send a help embed to Context. - await ctx.send_help() This will send the main "root" help embed to the Context. --- bot/cogs/bot.py | 2 +- bot/cogs/clean.py | 2 +- bot/cogs/defcon.py | 2 +- bot/cogs/error_handler.py | 33 +++++++++++++-------------------- bot/cogs/eval.py | 2 +- bot/cogs/extensions.py | 8 ++++---- bot/cogs/moderation/management.py | 2 +- bot/cogs/off_topic_names.py | 2 +- bot/cogs/reddit.py | 2 +- bot/cogs/reminders.py | 2 +- bot/cogs/site.py | 2 +- bot/cogs/snekbox.py | 2 +- bot/cogs/utils.py | 2 +- bot/cogs/watchchannels/bigbrother.py | 2 +- bot/cogs/watchchannels/talentpool.py | 4 ++-- tests/bot/cogs/test_snekbox.py | 3 +-- 16 files changed, 32 insertions(+), 40 deletions(-) (limited to 'tests') diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index a6929b431..ae829d5c3 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -41,7 +41,7 @@ class BotCog(Cog, name="Bot"): @with_role(Roles.verified) async def botinfo_group(self, ctx: Context) -> None: """Bot informational commands.""" - await ctx.invoke(self.bot.get_command("help"), "bot") + await ctx.send_help("bot") @botinfo_group.command(name='about', aliases=('info',), hidden=True) @with_role(Roles.verified) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index 5cdf0b048..e9bdbf510 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -180,7 +180,7 @@ class Clean(Cog): @with_role(*MODERATION_ROLES) async def clean_group(self, ctx: Context) -> None: """Commands for cleaning messages in channels.""" - await ctx.invoke(self.bot.get_command("help"), "clean") + await ctx.send_help("clean") @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index 56fca002a..71847a441 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -122,7 +122,7 @@ class Defcon(Cog): @with_role(Roles.admins, Roles.owners) async def defcon_group(self, ctx: Context) -> None: """Check the DEFCON status or run a subcommand.""" - await ctx.invoke(self.bot.get_command("help"), "defcon") + await ctx.send_help("defcon") async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: """Providing a structured way to do an defcon action.""" diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index b2f4c59f6..2d6cd85e6 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -2,7 +2,7 @@ import contextlib import logging import typing as t -from discord.ext.commands import Cog, Command, Context, errors +from discord.ext.commands import Cog, Context, errors from sentry_sdk import push_scope from bot.api import ResponseCodeError @@ -79,19 +79,13 @@ class ErrorHandler(Cog): f"{e.__class__.__name__}: {e}" ) - async def get_help_command(self, command: t.Optional[Command]) -> t.Tuple: - """Return the help command invocation args to display help for `command`.""" - parent = None - if command is not None: - parent = command.parent - - # Retrieve the help command for the invoked command. - if parent and command: - return self.bot.get_command("help"), parent.name, command.name - elif command: - return self.bot.get_command("help"), command.name - else: - return self.bot.get_command("help") + @staticmethod + def get_help_command(ctx: Context) -> t.Coroutine: + """Return a prepared `help` command invocation coroutine.""" + if ctx.command: + return ctx.send_help(ctx.command.qualified_name) + + return ctx.send_help() async def try_silence(self, ctx: Context) -> bool: """ @@ -165,20 +159,19 @@ class ErrorHandler(Cog): * ArgumentParsingError: send an error message * Other: send an error message and the help command """ - # TODO: use ctx.send_help() once PR #519 is merged. - help_command = await self.get_help_command(ctx.command) + prepared_help_command = self.get_help_command(ctx) if isinstance(e, errors.MissingRequiredArgument): await ctx.send(f"Missing required argument `{e.param.name}`.") - await ctx.invoke(*help_command) + await prepared_help_command self.bot.stats.incr("errors.missing_required_argument") elif isinstance(e, errors.TooManyArguments): await ctx.send(f"Too many arguments provided.") - await ctx.invoke(*help_command) + await prepared_help_command self.bot.stats.incr("errors.too_many_arguments") elif isinstance(e, errors.BadArgument): await ctx.send(f"Bad argument: {e}\n") - await ctx.invoke(*help_command) + await prepared_help_command self.bot.stats.incr("errors.bad_argument") elif isinstance(e, errors.BadUnionArgument): await ctx.send(f"Bad argument: {e}\n```{e.errors[-1]}```") @@ -188,7 +181,7 @@ class ErrorHandler(Cog): self.bot.stats.incr("errors.argument_parsing_error") else: await ctx.send("Something about your input seems off. Check the arguments:") - await ctx.invoke(*help_command) + await prepared_help_command self.bot.stats.incr("errors.other_user_input_error") @staticmethod diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 52136fc8d..2d52197e8 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -178,7 +178,7 @@ async def func(): # (None,) -> Any async def internal_group(self, ctx: Context) -> None: """Internal commands. Top secret!""" if not ctx.invoked_subcommand: - await ctx.invoke(self.bot.get_command("help"), "internal") + await ctx.send_help("internal") @internal_group.command(name='eval', aliases=('e',)) @with_role(Roles.admins, Roles.owners) diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index fb6cd9aa3..4493046e1 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -65,7 +65,7 @@ class Extensions(commands.Cog): @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) async def extensions_group(self, ctx: Context) -> None: """Load, unload, reload, and list loaded extensions.""" - await ctx.invoke(self.bot.get_command("help"), "extensions") + await ctx.send_help("extensions") @extensions_group.command(name="load", aliases=("l",)) async def load_command(self, ctx: Context, *extensions: Extension) -> None: @@ -75,7 +75,7 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. """ # noqa: W605 if not extensions: - await ctx.invoke(self.bot.get_command("help"), "extensions load") + await ctx.send_help("extensions load") return if "*" in extensions or "**" in extensions: @@ -92,7 +92,7 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. """ # noqa: W605 if not extensions: - await ctx.invoke(self.bot.get_command("help"), "extensions unload") + await ctx.send_help("extensions unload") return blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) @@ -118,7 +118,7 @@ class Extensions(commands.Cog): If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. """ # noqa: W605 if not extensions: - await ctx.invoke(self.bot.get_command("help"), "extensions reload") + await ctx.send_help("extensions reload") return if "**" in extensions: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 250a24247..5cd59cc07 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -43,7 +43,7 @@ class ModManagement(commands.Cog): @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) async def infraction_group(self, ctx: Context) -> None: """Infraction manipulation commands.""" - await ctx.invoke(self.bot.get_command("help"), "infraction") + await ctx.send_help("infraction") @infraction_group.command(name='edit') async def infraction_edit( diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 81511f99d..829772f65 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -97,7 +97,7 @@ class OffTopicNames(Cog): @with_role(*MODERATION_ROLES) async def otname_group(self, ctx: Context) -> None: """Add or list items from the off-topic channel name rotation.""" - await ctx.invoke(self.bot.get_command("help"), "otname") + await ctx.send_help("otname") @otname_group.command(name='add', aliases=('a',)) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 5a7fa100f..07a2497be 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -245,7 +245,7 @@ class Reddit(Cog): @group(name="reddit", invoke_without_command=True) async def reddit_group(self, ctx: Context) -> None: """View the top posts from various subreddits.""" - await ctx.invoke(self.bot.get_command("help"), "reddit") + await ctx.send_help("reddit") @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 8b6457cbb..e2289c75d 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -281,7 +281,7 @@ class Reminders(Scheduler, Cog): @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) async def edit_reminder_group(self, ctx: Context) -> None: """Commands for modifying your current reminders.""" - await ctx.invoke(self.bot.get_command("help"), "reminders", "edit") + await ctx.send_help("reminders edit") @edit_reminder_group.command(name="duration", aliases=("time",)) async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 853e29568..c17761a2b 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -21,7 +21,7 @@ class Site(Cog): @group(name="site", aliases=("s",), invoke_without_command=True) async def site_group(self, ctx: Context) -> None: """Commands for getting info about our website.""" - await ctx.invoke(self.bot.get_command("help"), "site") + await ctx.send_help("site") @site_group.command(name="home", aliases=("about",)) async def site_main(self, ctx: Context) -> None: diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 8d4688114..5de978758 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -289,7 +289,7 @@ class Snekbox(Cog): return if not code: # None or empty string - await ctx.invoke(self.bot.get_command("help"), "eval") + await ctx.send_help("eval") return log.info(f"Received code from {ctx.author} for evaluation:\n{code}") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 89d556f58..7350dc2ba 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -55,7 +55,7 @@ class Utils(Cog): if pep_number.isdigit(): pep_number = int(pep_number) else: - await ctx.invoke(self.bot.get_command("help"), "pep") + await ctx.send_help("pep") return # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 903c87f85..37f2d2b9d 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -30,7 +30,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): @with_role(*MODERATION_ROLES) async def bigbrother_group(self, ctx: Context) -> None: """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.invoke(self.bot.get_command("help"), "bigbrother") + await ctx.send_help("bigbrother") @bigbrother_group.command(name='watched', aliases=('all', 'list')) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index ad0c51fa6..b8473963d 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -34,7 +34,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): @with_role(*MODERATION_ROLES) async def nomination_group(self, ctx: Context) -> None: """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.invoke(self.bot.get_command("help"), "talentpool") + await ctx.send_help("talentpool") @nomination_group.command(name='watched', aliases=('all', 'list')) @with_role(*MODERATION_ROLES) @@ -173,7 +173,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): @with_role(*MODERATION_ROLES) async def nomination_edit_group(self, ctx: Context) -> None: """Commands to edit nominations.""" - await ctx.invoke(self.bot.get_command("help"), "talentpool", "edit") + await ctx.send_help("talentpool edit") @nomination_edit_group.command(name='reason') @with_role(*MODERATION_ROLES) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..190d41d66 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -209,9 +209,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_eval_command_call_help(self): """Test if the eval command call the help command if no code is provided.""" ctx = MockContext() - ctx.invoke = AsyncMock() await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') - ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") + ctx.send_help.assert_called_once_with("eval") async def test_send_eval(self): """Test the send_eval function.""" -- cgit v1.2.3 From 87cec1a863213aa23a07d29ca928766d382ee732 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Sun, 17 May 2020 12:37:34 +0200 Subject: Use `Command`-object for `send_help` As @mathsman5133 pointed out, it's better to use the `Command`-instance we typically already have in the current context than to rely on parsing the qualified name again. The invocation is now done as: `await ctx.send_help(ctx.command)` --- bot/cogs/bot.py | 2 +- bot/cogs/clean.py | 2 +- bot/cogs/defcon.py | 2 +- bot/cogs/error_handler.py | 2 +- bot/cogs/eval.py | 2 +- bot/cogs/extensions.py | 8 ++++---- bot/cogs/moderation/management.py | 2 +- bot/cogs/off_topic_names.py | 2 +- bot/cogs/reddit.py | 2 +- bot/cogs/reminders.py | 2 +- bot/cogs/site.py | 2 +- bot/cogs/snekbox.py | 2 +- bot/cogs/utils.py | 2 +- bot/cogs/watchchannels/bigbrother.py | 2 +- bot/cogs/watchchannels/talentpool.py | 4 ++-- tests/bot/cogs/test_snekbox.py | 4 ++-- 16 files changed, 21 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index ae829d5c3..f6aea51c5 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -41,7 +41,7 @@ class BotCog(Cog, name="Bot"): @with_role(Roles.verified) async def botinfo_group(self, ctx: Context) -> None: """Bot informational commands.""" - await ctx.send_help("bot") + await ctx.send_help(ctx.command) @botinfo_group.command(name='about', aliases=('info',), hidden=True) @with_role(Roles.verified) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index e9bdbf510..b5d9132cb 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -180,7 +180,7 @@ class Clean(Cog): @with_role(*MODERATION_ROLES) async def clean_group(self, ctx: Context) -> None: """Commands for cleaning messages in channels.""" - await ctx.send_help("clean") + await ctx.send_help(ctx.command) @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index 71847a441..25b0a6ad5 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -122,7 +122,7 @@ class Defcon(Cog): @with_role(Roles.admins, Roles.owners) async def defcon_group(self, ctx: Context) -> None: """Check the DEFCON status or run a subcommand.""" - await ctx.send_help("defcon") + await ctx.send_help(ctx.command) async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None: """Providing a structured way to do an defcon action.""" diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 2d6cd85e6..23d1eed82 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -83,7 +83,7 @@ class ErrorHandler(Cog): def get_help_command(ctx: Context) -> t.Coroutine: """Return a prepared `help` command invocation coroutine.""" if ctx.command: - return ctx.send_help(ctx.command.qualified_name) + return ctx.send_help(ctx.command) return ctx.send_help() diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 2d52197e8..eb8bfb1cf 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -178,7 +178,7 @@ async def func(): # (None,) -> Any async def internal_group(self, ctx: Context) -> None: """Internal commands. Top secret!""" if not ctx.invoked_subcommand: - await ctx.send_help("internal") + await ctx.send_help(ctx.command) @internal_group.command(name='eval', aliases=('e',)) @with_role(Roles.admins, Roles.owners) diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index 4493046e1..365f198ff 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -65,7 +65,7 @@ class Extensions(commands.Cog): @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) async def extensions_group(self, ctx: Context) -> None: """Load, unload, reload, and list loaded extensions.""" - await ctx.send_help("extensions") + await ctx.send_help(ctx.command) @extensions_group.command(name="load", aliases=("l",)) async def load_command(self, ctx: Context, *extensions: Extension) -> None: @@ -75,7 +75,7 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. """ # noqa: W605 if not extensions: - await ctx.send_help("extensions load") + await ctx.send_help(ctx.command) return if "*" in extensions or "**" in extensions: @@ -92,7 +92,7 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. """ # noqa: W605 if not extensions: - await ctx.send_help("extensions unload") + await ctx.send_help(ctx.command) return blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) @@ -118,7 +118,7 @@ class Extensions(commands.Cog): If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. """ # noqa: W605 if not extensions: - await ctx.send_help("extensions reload") + await ctx.send_help(ctx.command) return if "**" in extensions: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 5cd59cc07..edfdfd9e2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -43,7 +43,7 @@ class ModManagement(commands.Cog): @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True) async def infraction_group(self, ctx: Context) -> None: """Infraction manipulation commands.""" - await ctx.send_help("infraction") + await ctx.send_help(ctx.command) @infraction_group.command(name='edit') async def infraction_edit( diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 829772f65..201579a0b 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -97,7 +97,7 @@ class OffTopicNames(Cog): @with_role(*MODERATION_ROLES) async def otname_group(self, ctx: Context) -> None: """Add or list items from the off-topic channel name rotation.""" - await ctx.send_help("otname") + await ctx.send_help(ctx.command) @otname_group.command(name='add', aliases=('a',)) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 07a2497be..5f2aec7a5 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -245,7 +245,7 @@ class Reddit(Cog): @group(name="reddit", invoke_without_command=True) async def reddit_group(self, ctx: Context) -> None: """View the top posts from various subreddits.""" - await ctx.send_help("reddit") + await ctx.send_help(ctx.command) @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index e2289c75d..c242d2920 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -281,7 +281,7 @@ class Reminders(Scheduler, Cog): @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) async def edit_reminder_group(self, ctx: Context) -> None: """Commands for modifying your current reminders.""" - await ctx.send_help("reminders edit") + await ctx.send_help(ctx.command) @edit_reminder_group.command(name="duration", aliases=("time",)) async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: diff --git a/bot/cogs/site.py b/bot/cogs/site.py index c17761a2b..7fc2a9c34 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -21,7 +21,7 @@ class Site(Cog): @group(name="site", aliases=("s",), invoke_without_command=True) async def site_group(self, ctx: Context) -> None: """Commands for getting info about our website.""" - await ctx.send_help("site") + await ctx.send_help(ctx.command) @site_group.command(name="home", aliases=("about",)) async def site_main(self, ctx: Context) -> None: diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 5de978758..c2782b9c8 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -289,7 +289,7 @@ class Snekbox(Cog): return if not code: # None or empty string - await ctx.send_help("eval") + await ctx.send_help(ctx.command) return log.info(f"Received code from {ctx.author} for evaluation:\n{code}") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 7350dc2ba..6b59d37c8 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -55,7 +55,7 @@ class Utils(Cog): if pep_number.isdigit(): pep_number = int(pep_number) else: - await ctx.send_help("pep") + await ctx.send_help(ctx.command) return # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 37f2d2b9d..e4fb173e0 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -30,7 +30,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"): @with_role(*MODERATION_ROLES) async def bigbrother_group(self, ctx: Context) -> None: """Monitors users by relaying their messages to the Big Brother watch channel.""" - await ctx.send_help("bigbrother") + await ctx.send_help(ctx.command) @bigbrother_group.command(name='watched', aliases=('all', 'list')) @with_role(*MODERATION_ROLES) diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index b8473963d..9a85c68c2 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -34,7 +34,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): @with_role(*MODERATION_ROLES) async def nomination_group(self, ctx: Context) -> None: """Highlights the activity of helper nominees by relaying their messages to the talent pool channel.""" - await ctx.send_help("talentpool") + await ctx.send_help(ctx.command) @nomination_group.command(name='watched', aliases=('all', 'list')) @with_role(*MODERATION_ROLES) @@ -173,7 +173,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"): @with_role(*MODERATION_ROLES) async def nomination_edit_group(self, ctx: Context) -> None: """Commands to edit nominations.""" - await ctx.send_help("talentpool edit") + await ctx.send_help(ctx.command) @nomination_edit_group.command(name='reason') @with_role(*MODERATION_ROLES) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 190d41d66..8490b02ca 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -208,9 +208,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_eval_command_call_help(self): """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext() + ctx = MockContext(command="sentinel") await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with("eval") + ctx.send_help.assert_called_once_with("sentinel") async def test_send_eval(self): """Test the send_eval function.""" -- cgit v1.2.3 From 21916ad9c19a326eb8406ea751e5fd9f80e9d912 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 10:42:11 +0300 Subject: ModLog Tests: Fix truncation tests docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Leon Sandøy --- tests/bot/cogs/moderation/test_modlog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py index d60836474..b5ad21a09 100644 --- a/tests/bot/cogs/moderation/test_modlog.py +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -15,7 +15,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.channel = MockTextChannel() async def test_log_entry_description_truncation(self): - """Should truncate embed description for ModLog entry.""" + """Test that embed description for ModLog entry is truncated.""" self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( icon_url="foo", -- cgit v1.2.3 From 1432e5ba36fc09c7233e5be4745f540c2c4af792 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 10:47:59 +0300 Subject: Infraction Tests: Small fixes - Remove unnecessary space from placeholder - Rename `has_active_infraction` to `get_active_infraction` --- tests/bot/cogs/moderation/test_infractions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 51a8cc645..2b1ff5728 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -17,11 +17,11 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.guild = MockGuild(id=4567) self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) - @patch("bot.cogs.moderation.utils.has_active_infraction") + @patch("bot.cogs.moderation.utils.get_active_infraction") @patch("bot.cogs.moderation.utils.post_infraction") - async def test_apply_ban_reason_truncation(self, post_infraction_mock, has_active_mock): + async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" - has_active_mock.return_value = False + get_active_mock.return_value = 'foo' post_infraction_mock.return_value = {"foo": "bar"} self.cog.apply_infraction = AsyncMock() @@ -32,7 +32,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): ban = self.cog.apply_infraction.call_args[0][3] self.assertEqual( ban.cr_frame.f_locals["kwargs"]["reason"], - textwrap.shorten("foo bar" * 3000, 512, placeholder=" ...") + textwrap.shorten("foo bar" * 3000, 512, placeholder="...") ) # Await ban to avoid warning await ban -- cgit v1.2.3 From 5989bcfefa244eb05f37b76d1e1df2f45e5782fa Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 10:54:49 +0300 Subject: ModLog Tests: Fix embed description truncate test --- tests/bot/cogs/moderation/test_modlog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py index b5ad21a09..f2809f40a 100644 --- a/tests/bot/cogs/moderation/test_modlog.py +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -25,5 +25,5 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): ) embed = self.channel.send.call_args[1]["embed"] self.assertEqual( - embed.description, ("foo bar" * 3000)[:2046] + "..." + embed.description, ("foo bar" * 3000)[:2045] + "..." ) -- cgit v1.2.3 From 874cb001df91ea8223385dd2b32ab4e3c280e183 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 10:57:07 +0300 Subject: Infr. Tests: Add more content to await comment --- tests/bot/cogs/moderation/test_infractions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 2b1ff5728..f8f340c2e 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -34,7 +34,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): ban.cr_frame.f_locals["kwargs"]["reason"], textwrap.shorten("foo bar" * 3000, 512, placeholder="...") ) - # Await ban to avoid warning + # Await ban to avoid not awaited coroutine warning await ban @patch("bot.cogs.moderation.utils.post_infraction") @@ -51,5 +51,5 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): kick.cr_frame.f_locals["kwargs"]["reason"], textwrap.shorten("foo bar" * 3000, 512, placeholder="...") ) - # Await kick to avoid warning + # Await kick to avoid not awaited coroutine warning await kick -- cgit v1.2.3 From e9bd09d90c5acf61caa955533f406851e1a65aec Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 10:59:11 +0300 Subject: Infr. Tests: Replace `str` with `dict` To allow `.get`, I had to replace `str` return value with `dict` --- tests/bot/cogs/moderation/test_infractions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index f8f340c2e..139209749 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -21,7 +21,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = 'foo' + get_active_mock.return_value = {"foo": "bar"} post_infraction_mock.return_value = {"foo": "bar"} self.cog.apply_infraction = AsyncMock() -- cgit v1.2.3 From d9730e41b3144862fdd9c221d160a40144a7c881 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 11:02:37 +0300 Subject: Infr. Test: Replace `get_active_mock` return value Replace `{"foo": "bar"}` with `{"id": 1}` --- tests/bot/cogs/moderation/test_infractions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 139209749..925439bf3 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -21,7 +21,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = {"foo": "bar"} + get_active_mock.return_value = {"id": 1} post_infraction_mock.return_value = {"foo": "bar"} self.cog.apply_infraction = AsyncMock() -- cgit v1.2.3 From a1b6d147befd4043acdddc00667d3bda94cc76ad Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Wed, 20 May 2020 11:15:09 +0300 Subject: Infr Tests: Make `get_active_infraction` return `None` --- tests/bot/cogs/moderation/test_infractions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 925439bf3..5548d9f68 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -21,7 +21,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): @patch("bot.cogs.moderation.utils.post_infraction") async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): """Should truncate reason for `ctx.guild.ban`.""" - get_active_mock.return_value = {"id": 1} + get_active_mock.return_value = None post_infraction_mock.return_value = {"foo": "bar"} self.cog.apply_infraction = AsyncMock() -- cgit v1.2.3 From 57fe4bf893e94289b5b6f7158ff2d6b92b1e3fae Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Fri, 22 May 2020 22:56:25 +0200 Subject: Set up async testbed --- bot/bot.py | 3 +- bot/utils/redis_cache.py | 13 ++- tests/bot/utils/test_redis_cache.py | 128 ++++++++++++++++++++++++ tests/bot/utils/test_redis_dict.py | 189 ------------------------------------ 4 files changed, 135 insertions(+), 198 deletions(-) create mode 100644 tests/bot/utils/test_redis_cache.py delete mode 100644 tests/bot/utils/test_redis_dict.py (limited to 'tests') diff --git a/bot/bot.py b/bot/bot.py index f55eec5bb..8a3805989 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -45,8 +45,7 @@ class Bot(commands.Bot): # will effectively disable stats. statsd_url = "127.0.0.1" - asyncio.create_task(self._create_redis_session()) - + self.loop.create_task(self._create_redis_session()) self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot") async def _create_redis_session(self) -> None: diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 467f16767..483bbc2cd 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -30,8 +30,8 @@ class RedisCache: def __init__(self) -> None: """Raise a NotImplementedError if `__set_name__` hasn't been run.""" - if not self._namespace: - raise NotImplementedError("RedisCache must be a class attribute.") + self._namespace = None + self.bot = None def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -47,8 +47,7 @@ class RedisCache: Called automatically when this class is constructed inside a class as an attribute. """ - if not self._has_custom_namespace: - self._set_namespace(f"{owner.__name__}.{attribute_name}") + self._set_namespace(f"{owner.__name__}.{attribute_name}") def __get__(self, instance: RedisCache, owner: Any) -> RedisCache: """Fetch the Bot instance, we need it for the redis pool.""" @@ -106,9 +105,9 @@ class RedisCache: async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: """Get the item, remove it from the cache, and provide a default if not found.""" - value = await self.get(key, default) - await self.delete(key) - return value + # value = await self.get(key, default) + # await self.delete(key) + # return value async def update(self) -> None: """Update the Redis cache with multiple values.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py new file mode 100644 index 000000000..f6344803f --- /dev/null +++ b/tests/bot/utils/test_redis_cache.py @@ -0,0 +1,128 @@ +import asyncio +import unittest +from unittest.mock import MagicMock + +import fakeredis.aioredis + +from bot.bot import Bot +from bot.utils import RedisCache + + +class RedisCacheTests(unittest.IsolatedAsyncioTestCase): + """Tests the RedisDict class from utils.redis_dict.py.""" + + redis = RedisCache() + + async def asyncSetUp(self): # noqa: N802 - this special method can't be all lowercase + """Sets up the objects that only have to be initialized once.""" + self.bot = MagicMock( + spec=Bot, + redis_session=await fakeredis.aioredis.create_redis_pool(), + _redis_ready=asyncio.Event(), + ) + self.bot._redis_ready.set() + + def test_class_attribute_namespace(self): + """Test that RedisDict creates a namespace automatically for class attributes.""" + self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") + # Test that errors are raised when this isn't true. + + # def test_set_get_item(self): + # """Test that users can set and get items from the RedisDict.""" + # self.redis['favorite_fruit'] = 'melon' + # self.redis['favorite_number'] = 86 + # self.assertEqual(self.redis['favorite_fruit'], 'melon') + # self.assertEqual(self.redis['favorite_number'], 86) + # + # def test_set_item_types(self): + # """Test that setitem rejects keys and values that are not strings, ints or floats.""" + # fruits = ["lemon", "melon", "apple"] + # + # with self.assertRaises(DataError): + # self.redis[fruits] = "nice" + # + # def test_contains(self): + # """Test that we can reliably use the `in` operator with our RedisDict.""" + # self.redis['favorite_country'] = "Burkina Faso" + # + # self.assertIn('favorite_country', self.redis) + # self.assertNotIn('favorite_dentist', self.redis) + # + # def test_items(self): + # """Test that the RedisDict can be iterated.""" + # self.redis.clear() + # test_cases = ( + # ('favorite_turtle', 'Donatello'), + # ('second_favorite_turtle', 'Leonardo'), + # ('third_favorite_turtle', 'Raphael'), + # ) + # for key, value in test_cases: + # self.redis[key] = value + # + # # Test regular iteration + # for test_case, key in zip(test_cases, self.redis): + # value = test_case[1] + # self.assertEqual(self.redis[key], value) + # + # # Test .items iteration + # for key, value in self.redis.items(): + # self.assertEqual(self.redis[key], value) + # + # # Test .keys iteration + # for test_case, key in zip(test_cases, self.redis.keys()): + # value = test_case[1] + # self.assertEqual(self.redis[key], value) + # + # def test_length(self): + # """Test that we can get the correct len() from the RedisDict.""" + # self.redis.clear() + # self.redis['one'] = 1 + # self.redis['two'] = 2 + # self.redis['three'] = 3 + # self.assertEqual(len(self.redis), 3) + # + # self.redis['four'] = 4 + # self.assertEqual(len(self.redis), 4) + # + # def test_to_dict(self): + # """Test that the .copy method returns a workable dictionary copy.""" + # copy = self.redis.copy() + # local_copy = dict(self.redis.items()) + # self.assertIs(type(copy), dict) + # self.assertEqual(copy, local_copy) + # + # def test_clear(self): + # """Test that the .clear method removes the entire hash.""" + # self.redis.clear() + # self.redis['teddy'] = "with me" + # self.redis['in my dreams'] = "you have a weird hat" + # self.assertEqual(len(self.redis), 2) + # + # self.redis.clear() + # self.assertEqual(len(self.redis), 0) + # + # def test_pop(self): + # """Test that we can .pop an item from the RedisDict.""" + # self.redis.clear() + # self.redis['john'] = 'was afraid' + # + # self.assertEqual(self.redis.pop('john'), 'was afraid') + # self.assertEqual(self.redis.pop('pete', 'breakneck'), 'breakneck') + # self.assertEqual(len(self.redis), 0) + # + # def test_update(self): + # """Test that we can .update the RedisDict with multiple items.""" + # self.redis.clear() + # self.redis["reckfried"] = "lona" + # self.redis["bel air"] = "prince" + # self.redis.update({ + # "reckfried": "jona", + # "mega": "hungry, though", + # }) + # + # result = { + # "reckfried": "jona", + # "bel air": "prince", + # "mega": "hungry, though", + # } + # self.assertEqual(self.redis.copy(), result) diff --git a/tests/bot/utils/test_redis_dict.py b/tests/bot/utils/test_redis_dict.py deleted file mode 100644 index f422887ce..000000000 --- a/tests/bot/utils/test_redis_dict.py +++ /dev/null @@ -1,189 +0,0 @@ -import unittest - -import fakeredis -from redis import DataError - -from bot.utils import RedisDict - -redis_server = fakeredis.FakeServer() -RedisDict._redis = fakeredis.FakeStrictRedis(server=redis_server) - - -class RedisDictTests(unittest.TestCase): - """Tests the RedisDict class from utils.redis_dict.py.""" - - redis = RedisDict() - - def test_class_attribute_namespace(self): - """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.redis._namespace, "RedisDictTests.redis") - - def test_custom_namespace(self): - """Test that users can set a custom namespaces which never collide.""" - test_cases = ( - (RedisDict("firedog")._namespace, "firedog"), - (RedisDict("firedog")._namespace, "firedog_"), - (RedisDict("firedog")._namespace, "firedog__"), - ) - - for test_case, result in test_cases: - self.assertEqual(test_case, result) - - def test_custom_namespace_takes_precedence(self): - """Test that custom namespaces take precedence over class attribute ones.""" - class LemonJuice: - citrus = RedisDict("citrus") - watercat = RedisDict() - - test_class = LemonJuice() - self.assertEqual(test_class.citrus._namespace, "citrus") - self.assertEqual(test_class.watercat._namespace, "LemonJuice.watercat") - - def test_set_get_item(self): - """Test that users can set and get items from the RedisDict.""" - self.redis['favorite_fruit'] = 'melon' - self.redis['favorite_number'] = 86 - self.assertEqual(self.redis['favorite_fruit'], 'melon') - self.assertEqual(self.redis['favorite_number'], 86) - - def test_set_item_value_types(self): - """Test that setitem rejects values that are not JSON serializable.""" - with self.assertRaises(TypeError): - self.redis['favorite_thing'] = object - self.redis['favorite_stuff'] = RedisDict - - def test_set_item_key_types(self): - """Test that setitem rejects keys that are not strings, ints or floats.""" - fruits = ["lemon", "melon", "apple"] - - with self.assertRaises(DataError): - self.redis[fruits] = "nice" - - def test_get_method(self): - """Test that the .get method works like in a dict.""" - self.redis['favorite_movie'] = 'Code Jam Highlights' - - self.assertEqual(self.redis.get('favorite_movie'), 'Code Jam Highlights') - self.assertEqual(self.redis.get('favorite_youtuber', 'pydis'), 'pydis') - self.assertIsNone(self.redis.get('favorite_dog')) - - def test_membership(self): - """Test that we can reliably use the `in` operator with our RedisDict.""" - self.redis['favorite_country'] = "Burkina Faso" - - self.assertIn('favorite_country', self.redis) - self.assertNotIn('favorite_dentist', self.redis) - - def test_del_item(self): - """Test that users can delete items from the RedisDict.""" - self.redis['favorite_band'] = "Radiohead" - self.assertIn('favorite_band', self.redis) - - del self.redis['favorite_band'] - self.assertNotIn('favorite_band', self.redis) - - def test_iter(self): - """Test that the RedisDict can be iterated.""" - self.redis.clear() - test_cases = ( - ('favorite_turtle', 'Donatello'), - ('second_favorite_turtle', 'Leonardo'), - ('third_favorite_turtle', 'Raphael'), - ) - for key, value in test_cases: - self.redis[key] = value - - # Test regular iteration - for test_case, key in zip(test_cases, self.redis): - value = test_case[1] - self.assertEqual(self.redis[key], value) - - # Test .items iteration - for key, value in self.redis.items(): - self.assertEqual(self.redis[key], value) - - # Test .keys iteration - for test_case, key in zip(test_cases, self.redis.keys()): - value = test_case[1] - self.assertEqual(self.redis[key], value) - - def test_len(self): - """Test that we can get the correct len() from the RedisDict.""" - self.redis.clear() - self.redis['one'] = 1 - self.redis['two'] = 2 - self.redis['three'] = 3 - self.assertEqual(len(self.redis), 3) - - self.redis['four'] = 4 - self.assertEqual(len(self.redis), 4) - - def test_copy(self): - """Test that the .copy method returns a workable dictionary copy.""" - copy = self.redis.copy() - local_copy = dict(self.redis.items()) - self.assertIs(type(copy), dict) - self.assertEqual(copy, local_copy) - - def test_clear(self): - """Test that the .clear method removes the entire hash.""" - self.redis.clear() - self.redis['teddy'] = "with me" - self.redis['in my dreams'] = "you have a weird hat" - self.assertEqual(len(self.redis), 2) - - self.redis.clear() - self.assertEqual(len(self.redis), 0) - - def test_pop(self): - """Test that we can .pop an item from the RedisDict.""" - self.redis.clear() - self.redis['john'] = 'was afraid' - - self.assertEqual(self.redis.pop('john'), 'was afraid') - self.assertEqual(self.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(len(self.redis), 0) - - def test_popitem(self): - """Test that we can .popitem an item from the RedisDict.""" - self.redis.clear() - self.redis['john'] = 'the revalator' - self.redis['teddy'] = 'big bear' - - self.assertEqual(len(self.redis), 2) - self.assertEqual(self.redis.popitem(), 'big bear') - self.assertEqual(len(self.redis), 1) - - def test_setdefault(self): - """Test that we can .setdefault an item from the RedisDict.""" - self.redis.clear() - self.redis.setdefault('john', 'is yellow and weak') - self.assertEqual(self.redis['john'], 'is yellow and weak') - - with self.assertRaises(TypeError): - self.redis.setdefault('geisha', object) - - def test_update(self): - """Test that we can .update the RedisDict with multiple items.""" - self.redis.clear() - self.redis["reckfried"] = "lona" - self.redis["bel air"] = "prince" - self.redis.update({ - "reckfried": "jona", - "mega": "hungry, though", - }) - - result = { - "reckfried": "jona", - "bel air": "prince", - "mega": "hungry, though", - } - self.assertEqual(self.redis.copy(), result) - - def test_equals(self): - """Test that RedisDicts can be compared with == and !=.""" - new_redis_dict = RedisDict("firedog_the_sequel") - new_new_redis_dict = new_redis_dict - - self.assertEqual(new_redis_dict, new_new_redis_dict) - self.assertNotEqual(new_redis_dict, self.redis) -- cgit v1.2.3 From fd6f3d30b4c67f9a81346bb142d4696948fa2812 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 22 May 2020 15:40:50 -0700 Subject: Fix assertion for `create_task` in duck pond tests The assertion wasn't using the assertion method. Furthermore, it was testing a non-existent function `create_loop` rather than `create_task`. --- tests/bot/cogs/test_duck_pond.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 7e6bfc748..a8c0107c6 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): self.assertEqual(cog.bot, bot) self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_loop.called_once_with(cog.fetch_webhook()) + bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) def test_fetch_webhook_succeeds_without_connectivity_issues(self): """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" -- cgit v1.2.3 From 45e6f8dba869a367b01d99a596bd3355802d1fbe Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 22 May 2020 15:44:04 -0700 Subject: Improve aiohttp context manager mocking in snekbox tests I'm not sure how it even managed to work before. It was calling the `post` coroutine (without specifying a URL) and then changing `__aenter__`. Now, a separate mock is created for the context manager and the `post` simply returns that mocked context manager. --- tests/bot/cogs/test_snekbox.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..84b273a7d 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") - self.bot.http_session.post().__aenter__.return_value = resp + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager self.assertEqual(await self.cog.post_eval("import random"), "return") self.bot.http_session.post.assert_called_with( @@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): key = "MarkDiamond" resp = MagicMock() resp.json = AsyncMock(return_value={"key": key}) - self.bot.http_session.post().__aenter__.return_value = resp + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager self.assertEqual( await self.cog.upload_output("My awesome output"), @@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Output upload gracefully fallback if the upload fail.""" resp = MagicMock() resp.json = AsyncMock(side_effect=Exception) - self.bot.http_session.post().__aenter__.return_value = resp + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager log = logging.getLogger("bot.cogs.snekbox") with self.assertLogs(logger=log, level='ERROR'): -- cgit v1.2.3 From 6aed2f6b69b79b5a7e5f327819d026e7a63a7dab Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 22 May 2020 16:15:23 -0700 Subject: Fix unawaited coro warning when instantiating Bot for MockBot's spec The fix is to mock the loop and pass it to the Bot. It will then set it as `self.loop` rather than trying to get an event loop from asyncio. The `create_task` patch has been moved to this loop mock rather than being done in MockBot to ensure that it applies to anything calling it when instantiating the Bot. --- tests/helpers.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2b79a6c2a..2efeff7db 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,6 +4,7 @@ import collections import itertools import logging import unittest.mock +from asyncio import AbstractEventLoop from typing import Iterable, Optional import discord @@ -264,10 +265,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): spec_set = APIClient -# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -bot_instance.http_session = None -bot_instance.api_client = None +def _get_mock_loop() -> unittest.mock.Mock: + """Return a mocked asyncio.AbstractEventLoop.""" + loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) + + # Since calling `create_task` on our MockBot does not actually schedule the coroutine object + # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object + # to prevent "has not been awaited"-warnings. + loop.create_task.side_effect = lambda coroutine: coroutine.close() + + return loop class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -277,17 +284,14 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ - spec_set = bot_instance + spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.api_client = MockAPIClient() - # Since calling `create_task` on our MockBot does not actually schedule the coroutine object - # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object - # to prevent "has not been awaited"-warnings. - self.loop.create_task.side_effect = lambda coroutine: coroutine.close() + self.loop = _get_mock_loop() + self.api_client = MockAPIClient(loop=self.loop) # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` -- cgit v1.2.3 From 1ad7833d800918efca06e5d6b2fbafdb0d757009 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 22 May 2020 16:23:12 -0700 Subject: Properly mock the redis pool in MockBot Because some of the redis pool/connection methods return futures rather than being coroutines, the redis pool had to be mocked using the CustomMockMixin so it could take advantage of `additional_spec_asyncs` to use AsyncMocks for these future-returning methods. --- tests/bot/utils/test_redis_cache.py | 12 +++--------- tests/helpers.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index f6344803f..991225481 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -1,11 +1,9 @@ -import asyncio import unittest -from unittest.mock import MagicMock import fakeredis.aioredis -from bot.bot import Bot from bot.utils import RedisCache +from tests import helpers class RedisCacheTests(unittest.IsolatedAsyncioTestCase): @@ -15,12 +13,8 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): # noqa: N802 - this special method can't be all lowercase """Sets up the objects that only have to be initialized once.""" - self.bot = MagicMock( - spec=Bot, - redis_session=await fakeredis.aioredis.create_redis_pool(), - _redis_ready=asyncio.Event(), - ) - self.bot._redis_ready.set() + self.bot = helpers.MockBot() + self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" diff --git a/tests/helpers.py b/tests/helpers.py index 2efeff7db..33d4f787c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,7 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional +import aioredis.abc import discord from discord.ext.commands import Context @@ -265,6 +266,17 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): spec_set = APIClient +class MockRedisPool(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock an aioredis connection pool. + + Instances of this class will follow the specifications of `aioredis.abc.AbcPool` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = aioredis.abc.AbcPool + additional_spec_asyncs = ("execute", "execute_pubsub") + + def _get_mock_loop() -> unittest.mock.Mock: """Return a mocked asyncio.AbstractEventLoop.""" loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) @@ -293,6 +305,10 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.loop = _get_mock_loop() self.api_client = MockAPIClient(loop=self.loop) + # fakeredis can't be used cause it'd require awaiting a coroutine to create the pool, + # which cannot be done here in __init__. + self.redis_session = MockRedisPool() + # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { -- cgit v1.2.3 From d8f1634ab68b2cd480d57c8b9da8834866b5c9cc Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 22 May 2020 16:25:10 -0700 Subject: Use autospecced mocks in MockBot for the stats and aiohttp session This will help catch anything that tries to get/set an attribute/method which doesn't exist. It'll also catch missing/too many parameters being passed to methods. --- tests/helpers.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 33d4f787c..d226be3f0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,9 +9,11 @@ from typing import Iterable, Optional import aioredis.abc import discord +from aiohttp import ClientSession from discord.ext.commands import Context from bot.api import APIClient +from bot.async_stats import AsyncStatsClient from bot.bot import Bot @@ -304,6 +306,8 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.loop = _get_mock_loop() self.api_client = MockAPIClient(loop=self.loop) + self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) + self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) # fakeredis can't be used cause it'd require awaiting a coroutine to create the pool, # which cannot be done here in __init__. -- cgit v1.2.3 From eb63fb02a49bf1979afd04a1350304edf00d3a56 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 23 May 2020 02:06:27 +0200 Subject: Finish .set and .get, and add tests. The .set and .get will accept ints, floats, and strings. These will be converted into "typestrings", which is basically just a simple format that's been invented for this object. For example, an int looks like `b"i|2423"`. Note how it is still stored as a bytestring (like everything in Redis), but because of this prefix we are able to coerce it into the type we want on the way out of the db. --- bot/utils/redis_cache.py | 72 ++++++++++++++++++++++++++++++------- tests/bot/utils/test_redis_cache.py | 36 +++++++++++++------ tests/helpers.py | 2 +- 3 files changed, 85 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 483bbc2cd..24f2f2e03 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -1,12 +1,10 @@ from __future__ import annotations -from enum import Enum -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Dict, Optional, Union from bot.bot import Bot -ValidRedisKey = Union[str, int, float] -JSONSerializableType = Optional[Union[str, float, bool, Dict, List, Tuple, Enum]] +ValidRedisType = Union[str, int, float] class RedisCache: @@ -41,7 +39,39 @@ class RedisCache: self._namespaces.append(namespace) self._namespace = namespace - def __set_name__(self, owner: object, attribute_name: str) -> None: + @staticmethod + def _to_typestring(value: ValidRedisType) -> str: + """Turn a valid Redis type into a typestring.""" + if isinstance(value, float): + return f"f|{value}" + elif isinstance(value, int): + return f"i|{value}" + elif isinstance(value, str): + return f"s|{value}" + + @staticmethod + def _from_typestring(value: str) -> ValidRedisType: + """Turn a valid Redis type into a typestring.""" + if value.startswith("f|"): + return float(value[2:]) + if value.startswith("i|"): + return int(value[2:]) + if value.startswith("s|"): + return value[2:] + + async def _validate_cache(self) -> None: + """Validate that the RedisCache is ready to be used.""" + if self.bot is None: + raise RuntimeError("Critical error: RedisCache has no `Bot` instance.") + + if self._namespace is None: + raise RuntimeError( + "Critical error: RedisCache has no namespace. " + "Did you initialize this object as a class attribute?" + ) + await self.bot._redis_ready.wait() + + def __set_name__(self, owner: Any, attribute_name: str) -> None: """ Set the namespace to Class.attribute_name. @@ -54,8 +84,11 @@ class RedisCache: if self.bot: return self + if self._namespace is None: + raise RuntimeError("RedisCache must be a class attribute.") + if instance is None: - raise NotImplementedError("You must create an instance of RedisCache to use it.") + raise RuntimeError("You must create an instance of RedisCache to use it.") for attribute in vars(instance).values(): if isinstance(attribute, Bot): @@ -69,19 +102,32 @@ class RedisCache: """Return a beautiful representation of this object instance.""" return f"RedisCache(namespace={self._namespace!r})" - async def set(self, key: ValidRedisKey, value: JSONSerializableType) -> None: + async def set(self, key: ValidRedisType, value: ValidRedisType) -> None: """Store an item in the Redis cache.""" - # await self._redis.hset(self._namespace, key, value) + await self._validate_cache() + + # Convert to a typestring and then set it + key = self._to_typestring(key) + value = self._to_typestring(value) + await self._redis.hset(self._namespace, key, value) - async def get(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + async def get(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType: """Get an item from the Redis cache.""" - # value = await self._redis.hget(self._namespace, key) + await self._validate_cache() + key = self._to_typestring(key) + value = await self._redis.hget(self._namespace, key) + + if value is None: + return default + else: + value = self._from_typestring(value.decode("utf-8")) + return value - async def delete(self, key: ValidRedisKey) -> None: + async def delete(self, key: ValidRedisType) -> None: """Delete an item from the Redis cache.""" # await self._redis.hdel(self._namespace, key) - async def contains(self, key: ValidRedisKey) -> bool: + async def contains(self, key: ValidRedisType) -> bool: """Check if a key exists in the Redis cache.""" # return await self._redis.hexists(self._namespace, key) @@ -103,7 +149,7 @@ class RedisCache: """Deletes the entire hash from the Redis cache.""" # await self._redis.delete(self._namespace) - async def pop(self, key: ValidRedisKey, default: Optional[JSONSerializableType] = None) -> JSONSerializableType: + async def pop(self, key: ValidRedisType, default: Optional[ValidRedisType] = None) -> ValidRedisType: """Get the item, remove it from the cache, and provide a default if not found.""" # value = await self.get(key, default) # await self.delete(key) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 991225481..ad38bfde0 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -7,27 +7,41 @@ from tests import helpers class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisDict class from utils.redis_dict.py.""" + """Tests the RedisCache class from utils.redis_dict.py.""" redis = RedisCache() - async def asyncSetUp(self): # noqa: N802 - this special method can't be all lowercase + async def asyncSetUp(self): # noqa: N802 """Sets up the objects that only have to be initialized once.""" self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - def test_class_attribute_namespace(self): + async def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") - # Test that errors are raised when this isn't true. - # def test_set_get_item(self): - # """Test that users can set and get items from the RedisDict.""" - # self.redis['favorite_fruit'] = 'melon' - # self.redis['favorite_number'] = 86 - # self.assertEqual(self.redis['favorite_fruit'], 'melon') - # self.assertEqual(self.redis['favorite_number'], 86) - # + # Test that errors are raised when not assigned as a class attribute + bad_cache = RedisCache() + + with self.assertRaises(RuntimeError): + await bad_cache.set("test", "me_up_deadman") + + async def test_set_get_item(self): + """Test that users can set and get items from the RedisDict.""" + test_cases = ( + ('favorite_fruit', 'melon'), + ('favorite_number', 86), + ('favorite_fraction', 86.54) + ) + + # Test that we can get and set different types. + for test in test_cases: + await self.redis.set(*test) + self.assertEqual(await self.redis.get(test[0]), test[1]) + + # Test that .get allows a default value + self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + # def test_set_item_types(self): # """Test that setitem rejects keys and values that are not strings, ints or floats.""" # fruits = ["lemon", "melon", "apple"] diff --git a/tests/helpers.py b/tests/helpers.py index d226be3f0..2b176db79 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -299,7 +299,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): For more information, see the `MockGuild` docstring. """ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) - additional_spec_asyncs = ("wait_for",) + additional_spec_asyncs = ("wait_for", "_redis_ready") def __init__(self, **kwargs) -> None: super().__init__(**kwargs) -- cgit v1.2.3 From 387bf5c6b6a21e25c4fc690fb992b6b3e4c165a6 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 23 May 2020 11:36:12 +0200 Subject: Complete asyncified test suite for RedisCache This commit just alters existing code to work with the new interface, and with async. All tests are passing successfully. --- tests/bot/utils/test_redis_cache.py | 206 ++++++++++++++++++++---------------- 1 file changed, 112 insertions(+), 94 deletions(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index ad38bfde0..d257e91d9 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -16,16 +16,24 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - async def test_class_attribute_namespace(self): + def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") - # Test that errors are raised when not assigned as a class attribute + async def test_class_attribute_required(self): + """Test that errors are raised when not assigned as a class attribute.""" bad_cache = RedisCache() + self.assertIs(bad_cache._namespace, None) with self.assertRaises(RuntimeError): await bad_cache.set("test", "me_up_deadman") + def test_namespace_collision(self): + """Test that we prevent colliding namespaces.""" + bad_cache = RedisCache() + bad_cache._set_namespace("RedisCacheTests.redis") + self.assertEqual(bad_cache._namespace, "RedisCacheTests.redis_") + async def test_set_get_item(self): """Test that users can set and get items from the RedisDict.""" test_cases = ( @@ -42,95 +50,105 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test that .get allows a default value self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw") - # def test_set_item_types(self): - # """Test that setitem rejects keys and values that are not strings, ints or floats.""" - # fruits = ["lemon", "melon", "apple"] - # - # with self.assertRaises(DataError): - # self.redis[fruits] = "nice" - # - # def test_contains(self): - # """Test that we can reliably use the `in` operator with our RedisDict.""" - # self.redis['favorite_country'] = "Burkina Faso" - # - # self.assertIn('favorite_country', self.redis) - # self.assertNotIn('favorite_dentist', self.redis) - # - # def test_items(self): - # """Test that the RedisDict can be iterated.""" - # self.redis.clear() - # test_cases = ( - # ('favorite_turtle', 'Donatello'), - # ('second_favorite_turtle', 'Leonardo'), - # ('third_favorite_turtle', 'Raphael'), - # ) - # for key, value in test_cases: - # self.redis[key] = value - # - # # Test regular iteration - # for test_case, key in zip(test_cases, self.redis): - # value = test_case[1] - # self.assertEqual(self.redis[key], value) - # - # # Test .items iteration - # for key, value in self.redis.items(): - # self.assertEqual(self.redis[key], value) - # - # # Test .keys iteration - # for test_case, key in zip(test_cases, self.redis.keys()): - # value = test_case[1] - # self.assertEqual(self.redis[key], value) - # - # def test_length(self): - # """Test that we can get the correct len() from the RedisDict.""" - # self.redis.clear() - # self.redis['one'] = 1 - # self.redis['two'] = 2 - # self.redis['three'] = 3 - # self.assertEqual(len(self.redis), 3) - # - # self.redis['four'] = 4 - # self.assertEqual(len(self.redis), 4) - # - # def test_to_dict(self): - # """Test that the .copy method returns a workable dictionary copy.""" - # copy = self.redis.copy() - # local_copy = dict(self.redis.items()) - # self.assertIs(type(copy), dict) - # self.assertEqual(copy, local_copy) - # - # def test_clear(self): - # """Test that the .clear method removes the entire hash.""" - # self.redis.clear() - # self.redis['teddy'] = "with me" - # self.redis['in my dreams'] = "you have a weird hat" - # self.assertEqual(len(self.redis), 2) - # - # self.redis.clear() - # self.assertEqual(len(self.redis), 0) - # - # def test_pop(self): - # """Test that we can .pop an item from the RedisDict.""" - # self.redis.clear() - # self.redis['john'] = 'was afraid' - # - # self.assertEqual(self.redis.pop('john'), 'was afraid') - # self.assertEqual(self.redis.pop('pete', 'breakneck'), 'breakneck') - # self.assertEqual(len(self.redis), 0) - # - # def test_update(self): - # """Test that we can .update the RedisDict with multiple items.""" - # self.redis.clear() - # self.redis["reckfried"] = "lona" - # self.redis["bel air"] = "prince" - # self.redis.update({ - # "reckfried": "jona", - # "mega": "hungry, though", - # }) - # - # result = { - # "reckfried": "jona", - # "bel air": "prince", - # "mega": "hungry, though", - # } - # self.assertEqual(self.redis.copy(), result) + async def test_set_item_type(self): + """Test that .set rejects keys and values that are not strings, ints or floats.""" + fruits = ["lemon", "melon", "apple"] + + with self.assertRaises(TypeError): + await self.redis.set(fruits, "nice") + + async def test_delete_item(self): + """Test that .delete allows us to delete stuff from the RedisCache.""" + # Add an item and verify that it gets added + await self.redis.set("internet", "firetruck") + self.assertEqual(await self.redis.get("internet"), "firetruck") + + # Delete that item and verify that it gets deleted + await self.redis.delete("internet") + self.assertIs(await self.redis.get("internet"), None) + + async def test_contains(self): + """Test that we can check membership with .contains.""" + await self.redis.set('favorite_country', "Burkina Faso") + + self.assertIs(await self.redis.contains('favorite_country'), True) + self.assertIs(await self.redis.contains('favorite_dentist'), False) + + async def test_items(self): + """Test that the RedisDict can be iterated.""" + await self.redis.clear() + + # Set up our test cases in the Redis cache + test_cases = [ + ('favorite_turtle', 'Donatello'), + ('second_favorite_turtle', 'Leonardo'), + ('third_favorite_turtle', 'Raphael'), + ] + for key, value in test_cases: + await self.redis.set(key, value) + + # Consume the AsyncIterator into a regular list, easier to compare that way. + redis_items = [item async for item in self.redis.items()] + + # These sequences are probably in the same order now, but probably + # isn't good enough for tests. Let's not rely on .hgetall always + # returning things in sequence, and just sort both lists to be safe. + redis_items = sorted(redis_items) + test_cases = sorted(test_cases) + + # If these are equal now, everything works fine. + self.assertSequenceEqual(test_cases, redis_items) + + async def test_length(self): + """Test that we can get the correct .length from the RedisDict.""" + await self.redis.clear() + await self.redis.set('one', 1) + await self.redis.set('two', 2) + await self.redis.set('three', 3) + self.assertEqual(await self.redis.length(), 3) + + await self.redis.set('four', 4) + self.assertEqual(await self.redis.length(), 4) + + async def test_to_dict(self): + """Test that the .copy method returns a workable dictionary copy.""" + copy = await self.redis.to_dict() + local_copy = {key: value async for key, value in self.redis.items()} + self.assertIs(type(copy), dict) + self.assertDictEqual(copy, local_copy) + + async def test_clear(self): + """Test that the .clear method removes the entire hash.""" + await self.redis.clear() + await self.redis.set('teddy', 'with me') + await self.redis.set('in my dreams', 'you have a weird hat') + self.assertEqual(await self.redis.length(), 2) + + await self.redis.clear() + self.assertEqual(await self.redis.length(), 0) + + async def test_pop(self): + """Test that we can .pop an item from the RedisDict.""" + await self.redis.clear() + await self.redis.set('john', 'was afraid') + + self.assertEqual(await self.redis.pop('john'), 'was afraid') + self.assertEqual(await self.redis.pop('pete', 'breakneck'), 'breakneck') + self.assertEqual(await self.redis.length(), 0) + + async def test_update(self): + """Test that we can .update the RedisDict with multiple items.""" + await self.redis.clear() + await self.redis.set("reckfried", "lona") + await self.redis.set("bel air", "prince") + await self.redis.update({ + "reckfried": "jona", + "mega": "hungry, though", + }) + + result = { + "reckfried": "jona", + "bel air": "prince", + "mega": "hungry, though", + } + self.assertDictEqual(await self.redis.to_dict(), result) -- cgit v1.2.3 From aa0bb028ed889d93376981213673053a540e137c Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 23 May 2020 14:30:56 +0200 Subject: Fix typo in test_to_dict docstring --- tests/bot/utils/test_redis_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index d257e91d9..2ce57499a 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -111,7 +111,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(await self.redis.length(), 4) async def test_to_dict(self): - """Test that the .copy method returns a workable dictionary copy.""" + """Test that the .to_dict method returns a workable dictionary copy.""" copy = await self.redis.to_dict() local_copy = {key: value async for key, value in self.redis.items()} self.assertIs(type(copy), dict) -- cgit v1.2.3 From 5120717a47c07812d1631cf0905ff3062e139487 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 23 May 2020 15:17:57 +0200 Subject: DRY approach to typestring prefix resolution Thanks to @kwzrd for this idea, basically we're making a constant with the typestring prefixes and iterating that in all our converters. These converter functions will also now raise TypeErrors if we try to convert something that isn't in this constants list. I've also added a new test that tests this functionality. --- bot/utils/redis_cache.py | 54 ++++++++++++++++++++++++++----------- tests/bot/utils/test_redis_cache.py | 21 +++++++++++++++ 2 files changed, 60 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 6831be157..1ec1b9fea 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -4,6 +4,11 @@ from typing import Any, AsyncIterator, Dict, Optional, Union from bot.bot import Bot +TYPESTRING_PREFIXES = ( + ("f|", float), + ("i|", int), + ("s|", str), +) ValidRedisType = Union[str, int, float] @@ -78,26 +83,45 @@ class RedisCache: self._namespace = namespace @staticmethod - def _to_typestring(value: ValidRedisType) -> str: - """Turn a valid Redis type into a typestring.""" - if isinstance(value, float): - return f"f|{value}" - elif isinstance(value, int): - return f"i|{value}" - elif isinstance(value, str): - return f"s|{value}" + def _valid_typestring_types() -> str: + """ + Creates a nice, readable list of valid types for typestrings, useful for error messages. + + This will be dynamically updated if we change the TYPESTRING_PREFIXES constant up top. + """ + valid_types = ", ".join([str(_type).split("'")[1] for _, _type in TYPESTRING_PREFIXES]) + valid_types = ", and ".join(valid_types.rsplit(", ", 1)) + return valid_types @staticmethod - def _from_typestring(value: Union[bytes, str]) -> ValidRedisType: + def _valid_typestring_prefixes() -> str: + """ + Creates a nice, readable list of valid prefixes for typestrings, useful for error messages. + + This will be dynamically updated if we change the TYPESTRING_PREFIXES constant up top. + """ + valid_prefixes = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_PREFIXES]) + valid_prefixes = ", and ".join(valid_prefixes.rsplit(", ", 1)) + return valid_prefixes + + def _to_typestring(self, value: ValidRedisType) -> str: + """Turn a valid Redis type into a typestring.""" + for prefix, _type in TYPESTRING_PREFIXES: + if isinstance(value, _type): + return f"{prefix}{value}" + raise TypeError(f"RedisCache._from_typestring only supports the types {self._valid_typestring_types()}.") + + def _from_typestring(self, value: Union[bytes, str]) -> ValidRedisType: """Turn a typestring into a valid Redis type.""" + # Stuff that comes out of Redis will be bytestrings, so let's decode those. if isinstance(value, bytes): value = value.decode('utf-8') - if value.startswith("f|"): - return float(value[2:]) - if value.startswith("i|"): - return int(value[2:]) - if value.startswith("s|"): - return value[2:] + + # Now we convert our unicode string back into the type it originally was. + for prefix, _type in TYPESTRING_PREFIXES: + if value.startswith(prefix): + return _type(value[2:]) + raise TypeError(f"RedisCache._to_typestring only supports the prefixes {self._valid_typestring_prefixes()}.") def _dict_from_typestring(self, dictionary: Dict) -> Dict: """Turns all contents of a dict into valid Redis types.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 2ce57499a..150195726 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -152,3 +152,24 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): "mega": "hungry, though", } self.assertDictEqual(await self.redis.to_dict(), result) + + def test_typestring_conversion(self): + """Test the typestring-related helper functions.""" + conversion_tests = ( + (12, "i|12"), + (12.4, "f|12.4"), + ("cowabunga", "s|cowabunga"), + ) + + # Test conversion to typestring + for _input, expected in conversion_tests: + self.assertEqual(self.redis._to_typestring(_input), expected) + + # Test conversion from typestrings + for _input, expected in conversion_tests: + self.assertEqual(self.redis._from_typestring(expected), _input) + + # Test that exceptions are raised on invalid input + with self.assertRaises(TypeError): + self.redis._to_typestring(["internet"]) + self.redis._from_typestring("o|firedog") -- cgit v1.2.3 From a52a13020f3468c671cb549052a9c8e303ae9d8c Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 23 May 2020 10:27:32 -0700 Subject: Remove redis session mock from MockBot It's not feasible to mock it because all the commands return futures rather than being coroutines, so they cannot automatically be turned into AsyncMocks. Furthermore, no code should ever use the redis session directly besides RedisCache. Since the tests for RedisCache already use fakeredis, there's no use in trying to mock redis in MockBot. --- tests/helpers.py | 16 ---------------- 1 file changed, 16 deletions(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 2b176db79..5ad826156 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,6 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional -import aioredis.abc import discord from aiohttp import ClientSession from discord.ext.commands import Context @@ -268,17 +267,6 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): spec_set = APIClient -class MockRedisPool(CustomMockMixin, unittest.mock.MagicMock): - """ - A MagicMock subclass to mock an aioredis connection pool. - - Instances of this class will follow the specifications of `aioredis.abc.AbcPool` instances. - For more information, see the `MockGuild` docstring. - """ - spec_set = aioredis.abc.AbcPool - additional_spec_asyncs = ("execute", "execute_pubsub") - - def _get_mock_loop() -> unittest.mock.Mock: """Return a mocked asyncio.AbstractEventLoop.""" loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) @@ -309,10 +297,6 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) - # fakeredis can't be used cause it'd require awaiting a coroutine to create the pool, - # which cannot be done here in __init__. - self.redis_session = MockRedisPool() - # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { -- cgit v1.2.3 From b2009d5304beba4829b7727ca154bb6a0d1cd50a Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 24 May 2020 11:42:33 +0200 Subject: Make .items return ItemsView instead of AsyncIter There really was no compelling reason why this method should return an AsyncIterator or than that `async for items in cache.items()` has nice readability, but there were a few concerns. One is a concern about race conditions raised by @SebastiaanZ, and @MarkKoz raised a concern that it was misleading to have an AsyncIterator that only "pretended" to be lazy. To address these concerns, I've refactored it to return a regular ItemsView instead. I also improved the docstring, and fixed the relevant tests. --- bot/utils/redis_cache.py | 28 +++++++++++++++++++++------- tests/bot/utils/test_redis_cache.py | 4 ++-- 2 files changed, 23 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 558ab33a7..fb9a534bd 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, AsyncIterator, Dict, Optional, Union +from typing import Any, Dict, ItemsView, Optional, Union from bot.bot import Bot @@ -237,12 +237,26 @@ class RedisCache: key = self._to_typestring(key) return await self._redis.hexists(self._namespace, key) - async def items(self) -> AsyncIterator: - """Iterate all the items in the Redis cache.""" + async def items(self) -> ItemsView: + """ + Fetch all the key/value pairs in the cache. + + Returns a normal ItemsView, like you would get from dict.items(). + + Keep in mind that these items are just a _copy_ of the data in the + RedisCache - any changes you make to them will not be reflected + into the RedisCache itself. If you want to change these, you need + to make a .set call. + + Example: + items = await my_cache.items() + for key, value in items: + # Iterate like a normal dictionary + """ await self._validate_cache() - data = await self._redis.hgetall(self._namespace) # Get all the keys - for key, value in self._dict_from_typestring(data).items(): - yield key, value + return self._dict_from_typestring( + await self._redis.hgetall(self._namespace) + ).items() async def length(self) -> int: """Return the number of items in the Redis cache.""" @@ -251,7 +265,7 @@ class RedisCache: async def to_dict(self) -> Dict: """Convert to dict and return.""" - return {key: value async for key, value in self.items()} + return {key: value for key, value in await self.items()} async def clear(self) -> None: """Deletes the entire hash from the Redis cache.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 150195726..6e12002ed 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -88,7 +88,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): await self.redis.set(key, value) # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item async for item in self.redis.items()] + redis_items = [item for item in await self.redis.items()] # These sequences are probably in the same order now, but probably # isn't good enough for tests. Let's not rely on .hgetall always @@ -113,7 +113,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_to_dict(self): """Test that the .to_dict method returns a workable dictionary copy.""" copy = await self.redis.to_dict() - local_copy = {key: value async for key, value in self.redis.items()} + local_copy = {key: value for key, value in await self.redis.items()} self.assertIs(type(copy), dict) self.assertDictEqual(copy, local_copy) -- cgit v1.2.3 From 01bedcadf762262eef0a2b406faf66cdc16a5c85 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 24 May 2020 13:04:41 +0200 Subject: Add .increment and .decrement methods. Sometimes, we just want to store a counter in the cache. In this case, it is convenient to have a single method that will allow us to increment or decrement this counter. These methods allow you to decrement or increment floats and integers by an specified amount. By default, it'll increment or decrement by 1. Since this involves several API requests, we create an asyncio.Lock so that we don't end up with race conditions. --- bot/utils/redis_cache.py | 35 +++++++++++++++++++++++++++++++++++ tests/bot/utils/test_redis_cache.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index fb9a534bd..290fae1a0 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any, Dict, ItemsView, Optional, Union from bot.bot import Bot @@ -77,6 +78,7 @@ class RedisCache: """Initialize the RedisCache.""" self._namespace = None self.bot = None + self.increment_lock = asyncio.Lock() def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -287,3 +289,36 @@ class RedisCache: """Update the Redis cache with multiple values.""" await self._validate_cache() await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items)) + + async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None: + """ + Increment the value by `amount`. + + This works for both floats and ints, but will raise a TypeError + if you try to do it for any other type of value. + + This also supports negative amounts, although it would provide better + readability to use .decrement() for that. + """ + # Since this has several API calls, we need a lock to prevent race conditions + async with self.increment_lock: + value = await self.get(key) + + # Can't increment a non-existing value + if value is None: + raise RuntimeError("The provided key does not exist!") + + # If it does exist, and it's an int or a float, increment and set it. + if isinstance(value, int) or isinstance(value, float): + value += amount + await self.set(key, value) + else: + raise TypeError("You may only increment or decrement values that are integers or floats.") + + async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None: + """ + Decrement the value by `amount`. + + Basically just does the opposite of .increment. + """ + await self.increment(key, -amount) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 6e12002ed..dbbaef018 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -173,3 +173,37 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): with self.assertRaises(TypeError): self.redis._to_typestring(["internet"]) self.redis._from_typestring("o|firedog") + + async def test_increment_decrement(self): + """Test .increment and .decrement methods.""" + await self.redis.set("entropic", 5) + await self.redis.set("disentropic", 12.5) + + # Test default increment + await self.redis.increment("entropic") + self.assertEqual(await self.redis.get("entropic"), 6) + + # Test default decrement + await self.redis.decrement("entropic") + self.assertEqual(await self.redis.get("entropic"), 5) + + # Test float increment with float + await self.redis.increment("disentropic", 2.0) + self.assertEqual(await self.redis.get("disentropic"), 14.5) + + # Test float increment with int + await self.redis.increment("disentropic", 2) + self.assertEqual(await self.redis.get("disentropic"), 16.5) + + # Test negative increments, because why not. + await self.redis.increment("entropic", -5) + self.assertEqual(await self.redis.get("entropic"), 0) + + # Negative decrements? Sure. + await self.redis.decrement("entropic", -5) + self.assertEqual(await self.redis.get("entropic"), 5) + + # What about if we use a negative float to decrement an int? + # This should convert the type into a float. + await self.redis.decrement("entropic", -2.5) + self.assertEqual(await self.redis.get("entropic"), 7.5) -- cgit v1.2.3 From c5e6e8f796265ee6faebdd3d02c839972cd028a9 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 24 May 2020 13:49:20 +0200 Subject: MockBot needs to be aware of redis_ready Forgot to update the additional_spec_asyncs when changing the name of this Bot attribute to be public. --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 5ad826156..13283339b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -287,7 +287,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): For more information, see the `MockGuild` docstring. """ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) - additional_spec_asyncs = ("wait_for", "_redis_ready") + additional_spec_asyncs = ("wait_for", "redis_ready") def __init__(self, **kwargs) -> None: super().__init__(**kwargs) -- cgit v1.2.3 From ad8b1fa455e141074daec5047682e82ed96db1f5 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 24 May 2020 19:09:45 +0200 Subject: Improve error and error testing for increment Changed a RuntimeError to a KeyError (thanks @MarkKoz), and also added some tests to ensure that the right errors are raised whenever this method is used incorrectly. --- bot/utils/redis_cache.py | 2 +- tests/bot/utils/test_redis_cache.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 5fc34d464..b91d663f3 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -341,7 +341,7 @@ class RedisCache: # Can't increment a non-existing value if value is None: log.exception("Attempt to increment/decrement value for non-existent key.") - raise RuntimeError("The provided key does not exist!") + raise KeyError("The provided key does not exist!") # If it does exist, and it's an int or a float, increment and set it. if isinstance(value, int) or isinstance(value, float): diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index dbbaef018..7405487ed 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -207,3 +207,11 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # This should convert the type into a float. await self.redis.decrement("entropic", -2.5) self.assertEqual(await self.redis.get("entropic"), 7.5) + + # Let's test that they raise the right errors + with self.assertRaises(KeyError): + await self.redis.increment("doesn't_exist!") + + await self.redis.set("stringthing", "stringthing") + with self.assertRaises(TypeError): + await self.redis.increment("stringthing") -- cgit v1.2.3 From 856cecbd2354d4cbdbace5a39b7eb9e3d3bf23c7 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 24 May 2020 19:29:13 -0700 Subject: Add support for Union type annotations for constants Note that `Optional[x]` is just an alias for `Union[None, x]` so this effectively supports `Optional` too. This was especially troublesome because the redis password must be unset/None in order to avoid authentication, but the test would complain that `None` isn't a `str`. Setting to an empty string would pass the test but then make redis authenticate and fail. --- bot/constants.py | 14 +++++++------- tests/bot/test_constants.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/bot/constants.py b/bot/constants.py index 75d394b6a..145ae54db 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -15,7 +15,7 @@ import os from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import yaml @@ -198,7 +198,7 @@ class Bot(metaclass=YAMLGetter): prefix: str token: str - sentry_dsn: str + sentry_dsn: Optional[str] class Redis(metaclass=YAMLGetter): @@ -207,7 +207,7 @@ class Redis(metaclass=YAMLGetter): host: str port: int - password: str + password: Optional[str] use_fakeredis: bool # If this is True, Bot will use fakeredis.aioredis @@ -459,7 +459,7 @@ class Guild(metaclass=YAMLGetter): class Keys(metaclass=YAMLGetter): section = "keys" - site_api: str + site_api: Optional[str] class URLs(metaclass=YAMLGetter): @@ -502,8 +502,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list - client_id: str - secret: str + client_id: Optional[str] + secret: Optional[str] class Wolfram(metaclass=YAMLGetter): @@ -511,7 +511,7 @@ class Wolfram(metaclass=YAMLGetter): user_limit_day: int guild_limit_day: int - key: str + key: Optional[str] class AntiSpam(metaclass=YAMLGetter): diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index dae7c066c..db9a9bcb0 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,4 +1,5 @@ import inspect +import typing import unittest from bot import constants @@ -8,7 +9,7 @@ class ConstantsTests(unittest.TestCase): """Tests for our constants.""" def test_section_configuration_matches_type_specification(self): - """The section annotations should match the actual types of the sections.""" + """"The section annotations should match the actual types of the sections.""" sections = ( cls @@ -19,8 +20,14 @@ class ConstantsTests(unittest.TestCase): for name, annotation in section.__annotations__.items(): with self.subTest(section=section, name=name, annotation=annotation): value = getattr(section, name) + annotation_args = typing.get_args(annotation) - if getattr(annotation, '_name', None) in ('Dict', 'List'): - self.skipTest("Cannot validate containers yet.") - - self.assertIsInstance(value, annotation) + if not annotation_args: + self.assertIsInstance(value, annotation) + else: + origin = typing.get_origin(annotation) + if origin is typing.Union: + is_instance = any(isinstance(value, arg) for arg in annotation_args) + self.assertTrue(is_instance) + else: + self.skipTest(f"Validating type {annotation} is unsupported.") -- cgit v1.2.3 From 0ede719d7beb36f476ac26f948ab940882978476 Mon Sep 17 00:00:00 2001 From: Jannes Jonkers Date: Mon, 25 May 2020 20:44:35 +0200 Subject: AntiMalware tests - Switched from monkeypatch to unittest.patch --- tests/bot/cogs/test_antimalware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index fab063201..f219fc1ba 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, patch from discord import NotFound @@ -10,6 +10,7 @@ from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole MODULE = "bot.cogs.antimalware" +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Test the AntiMalware cog.""" @@ -18,7 +19,6 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() - AntiMalwareConfig.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" -- cgit v1.2.3 From 9b9aa9b2adbdcd0e0b8c4f4ad38f112a9566fa2f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 12:03:09 -0700 Subject: Support validating collection types for constants This is a simple validation that only check the type of the collection. It does not validate the types inside the collection because that has proven to be quite complex. --- tests/bot/test_constants.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index db9a9bcb0..2937b6189 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -5,6 +5,31 @@ import unittest from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: + """ + Return True if `value` is an instance of the type represented by `annotation`. + + This doesn't account for things like Unions or checking for homogenous types in collections. + """ + origin = typing.get_origin(annotation) + + # This is done in case a bare e.g. `typing.List` is used. + # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. + # `get_origin()` does this normalisation for us. + type_ = annotation if origin is None else origin + + return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: + """Return True if `value` is an instance of any type in `types`.""" + for type_ in types: + if is_annotation_instance(value, type_): + return True + + return False + + class ConstantsTests(unittest.TestCase): """Tests for our constants.""" @@ -20,14 +45,13 @@ class ConstantsTests(unittest.TestCase): for name, annotation in section.__annotations__.items(): with self.subTest(section=section, name=name, annotation=annotation): value = getattr(section, name) + origin = typing.get_origin(annotation) annotation_args = typing.get_args(annotation) + failure_msg = f"{value} is not an instance of {annotation}" - if not annotation_args: - self.assertIsInstance(value, annotation) + if origin is typing.Union: + is_instance = is_any_instance(value, annotation_args) + self.assertTrue(is_instance, failure_msg) else: - origin = typing.get_origin(annotation) - if origin is typing.Union: - is_instance = any(isinstance(value, arg) for arg in annotation_args) - self.assertTrue(is_instance) - else: - self.skipTest(f"Validating type {annotation} is unsupported.") + is_instance = is_annotation_instance(value, annotation) + self.assertTrue(is_instance, failure_msg) -- cgit v1.2.3 From 87d42add019e8ef1bad5d9593f6ed5a803e4d153 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 12:04:50 -0700 Subject: Improve output of section name in config validation subtests --- tests/bot/test_constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index 2937b6189..f10d6fbe8 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -43,7 +43,7 @@ class ConstantsTests(unittest.TestCase): ) for section in sections: for name, annotation in section.__annotations__.items(): - with self.subTest(section=section, name=name, annotation=annotation): + with self.subTest(section=section.__name__, name=name, annotation=annotation): value = getattr(section, name) origin = typing.get_origin(annotation) annotation_args = typing.get_args(annotation) -- cgit v1.2.3 From 47886501fb7d030f1cb91c69413058e3ffcb76bf Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 20:47:32 -0700 Subject: Test token regex won't match non-base64 characters --- tests/bot/cogs/test_token_remover.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 8e743a715..dbea5ad1b 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -144,10 +144,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): "x..z", " . . ", "\n.\n.\n", - "'.'.'", - '"."."', - "(.(.(", - ").).)" + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", ) for token in tokens: -- cgit v1.2.3 From e76099d48b9a895c48e58c5f5489886f4191eeb6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 20:50:30 -0700 Subject: Add more valid tokens to test the regex with --- tests/bot/cogs/test_token_remover.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index dbea5ad1b..6a280f358 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -156,10 +156,12 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_regex_valid_tokens(self): """Messages that look like tokens should be matched.""" - # Don't worry, the token's been invalidated. + # Don't worry, these tokens have been invalidated. tokens = ( - "x1.y2.z_3", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8" + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", ) for token in tokens: -- cgit v1.2.3 From a8a216d0803b67a330ae092a17bea563f5012275 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:02:24 -0700 Subject: Fix valid token regex test It was broken due to the addition of groups. Rather than returning the full match, `findall` returns groups if any exist. The test was comparing a tuple of groups to the token string, which was of course failing. Now `fullmatch` is used cause it's simpler - just check for `None` and don't worry about iterating matches to search. --- tests/bot/cogs/test_token_remover.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 6a280f358..518bf91ca 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -166,8 +166,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): for token in tokens: with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertIn(token, results) + results = token_remover.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") def test_regex_matches_multiple_valid(self): """Should support multiple matches in the middle of a string.""" -- cgit v1.2.3 From 19cc849d4c70bc3e792460ad712aa308fa500462 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:07:21 -0700 Subject: Fix multiple match text for token regex It has to account for the addition of groups. It's easiest to compare the entire string so `finditer` is used to return re.Match objects; the tuples of `findall` would be cumbersome. Also threw in a change to use `assertCountEqual` cause the order doesn't really matter. --- tests/bot/cogs/test_token_remover.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 518bf91ca..2ecfae2bd 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -174,8 +174,9 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): tokens = ["x.y.z", "a.b.c"] message = f"garbage {tokens[0]} hello {tokens[1]} world" - results = token_remover.TOKEN_RE.findall(message) - self.assertEqual(tokens, results) + results = token_remover.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual(tokens, results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 300f8c093edea03855d94be179c64c328ec842ac Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 25 May 2020 21:09:04 -0700 Subject: Use real token values for testing multiple matches in regex --- tests/bot/cogs/test_token_remover.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 2ecfae2bd..971bc93fc 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -171,12 +171,13 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): def test_regex_matches_multiple_valid(self): """Should support multiple matches in the middle of a string.""" - tokens = ["x.y.z", "a.b.c"] - message = f"garbage {tokens[0]} hello {tokens[1]} world" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" results = token_remover.TOKEN_RE.finditer(message) results = [match[0] for match in results] - self.assertCountEqual(tokens, results) + self.assertCountEqual((token_1, token_2), results) @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): -- cgit v1.2.3 From 1ab34dd48fce2de70db1fb2dd6da06f752460829 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Tue, 26 May 2020 19:06:57 +0200 Subject: Add a test for RuntimeErrors. This just tests that the various RuntimeErrors are reachable - that includes the error about not having a bot instance, the one about not being a class attribute, and the one about not having instantiated the class. This test addresses a concern raised by @MarkKoz in a review. I've decided not to test that actual contents of these RuntimeErrors, because I believe that sort of testing is a bit too brittle. It shouldn't break a test just to change the content of an error string. --- tests/bot/utils/test_redis_cache.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 7405487ed..1b05ae350 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -215,3 +215,25 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): await self.redis.set("stringthing", "stringthing") with self.assertRaises(TypeError): await self.redis.increment("stringthing") + + async def test_exceptions_raised(self): + """Testing that the various RuntimeErrors are reachable.""" + class MyCog: + cache = RedisCache() + + def __init__(self): + self.other_cache = RedisCache() + + cog = MyCog() + + # Raises "No Bot instance" + with self.assertRaises(RuntimeError): + await cog.cache.get("john") + + # Raises "RedisCache has no namespace" + with self.assertRaises(RuntimeError): + await cog.other_cache.get("was") + + # Raises "You must access the RedisCache instance through the cog instance" + with self.assertRaises(RuntimeError): + await MyCog.cache.get("afraid") -- cgit v1.2.3 From d9190d997538f49c0a1b53d63a15bada3c89297f Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 07:32:16 +0200 Subject: Refactor the in_whitelist deco to a check. We're moving the actual predicate into the `utils.checks` folder, just like we're doing with most of the other decorators. This is to allow us the flexibility to use it as a pure check, not only as a decorator. This commit doesn't actually change any functionality, just moves it around. --- bot/decorators.py | 54 +++-------------------------- bot/utils/checks.py | 81 ++++++++++++++++++++++++++++++++++++++++++-- tests/bot/test_decorators.py | 4 +-- 3 files changed, 86 insertions(+), 53 deletions(-) (limited to 'tests') diff --git a/bot/decorators.py b/bot/decorators.py index 306f0830c..1e77afe60 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -9,37 +9,20 @@ from weakref import WeakValueDictionary from discord import Colour, Embed, Member from discord.errors import NotFound from discord.ext import commands -from discord.ext.commands import CheckFailure, Cog, Context +from discord.ext.commands import Cog, Context from bot.constants import Channels, ERROR_REPLIES, RedirectOutput -from bot.utils.checks import with_role_check, without_role_check +from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check log = logging.getLogger(__name__) -class InWhitelistCheckFailure(CheckFailure): - """Raised when the `in_whitelist` check fails.""" - - def __init__(self, redirect_channel: Optional[int]) -> None: - self.redirect_channel = redirect_channel - - if redirect_channel: - redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" - else: - redirect_message = "" - - error_message = f"You are not allowed to use that command{redirect_message}." - - super().__init__(error_message) - - def in_whitelist( *, channels: Container[int] = (), categories: Container[int] = (), roles: Container[int] = (), redirect: Optional[int] = Channels.bot_commands, - ) -> Callable: """ Check if a command was issued in a whitelisted context. @@ -54,36 +37,9 @@ def in_whitelist( redirected to the `redirect` channel that was passed (default: #bot-commands) or simply told that they're not allowed to use this particular command (if `None` was passed). """ - if redirect and redirect not in channels: - # It does not make sense for the channel whitelist to not contain the redirection - # channel (if applicable). That's why we add the redirection channel to the `channels` - # container if it's not already in it. As we allow any container type to be passed, - # we first create a tuple in order to safely add the redirection channel. - # - # Note: It's possible for the redirect channel to be in a whitelisted category, but - # there's no easy way to check that and as a channel can easily be moved in and out of - # categories, it's probably not wise to rely on its category in any case. - channels = tuple(channels) + (redirect,) - def predicate(ctx: Context) -> bool: - """Check if a command was issued in a whitelisted context.""" - if channels and ctx.channel.id in channels: - log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") - return True - - # Only check the category id if we have a category whitelist and the channel has a `category_id` - if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: - log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") - return True - - # Only check the roles whitelist if we have one and ensure the author's roles attribute returns - # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). - if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): - log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") - return True - - log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") - raise InWhitelistCheckFailure(redirect) + """Check if command was issued in a whitelisted context.""" + return in_whitelist_check(ctx, channels, categories, roles, redirect) return commands.check(predicate) @@ -121,7 +77,7 @@ def locked() -> Callable: embed = Embed() embed.colour = Colour.red() - log.debug(f"User tried to invoke a locked command.") + log.debug("User tried to invoke a locked command.") embed.description = ( "You're already using this command. Please wait until it is done before you use it again." ) diff --git a/bot/utils/checks.py b/bot/utils/checks.py index db56c347c..63568b29e 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,12 +1,89 @@ import datetime import logging -from typing import Callable, Iterable +from typing import Callable, Container, Iterable, Optional -from discord.ext.commands import BucketType, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping +from discord.ext.commands import ( + BucketType, + CheckFailure, + Cog, + Command, + CommandOnCooldown, + Context, + Cooldown, + CooldownMapping, +) + +from bot import constants log = logging.getLogger(__name__) +class InWhitelistCheckFailure(CheckFailure): + """Raised when the `in_whitelist` check fails.""" + + def __init__(self, redirect_channel: Optional[int]) -> None: + self.redirect_channel = redirect_channel + + if redirect_channel: + redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" + else: + redirect_message = "" + + error_message = f"You are not allowed to use that command{redirect_message}." + + super().__init__(error_message) + + +def in_whitelist_check( + ctx: Context, + channels: Container[int] = (), + categories: Container[int] = (), + roles: Container[int] = (), + redirect: Optional[int] = constants.Channels.bot_commands, +) -> bool: + """ + Check if a command was issued in a whitelisted context. + + The whitelists that can be provided are: + + - `channels`: a container with channel ids for whitelisted channels + - `categories`: a container with category ids for whitelisted categories + - `roles`: a container with with role ids for whitelisted roles + + If the command was invoked in a context that was not whitelisted, the member is either + redirected to the `redirect` channel that was passed (default: #bot-commands) or simply + told that they're not allowed to use this particular command (if `None` was passed). + """ + if redirect and redirect not in channels: + # It does not make sense for the channel whitelist to not contain the redirection + # channel (if applicable). That's why we add the redirection channel to the `channels` + # container if it's not already in it. As we allow any container type to be passed, + # we first create a tuple in order to safely add the redirection channel. + # + # Note: It's possible for the redirect channel to be in a whitelisted category, but + # there's no easy way to check that and as a channel can easily be moved in and out of + # categories, it's probably not wise to rely on its category in any case. + channels = tuple(channels) + (redirect,) + + if channels and ctx.channel.id in channels: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") + return True + + # Only check the category id if we have a category whitelist and the channel has a `category_id` + if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") + return True + + # Only check the roles whitelist if we have one and ensure the author's roles attribute returns + # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). + if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") + return True + + log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + raise InWhitelistCheckFailure(redirect) + + def with_role_check(ctx: Context, *role_ids: int) -> bool: """Returns True if the user has any one of the roles in role_ids.""" if not ctx.guild: # Return False in a DM diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index a17dd3e16..3d450caa0 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -3,10 +3,10 @@ import unittest import unittest.mock from bot import constants -from bot.decorators import InWhitelistCheckFailure, in_whitelist +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure from tests import helpers - InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) -- cgit v1.2.3 From d310f42080278b35914bf5785fa322b97627c45f Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 07:42:08 +0200 Subject: Find + change all InWhitelistCheckFailure imports --- bot/cogs/error_handler.py | 6 +++--- bot/cogs/information.py | 4 ++-- bot/cogs/verification.py | 4 ++-- tests/bot/cogs/test_information.py | 3 +-- 4 files changed, 8 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 23d1eed82..5de961116 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import TagNameConverter -from bot.decorators import InWhitelistCheckFailure +from bot.utils.checks import InWhitelistCheckFailure log = logging.getLogger(__name__) @@ -166,7 +166,7 @@ class ErrorHandler(Cog): await prepared_help_command self.bot.stats.incr("errors.missing_required_argument") elif isinstance(e, errors.TooManyArguments): - await ctx.send(f"Too many arguments provided.") + await ctx.send("Too many arguments provided.") await prepared_help_command self.bot.stats.incr("errors.too_many_arguments") elif isinstance(e, errors.BadArgument): @@ -206,7 +206,7 @@ class ErrorHandler(Cog): if isinstance(e, bot_missing_errors): ctx.bot.stats.incr("errors.bot_permission_error") await ctx.send( - f"Sorry, it looks like I don't have the permissions or roles I need to do that." + "Sorry, it looks like I don't have the permissions or roles I need to do that." ) elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") diff --git a/bot/cogs/information.py b/bot/cogs/information.py index ef2f308ca..f0eb3a1ea 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,9 +12,9 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot -from bot.decorators import InWhitelistCheckFailure, in_whitelist, with_role +from bot.decorators import in_whitelist, with_role from bot.pagination import LinePaginator -from bot.utils.checks import cooldown_with_role_bypass, with_role_check +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since log = logging.getLogger(__name__) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 77e8b5706..99be3cdaa 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,8 +9,8 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot from bot.cogs.moderation import ModLog -from bot.decorators import InWhitelistCheckFailure, in_whitelist, without_role -from bot.utils.checks import without_role_check +from bot.decorators import in_whitelist, without_role +from bot.utils.checks import InWhitelistCheckFailure, without_role_check log = logging.getLogger(__name__) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index b5f928dd6..aca6b594f 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,10 +7,9 @@ import discord from bot import constants from bot.cogs import information -from bot.decorators import InWhitelistCheckFailure +from bot.utils.checks import InWhitelistCheckFailure from tests import helpers - COG_PATH = "bot.cogs.information.Information" -- cgit v1.2.3 From 8e0cdb258ea6e0f25977d18336a2e07b20b5d1ee Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 09:42:57 +0200 Subject: Fix failing tests related to avatar_hash --- tests/bot/cogs/sync/test_cog.py | 3 --- tests/bot/cogs/sync/test_users.py | 2 -- 2 files changed, 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 81398c61f..14fd909c4 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -247,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase): before_data = { "name": "old name", "discriminator": "1234", - "avatar": "old avatar", "bot": False, } subtests = ( (True, "name", "name", "new name", "new name"), (True, "discriminator", "discriminator", "8765", 8765), - (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), (False, "bot", "bot", True, True), ) @@ -295,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase): ) data = { - "avatar_hash": member.avatar, "discriminator": int(member.discriminator), "id": member.id, "in_guild": True, diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 818883012..002a947ad 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -10,7 +10,6 @@ def fake_user(**kwargs): kwargs.setdefault("id", 43) kwargs.setdefault("name", "bob the test man") kwargs.setdefault("discriminator", 1337) - kwargs.setdefault("avatar_hash", None) kwargs.setdefault("roles", (666,)) kwargs.setdefault("in_guild", True) @@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): for member in members: member = member.copy() - member["avatar"] = member.pop("avatar_hash") del member["in_guild"] mock_member = helpers.MockMember(**member) -- cgit v1.2.3 From 35a1de37307b1745c061e490be4e96c8467de212 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 12:21:58 +0200 Subject: Clear cache in asyncSetUp instead of tests. --- tests/bot/utils/test_redis_cache.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 1b05ae350..900a6d035 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -15,6 +15,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): """Sets up the objects that only have to be initialized once.""" self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() + await self.redis.clear() def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" @@ -76,8 +77,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_items(self): """Test that the RedisDict can be iterated.""" - await self.redis.clear() - # Set up our test cases in the Redis cache test_cases = [ ('favorite_turtle', 'Donatello'), @@ -101,7 +100,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_length(self): """Test that we can get the correct .length from the RedisDict.""" - await self.redis.clear() await self.redis.set('one', 1) await self.redis.set('two', 2) await self.redis.set('three', 3) @@ -119,7 +117,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_clear(self): """Test that the .clear method removes the entire hash.""" - await self.redis.clear() await self.redis.set('teddy', 'with me') await self.redis.set('in my dreams', 'you have a weird hat') self.assertEqual(await self.redis.length(), 2) @@ -129,7 +126,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_pop(self): """Test that we can .pop an item from the RedisDict.""" - await self.redis.clear() await self.redis.set('john', 'was afraid') self.assertEqual(await self.redis.pop('john'), 'was afraid') @@ -138,7 +134,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_update(self): """Test that we can .update the RedisDict with multiple items.""" - await self.redis.clear() await self.redis.set("reckfried", "lona") await self.redis.set("bel air", "prince") await self.redis.update({ -- cgit v1.2.3 From b18930735e05e09ba615cb54fe1dbdfd41bb0f81 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 12:51:40 +0200 Subject: Refactor .increment and add lock test. The way we were doing the asyncio.Lock() stuff for increment was slightly problematic. @aeros has adviced us that it's better to just initialize the lock as None in __init__, and then initialize it inside the first coroutine that uses it instead. This ensures that the correct loop gets attached to the lock, so we don't end up getting errors like this one: RuntimeError: got Future attached to a different loop This happens because the lock and the actual calling coroutines aren't on the same loop. When creating a new test, test_increment_lock, we discovered that we needed a small refactor here and also in the test class to make this new test pass. So, now we're creating a DummyCog for every test method, and this will ensure the loop streams never cross. Cause we all know we must never cross the streams. --- bot/utils/redis_cache.py | 11 ++- tests/bot/utils/test_redis_cache.py | 163 ++++++++++++++++++++++-------------- 2 files changed, 109 insertions(+), 65 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 895a12da4..33e5d5852 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -81,7 +81,7 @@ class RedisCache: """Initialize the RedisCache.""" self._namespace = None self.bot = None - self._increment_lock = asyncio.Lock() + self._increment_lock = None def _set_namespace(self, namespace: str) -> None: """Try to set the namespace, but do not permit collisions.""" @@ -345,6 +345,15 @@ class RedisCache: """ log.trace(f"Attempting to increment/decrement the value with the key {key} by {amount}.") + # We initialize the lock here, because we need to ensure we get it + # running on the same loop as the calling coroutine. + # + # If we initialized the lock in the __init__, the loop that the coroutine this method + # would be called from might not exist yet, and so the lock would be on a different + # loop, which would raise RuntimeErrors. + if self._increment_lock is None: + self._increment_lock = asyncio.Lock() + # Since this has several API calls, we need a lock to prevent race conditions async with self._increment_lock: value = await self.get(key) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 900a6d035..efd168dac 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -1,3 +1,4 @@ +import asyncio import unittest import fakeredis.aioredis @@ -9,17 +10,30 @@ from tests import helpers class RedisCacheTests(unittest.IsolatedAsyncioTestCase): """Tests the RedisCache class from utils.redis_dict.py.""" - redis = RedisCache() - async def asyncSetUp(self): # noqa: N802 """Sets up the objects that only have to be initialized once.""" self.bot = helpers.MockBot() self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - await self.redis.clear() + + # Okay, so this is necessary so that we can create a clean new + # class for every test method, and we want that because it will + # ensure we get a fresh loop, which is necessary for test_increment_lock + # to be able to pass. + class DummyCog: + """A dummy cog, for dummies.""" + + redis = RedisCache() + + def __init__(self, bot: helpers.MockBot): + self.bot = bot + + self.cog = DummyCog(self.bot) + + await self.cog.redis.clear() def test_class_attribute_namespace(self): """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.redis._namespace, "RedisCacheTests.redis") + self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") async def test_class_attribute_required(self): """Test that errors are raised when not assigned as a class attribute.""" @@ -31,9 +45,13 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): def test_namespace_collision(self): """Test that we prevent colliding namespaces.""" - bad_cache = RedisCache() - bad_cache._set_namespace("RedisCacheTests.redis") - self.assertEqual(bad_cache._namespace, "RedisCacheTests.redis_") + bob_cache_1 = RedisCache() + bob_cache_1._set_namespace("BobRoss") + self.assertEqual(bob_cache_1._namespace, "BobRoss") + + bob_cache_2 = RedisCache() + bob_cache_2._set_namespace("BobRoss") + self.assertEqual(bob_cache_2._namespace, "BobRoss_") async def test_set_get_item(self): """Test that users can set and get items from the RedisDict.""" @@ -45,35 +63,35 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test that we can get and set different types. for test in test_cases: - await self.redis.set(*test) - self.assertEqual(await self.redis.get(test[0]), test[1]) + await self.cog.redis.set(*test) + self.assertEqual(await self.cog.redis.get(test[0]), test[1]) # Test that .get allows a default value - self.assertEqual(await self.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") async def test_set_item_type(self): """Test that .set rejects keys and values that are not strings, ints or floats.""" fruits = ["lemon", "melon", "apple"] with self.assertRaises(TypeError): - await self.redis.set(fruits, "nice") + await self.cog.redis.set(fruits, "nice") async def test_delete_item(self): """Test that .delete allows us to delete stuff from the RedisCache.""" # Add an item and verify that it gets added - await self.redis.set("internet", "firetruck") - self.assertEqual(await self.redis.get("internet"), "firetruck") + await self.cog.redis.set("internet", "firetruck") + self.assertEqual(await self.cog.redis.get("internet"), "firetruck") # Delete that item and verify that it gets deleted - await self.redis.delete("internet") - self.assertIs(await self.redis.get("internet"), None) + await self.cog.redis.delete("internet") + self.assertIs(await self.cog.redis.get("internet"), None) async def test_contains(self): """Test that we can check membership with .contains.""" - await self.redis.set('favorite_country', "Burkina Faso") + await self.cog.redis.set('favorite_country', "Burkina Faso") - self.assertIs(await self.redis.contains('favorite_country'), True) - self.assertIs(await self.redis.contains('favorite_dentist'), False) + self.assertIs(await self.cog.redis.contains('favorite_country'), True) + self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) async def test_items(self): """Test that the RedisDict can be iterated.""" @@ -84,10 +102,10 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): ('third_favorite_turtle', 'Raphael'), ] for key, value in test_cases: - await self.redis.set(key, value) + await self.cog.redis.set(key, value) # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.redis.items()] + redis_items = [item for item in await self.cog.redis.items()] # These sequences are probably in the same order now, but probably # isn't good enough for tests. Let's not rely on .hgetall always @@ -100,43 +118,43 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): async def test_length(self): """Test that we can get the correct .length from the RedisDict.""" - await self.redis.set('one', 1) - await self.redis.set('two', 2) - await self.redis.set('three', 3) - self.assertEqual(await self.redis.length(), 3) + await self.cog.redis.set('one', 1) + await self.cog.redis.set('two', 2) + await self.cog.redis.set('three', 3) + self.assertEqual(await self.cog.redis.length(), 3) - await self.redis.set('four', 4) - self.assertEqual(await self.redis.length(), 4) + await self.cog.redis.set('four', 4) + self.assertEqual(await self.cog.redis.length(), 4) async def test_to_dict(self): """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.redis.to_dict() - local_copy = {key: value for key, value in await self.redis.items()} + copy = await self.cog.redis.to_dict() + local_copy = {key: value for key, value in await self.cog.redis.items()} self.assertIs(type(copy), dict) self.assertDictEqual(copy, local_copy) async def test_clear(self): """Test that the .clear method removes the entire hash.""" - await self.redis.set('teddy', 'with me') - await self.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.redis.length(), 2) + await self.cog.redis.set('teddy', 'with me') + await self.cog.redis.set('in my dreams', 'you have a weird hat') + self.assertEqual(await self.cog.redis.length(), 2) - await self.redis.clear() - self.assertEqual(await self.redis.length(), 0) + await self.cog.redis.clear() + self.assertEqual(await self.cog.redis.length(), 0) async def test_pop(self): """Test that we can .pop an item from the RedisDict.""" - await self.redis.set('john', 'was afraid') + await self.cog.redis.set('john', 'was afraid') - self.assertEqual(await self.redis.pop('john'), 'was afraid') - self.assertEqual(await self.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.redis.length(), 0) + self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') + self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') + self.assertEqual(await self.cog.redis.length(), 0) async def test_update(self): """Test that we can .update the RedisDict with multiple items.""" - await self.redis.set("reckfried", "lona") - await self.redis.set("bel air", "prince") - await self.redis.update({ + await self.cog.redis.set("reckfried", "lona") + await self.cog.redis.set("bel air", "prince") + await self.cog.redis.update({ "reckfried": "jona", "mega": "hungry, though", }) @@ -146,7 +164,7 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): "bel air": "prince", "mega": "hungry, though", } - self.assertDictEqual(await self.redis.to_dict(), result) + self.assertDictEqual(await self.cog.redis.to_dict(), result) def test_typestring_conversion(self): """Test the typestring-related helper functions.""" @@ -158,58 +176,75 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test conversion to typestring for _input, expected in conversion_tests: - self.assertEqual(self.redis._to_typestring(_input), expected) + self.assertEqual(self.cog.redis._to_typestring(_input), expected) # Test conversion from typestrings for _input, expected in conversion_tests: - self.assertEqual(self.redis._from_typestring(expected), _input) + self.assertEqual(self.cog.redis._from_typestring(expected), _input) # Test that exceptions are raised on invalid input with self.assertRaises(TypeError): - self.redis._to_typestring(["internet"]) - self.redis._from_typestring("o|firedog") + self.cog.redis._to_typestring(["internet"]) + self.cog.redis._from_typestring("o|firedog") async def test_increment_decrement(self): """Test .increment and .decrement methods.""" - await self.redis.set("entropic", 5) - await self.redis.set("disentropic", 12.5) + await self.cog.redis.set("entropic", 5) + await self.cog.redis.set("disentropic", 12.5) # Test default increment - await self.redis.increment("entropic") - self.assertEqual(await self.redis.get("entropic"), 6) + await self.cog.redis.increment("entropic") + self.assertEqual(await self.cog.redis.get("entropic"), 6) # Test default decrement - await self.redis.decrement("entropic") - self.assertEqual(await self.redis.get("entropic"), 5) + await self.cog.redis.decrement("entropic") + self.assertEqual(await self.cog.redis.get("entropic"), 5) # Test float increment with float - await self.redis.increment("disentropic", 2.0) - self.assertEqual(await self.redis.get("disentropic"), 14.5) + await self.cog.redis.increment("disentropic", 2.0) + self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) # Test float increment with int - await self.redis.increment("disentropic", 2) - self.assertEqual(await self.redis.get("disentropic"), 16.5) + await self.cog.redis.increment("disentropic", 2) + self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) # Test negative increments, because why not. - await self.redis.increment("entropic", -5) - self.assertEqual(await self.redis.get("entropic"), 0) + await self.cog.redis.increment("entropic", -5) + self.assertEqual(await self.cog.redis.get("entropic"), 0) # Negative decrements? Sure. - await self.redis.decrement("entropic", -5) - self.assertEqual(await self.redis.get("entropic"), 5) + await self.cog.redis.decrement("entropic", -5) + self.assertEqual(await self.cog.redis.get("entropic"), 5) # What about if we use a negative float to decrement an int? # This should convert the type into a float. - await self.redis.decrement("entropic", -2.5) - self.assertEqual(await self.redis.get("entropic"), 7.5) + await self.cog.redis.decrement("entropic", -2.5) + self.assertEqual(await self.cog.redis.get("entropic"), 7.5) # Let's test that they raise the right errors with self.assertRaises(KeyError): - await self.redis.increment("doesn't_exist!") + await self.cog.redis.increment("doesn't_exist!") - await self.redis.set("stringthing", "stringthing") + await self.cog.redis.set("stringthing", "stringthing") with self.assertRaises(TypeError): - await self.redis.increment("stringthing") + await self.cog.redis.increment("stringthing") + + async def test_increment_lock(self): + """Test that we can't produce a race condition in .increment.""" + await self.cog.redis.set("test_key", 0) + tasks = [] + + # Increment this a lot in different tasks + for _ in range(100): + task = asyncio.create_task( + self.cog.redis.increment("test_key", 1) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Confirm that the value has been incremented the exact right number of times. + value = await self.cog.redis.get("test_key") + self.assertEqual(value, 100) async def test_exceptions_raised(self): """Testing that the various RuntimeErrors are reachable.""" -- cgit v1.2.3 From db0a384e91a463ff9668ab4f9ea5268aa332ab2d Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 13:27:34 +0200 Subject: Remove the now deprecated in_channel_check. This check was no longer being used anywhere, having been replaced by in_whitelist_check. --- bot/utils/checks.py | 8 -------- tests/bot/utils/test_checks.py | 8 -------- 2 files changed, 16 deletions(-) (limited to 'tests') diff --git a/bot/utils/checks.py b/bot/utils/checks.py index d5ebe4ec9..f0ef36302 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -120,14 +120,6 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool: return check -def in_channel_check(ctx: Context, *channel_ids: int) -> bool: - """Checks if the command was executed inside the list of specified channels.""" - check = ctx.channel.id in channel_ids - log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The result of the in_channel check was {check}.") - return check - - def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, bypass_roles: Iterable[int]) -> Callable: """ diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 9610771e5..d572b6299 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -41,11 +41,3 @@ class ChecksTests(unittest.TestCase): role_id = 42 self.ctx.author.roles.append(MockRole(id=role_id)) self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) - - def test_in_channel_check_for_correct_channel(self): - self.ctx.channel.id = 42 - self.assertTrue(checks.in_channel_check(self.ctx, *[42])) - - def test_in_channel_check_for_incorrect_channel(self): - self.ctx.channel.id = 42 + 10 - self.assertFalse(checks.in_channel_check(self.ctx, *[42])) -- cgit v1.2.3 From 876fae1856f1ad876d74036899739115fd8b86c3 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 13:39:32 +0200 Subject: Add some tests for `in_whitelist_check`. --- tests/bot/utils/test_checks.py | 48 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'tests') diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index d572b6299..de72e5748 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,6 +1,8 @@ import unittest +from unittest.mock import MagicMock from bot.utils import checks +from bot.utils.checks import InWhitelistCheckFailure from tests.helpers import MockContext, MockRole @@ -41,3 +43,49 @@ class ChecksTests(unittest.TestCase): role_id = 42 self.ctx.author.roles.append(MockRole(id=role_id)) self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + + def test_in_whitelist_check_correct_channel(self): + """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" + channel_id = 3 + self.ctx.channel.id = channel_id + self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id])) + + def test_in_whitelist_check_incorrect_channel(self): + """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match.""" + self.ctx.channel.id = 3 + with self.assertRaises(InWhitelistCheckFailure): + checks.in_whitelist_check(self.ctx, [4]) + + def test_in_whitelist_check_correct_category(self): + """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list.""" + category_id = 3 + self.ctx.channel.category_id = category_id + self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id])) + + def test_in_whitelist_check_incorrect_category(self): + """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match.""" + self.ctx.channel.category_id = 3 + with self.assertRaises(InWhitelistCheckFailure): + checks.in_whitelist_check(self.ctx, categories=[4]) + + def test_in_whitelist_check_correct_role(self): + """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list.""" + self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) + self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6])) + + def test_in_whitelist_check_incorrect_role(self): + """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match.""" + self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) + with self.assertRaises(InWhitelistCheckFailure): + checks.in_whitelist_check(self.ctx, roles=[4]) + + def test_in_whitelist_check_fail_silently(self): + """`in_whitelist_check` test no exception raised if `fail_silently` is `True`""" + self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True)) + + def test_in_whitelist_check_complex(self): + """`in_whitelist_check` test with multiple parameters""" + self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) + self.ctx.channel.category_id = 3 + self.ctx.channel.id = 5 + self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2])) -- cgit v1.2.3 From 4db313e9a7899666f1597094b0d88447c7b64311 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 27 May 2020 20:15:19 +0200 Subject: Floats are no longer permitted as RedisCache keys. Also added a test for this. This is the DRYest approach I could find. It's a little ugly, but I think it's probably good enough. --- bot/utils/redis_cache.py | 116 ++++++++++++++++++++++++------------ tests/bot/utils/test_redis_cache.py | 13 ++-- 2 files changed, 86 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 33e5d5852..afd37f8f8 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -2,26 +2,42 @@ from __future__ import annotations import asyncio import logging -from typing import Any, Dict, ItemsView, Optional, Union +import typing +from typing import Any, Dict, ItemsView, Optional, Tuple, Union from bot.bot import Bot log = logging.getLogger(__name__) -RedisType = Union[str, int, float] -TYPESTRING_PREFIXES = ( +# Type aliases +RedisKeyType = Union[str, int] +RedisValueType = Union[str, int, float] + +# Prefix tuples +PrefixTuple = Tuple[Tuple[str, Any]] +TYPESTRING_VALUE_PREFIXES = ( ("f|", float), ("i|", int), ("s|", str), ) +TYPESTRING_KEY_PREFIXES = ( + ("i|", int), + ("s|", str), +) # Makes a nice list like "float, int, and str" -NICE_TYPE_LIST = ", ".join(str(_type.__name__) for _, _type in TYPESTRING_PREFIXES) -NICE_TYPE_LIST = ", and ".join(NICE_TYPE_LIST.rsplit(", ", 1)) +NICE_VALUE_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisValueType)) +NICE_VALUE_TYPE_LIST = ", and ".join(NICE_VALUE_TYPE_LIST.rsplit(", ", 1)) + +NICE_KEY_TYPE_LIST = ", ".join(str(_type.__name__) for _type in typing.get_args(RedisKeyType)) +NICE_KEY_TYPE_LIST = ", and ".join(NICE_KEY_TYPE_LIST.rsplit(", ", 1)) # Makes a list like "'f|', 'i|', and 's|'" -NICE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_PREFIXES]) -NICE_PREFIX_LIST = ", and ".join(NICE_PREFIX_LIST.rsplit(", ", 1)) +NICE_VALUE_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_VALUE_PREFIXES]) +NICE_VALUE_PREFIX_LIST = ", and ".join(NICE_VALUE_PREFIX_LIST.rsplit(", ", 1)) + +NICE_KEY_PREFIX_LIST = ", ".join([f"'{prefix}'" for prefix, _ in TYPESTRING_KEY_PREFIXES]) +NICE_KEY_PREFIX_LIST = ", and ".join(NICE_KEY_PREFIX_LIST.rsplit(", ", 1)) class RedisCache: @@ -99,33 +115,57 @@ class RedisCache: self._namespace = namespace @staticmethod - def _to_typestring(value: RedisType) -> str: + def _to_typestring( + key_or_value: Union[RedisKeyType, RedisValueType], + prefixes: PrefixTuple, + nice_type_list: str + ) -> str: """Turn a valid Redis type into a typestring.""" - for prefix, _type in TYPESTRING_PREFIXES: - if isinstance(value, _type): - return f"{prefix}{value}" - raise TypeError(f"RedisCache._from_typestring only supports the types {NICE_TYPE_LIST}.") + for prefix, _type in prefixes: + if isinstance(key_or_value, _type): + return f"{prefix}{key_or_value}" + raise TypeError(f"RedisCache._from_typestring only supports the types {nice_type_list}.") @staticmethod - def _from_typestring(value: Union[bytes, str]) -> RedisType: - """Turn a typestring into a valid Redis type.""" + def _from_typestring( + key_or_value: Union[bytes, str], + prefixes: PrefixTuple, + nice_prefix_list: str, + ) -> Union[RedisKeyType, RedisValueType]: + """Deserialize a typestring into a valid Redis type.""" # Stuff that comes out of Redis will be bytestrings, so let's decode those. - if isinstance(value, bytes): - value = value.decode('utf-8') + if isinstance(key_or_value, bytes): + key_or_value = key_or_value.decode('utf-8') # Now we convert our unicode string back into the type it originally was. - for prefix, _type in TYPESTRING_PREFIXES: - if value.startswith(prefix): - return _type(value[len(prefix):]) - raise TypeError(f"RedisCache._to_typestring only supports the prefixes {NICE_PREFIX_LIST}.") + for prefix, _type in prefixes: + if key_or_value.startswith(prefix): + return _type(key_or_value[len(prefix):]) + raise TypeError(f"RedisCache._to_typestring only supports the prefixes {nice_prefix_list}.") + + def _key_to_typestring(self, key: RedisKeyType) -> str: + """Serialize a RedisKeyType object into a typestring.""" + return self._to_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_TYPE_LIST) + + def _value_to_typestring(self, value: RedisValueType) -> str: + """Serialize a RedisValueType object into a typestring.""" + return self._to_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_TYPE_LIST) + + def _key_from_typestring(self, key: Union[bytes, str]) -> RedisKeyType: + """Deserialize a RedisKeyType object from a typestring.""" + return self._from_typestring(key, TYPESTRING_KEY_PREFIXES, NICE_KEY_PREFIX_LIST) + + def _value_from_typestring(self, value: Union[bytes, str]) -> RedisValueType: + """Deserialize a RedisValueType object from a typestring.""" + return self._from_typestring(value, TYPESTRING_VALUE_PREFIXES, NICE_VALUE_PREFIX_LIST) def _dict_from_typestring(self, dictionary: Dict) -> Dict: """Turns all contents of a dict into valid Redis types.""" - return {self._from_typestring(key): self._from_typestring(value) for key, value in dictionary.items()} + return {self._key_from_typestring(key): self._value_from_typestring(value) for key, value in dictionary.items()} def _dict_to_typestring(self, dictionary: Dict) -> Dict: """Turns all contents of a dict into typestrings.""" - return {self._to_typestring(key): self._to_typestring(value) for key, value in dictionary.items()} + return {self._key_to_typestring(key): self._value_to_typestring(value) for key, value in dictionary.items()} async def _validate_cache(self) -> None: """Validate that the RedisCache is ready to be used.""" @@ -209,21 +249,21 @@ class RedisCache: """Return a beautiful representation of this object instance.""" return f"RedisCache(namespace={self._namespace!r})" - async def set(self, key: RedisType, value: RedisType) -> None: + async def set(self, key: RedisKeyType, value: RedisValueType) -> None: """Store an item in the Redis cache.""" await self._validate_cache() # Convert to a typestring and then set it - key = self._to_typestring(key) - value = self._to_typestring(value) + key = self._key_to_typestring(key) + value = self._value_to_typestring(value) log.trace(f"Setting {key} to {value}.") await self._redis.hset(self._namespace, key, value) - async def get(self, key: RedisType, default: Optional[RedisType] = None) -> Optional[RedisType]: + async def get(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> Optional[RedisValueType]: """Get an item from the Redis cache.""" await self._validate_cache() - key = self._to_typestring(key) + key = self._key_to_typestring(key) log.trace(f"Attempting to retrieve {key}.") value = await self._redis.hget(self._namespace, key) @@ -232,11 +272,11 @@ class RedisCache: log.trace(f"Value not found, returning default value {default}") return default else: - value = self._from_typestring(value) + value = self._value_from_typestring(value) log.trace(f"Value found, returning value {value}") return value - async def delete(self, key: RedisType) -> None: + async def delete(self, key: RedisKeyType) -> None: """ Delete an item from the Redis cache. @@ -245,19 +285,19 @@ class RedisCache: See https://redis.io/commands/hdel for more info on how this works. """ await self._validate_cache() - key = self._to_typestring(key) + key = self._key_to_typestring(key) log.trace(f"Attempting to delete {key}.") return await self._redis.hdel(self._namespace, key) - async def contains(self, key: RedisType) -> bool: + async def contains(self, key: RedisKeyType) -> bool: """ Check if a key exists in the Redis cache. Return True if the key exists, otherwise False. """ await self._validate_cache() - key = self._to_typestring(key) + key = self._key_to_typestring(key) exists = await self._redis.hexists(self._namespace, key) log.trace(f"Testing if {key} exists in the RedisCache - Result is {exists}") @@ -304,7 +344,7 @@ class RedisCache: log.trace("Clearing the cache of all key/value pairs.") await self._redis.delete(self._namespace) - async def pop(self, key: RedisType, default: Optional[RedisType] = None) -> RedisType: + async def pop(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> RedisValueType: """Get the item, remove it from the cache, and provide a default if not found.""" log.trace(f"Attempting to pop {key}.") value = await self.get(key, default) @@ -317,7 +357,7 @@ class RedisCache: return value - async def update(self, items: Dict[RedisType, RedisType]) -> None: + async def update(self, items: Dict[RedisKeyType, RedisValueType]) -> None: """ Update the Redis cache with multiple values. @@ -326,14 +366,14 @@ class RedisCache: do not exist in the RedisCache, they are created. If they do exist, the values are updated with the new ones from `items`. - Please note that both the keys and the values in the `items` dictionary - must consist of valid RedisTypes - ints, floats, or strings. + Please note that keys and the values in the `items` dictionary + must consist of valid RedisKeyTypes and RedisValueTypes. """ await self._validate_cache() log.trace(f"Updating the cache with the following items:\n{items}") await self._redis.hmset_dict(self._namespace, self._dict_to_typestring(items)) - async def increment(self, key: RedisType, amount: Optional[int, float] = 1) -> None: + async def increment(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: """ Increment the value by `amount`. @@ -373,7 +413,7 @@ class RedisCache: log.error(error_message) raise TypeError(error_message) - async def decrement(self, key: RedisType, amount: Optional[int, float] = 1) -> None: + async def decrement(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: """ Decrement the value by `amount`. diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index efd168dac..4f95dff03 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -70,12 +70,15 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") async def test_set_item_type(self): - """Test that .set rejects keys and values that are not strings, ints or floats.""" + """Test that .set rejects keys and values that are not permitted.""" fruits = ["lemon", "melon", "apple"] with self.assertRaises(TypeError): await self.cog.redis.set(fruits, "nice") + with self.assertRaises(TypeError): + await self.cog.redis.set(4.23, "nice") + async def test_delete_item(self): """Test that .delete allows us to delete stuff from the RedisCache.""" # Add an item and verify that it gets added @@ -176,16 +179,16 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): # Test conversion to typestring for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._to_typestring(_input), expected) + self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) # Test conversion from typestrings for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._from_typestring(expected), _input) + self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) # Test that exceptions are raised on invalid input with self.assertRaises(TypeError): - self.cog.redis._to_typestring(["internet"]) - self.cog.redis._from_typestring("o|firedog") + self.cog.redis._value_to_typestring(["internet"]) + self.cog.redis._value_from_typestring("o|firedog") async def test_increment_decrement(self): """Test .increment and .decrement methods.""" -- cgit v1.2.3 From f66a63501fe1ef8fb5390dfbe42ae9f95ea2bc28 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Thu, 28 May 2020 01:29:34 +0200 Subject: Add custom exceptions for each error state. The bot can get into trouble in three distinct ways: - It has no Bot instance - It has no namespace - It has no parent instance. These happen only if you're using it wrong. To make the test more precise, and to add a little bit more readability (RuntimeError could be anything!), we'll introduce some custom exceptions for these three states. This addresses a review comment by @aeros. --- bot/utils/redis_cache.py | 22 +++++++++++++++++----- tests/bot/utils/test_redis_cache.py | 7 ++++--- 2 files changed, 21 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index 979ea5d47..6b3c68979 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -27,6 +27,18 @@ _KEY_PREFIXES = ( ) +class NoBotInstanceError(RuntimeError): + """Raised when RedisCache is created without an available bot instance on the owner class.""" + + +class NoNamespaceError(RuntimeError): + """Raised when RedisCache has no namespace, for example if it is not assigned to a class attribute.""" + + +class NoParentInstanceError(RuntimeError): + """Raised when the parent instance is available, for example if called by accessing the parent class directly.""" + + class RedisCache: """ A simplified interface for a Redis connection. @@ -149,7 +161,7 @@ class RedisCache: "This object must be initialized as a class attribute." ) log.error(error_message) - raise RuntimeError(error_message) + raise NoNamespaceError(error_message) if self.bot is None: error_message = ( @@ -159,7 +171,7 @@ class RedisCache: "the RedisCache inside a class that has a Bot instance attribute." ) log.error(error_message) - raise RuntimeError(error_message) + raise NoBotInstanceError(error_message) await self.bot.redis_ready.wait() @@ -194,7 +206,7 @@ class RedisCache: if self._namespace is None: error_message = "RedisCache must be a class attribute." log.error(error_message) - raise RuntimeError(error_message) + raise NoNamespaceError(error_message) if instance is None: error_message = ( @@ -202,7 +214,7 @@ class RedisCache: "before accessing it using the cog's class object." ) log.error(error_message) - raise RuntimeError(error_message) + raise NoParentInstanceError(error_message) for attribute in vars(instance).values(): if isinstance(attribute, Bot): @@ -217,7 +229,7 @@ class RedisCache: "the RedisCache inside a class that has a Bot instance attribute." ) log.error(error_message) - raise RuntimeError(error_message) + raise NoBotInstanceError(error_message) def __repr__(self) -> str: """Return a beautiful representation of this object instance.""" diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 4f95dff03..8c1a40640 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -4,6 +4,7 @@ import unittest import fakeredis.aioredis from bot.utils import RedisCache +from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError from tests import helpers @@ -260,13 +261,13 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): cog = MyCog() # Raises "No Bot instance" - with self.assertRaises(RuntimeError): + with self.assertRaises(NoBotInstanceError): await cog.cache.get("john") # Raises "RedisCache has no namespace" - with self.assertRaises(RuntimeError): + with self.assertRaises(NoNamespaceError): await cog.other_cache.get("was") # Raises "You must access the RedisCache instance through the cog instance" - with self.assertRaises(RuntimeError): + with self.assertRaises(NoParentInstanceError): await MyCog.cache.get("afraid") -- cgit v1.2.3 From 96db6087254c957fcb8fb45aad7ffcddb46ee839 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 17:08:18 -0700 Subject: Switch findall to finditer in assertions `find_token_in_message` now uses the latter so the tests should adjust accordingly. --- tests/bot/cogs/test_token_remover.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 971bc93fc..4fff3ab33 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -94,18 +94,18 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) - token_re.findall.assert_not_called() + token_re.finditer.assert_not_called() @autospec(TokenRemover, "is_maybe_token") @autospec("bot.cogs.token_remover", "TOKEN_RE") def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): """None should be returned if the regex matches no tokens in a message.""" - token_re.findall.return_value = () + token_re.finditer.return_value = () return_value = TokenRemover.find_token_in_message(self.msg) self.assertIsNone(return_value) - token_re.findall.assert_called_once_with(self.msg.content) + token_re.finditer.assert_called_once_with(self.msg.content) is_maybe_token.assert_not_called() @autospec(TokenRemover, "is_maybe_token") @@ -123,7 +123,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): return_value = TokenRemover.find_token_in_message(self.msg) self.assertEqual(return_value, matches[true_index]) - token_re.findall.assert_called_once_with(self.msg.content) + token_re.finditer.assert_called_once_with(self.msg.content) # assert_has_calls isn't used cause it'd allow for extra calls before or after. # The function should short-circuit, so nothing past true_index should have been used. -- cgit v1.2.3 From f937032466a4124bacf217d1bfd0af097fc3395d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 19:31:55 -0700 Subject: Adjust token remover tests to use the Token NamedTuple --- tests/bot/cogs/test_token_remover.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 4fff3ab33..65bc1ee58 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -7,7 +7,7 @@ from discord import Colour from bot import constants from bot.cogs import token_remover from bot.cogs.moderation import ModLog -from bot.cogs.token_remover import TokenRemover +from bot.cogs.token_remover import Token, TokenRemover from tests.helpers import MockBot, MockMessage, autospec @@ -224,17 +224,19 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @autospec("bot.cogs.token_remover", "LOG_MESSAGE") def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") log_message.format.return_value = "Howdy" - return_value = TokenRemover.format_log_message(self.msg, "MTIz.DN9R_A.xyz") + + return_value = TokenRemover.format_log_message(self.msg, token) self.assertEqual(return_value, log_message.format.return_value) log_message.format.assert_called_once_with( author=self.msg.author, author_id=self.msg.author.id, channel=self.msg.channel.mention, - user_id="MTIz", - timestamp="DN9R_A", - hmac="xxx", + user_id=token.user_id, + timestamp=token.timestamp, + hmac="x" * len(token.hmac), ) @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) @@ -244,7 +246,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): """Should delete the message and send a mod log.""" cog = TokenRemover(self.bot) mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = "MTIz.DN9R_A.xyz" + token = mock.create_autospec(Token, spec_set=True, instance=True) log_msg = "testing123" mod_log_property.return_value = mod_log -- cgit v1.2.3 From 12b8f5002807144451a313180c639bf6b4925f2e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 27 May 2020 20:00:33 -0700 Subject: Add more thorough and realistic inputs for token ID and timestamp tests The tests for valid inputs and invalid inputs were split to make them more readable. --- tests/bot/cogs/test_token_remover.py | 70 ++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 65bc1ee58..ffe76865a 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -24,31 +24,65 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) self.msg.author.avatar_url_as.return_value = "picture-lemon.png" - def test_is_valid_user_id(self): - """Should correctly discern valid user IDs and ignore non-numeric and non-ASCII IDs.""" - subtests = ( - ("MTIz", True), # base64(123) - ("YWJj", False), # base64(abc) - ("λδµ", False), + def test_is_valid_user_id_valid(self): + """Should consider user IDs valid if they decode entirely to ASCII digits.""" + ids = ( + "NDcyMjY1OTQzMDYyNDEzMzMy", + "NDc1MDczNjI5Mzk5NTQ3OTA0", + "NDY3MjIzMjMwNjUwNzc3NjQx", ) - for user_id, is_valid in subtests: - with self.subTest(user_id=user_id, is_valid=is_valid): + for user_id in ids: + with self.subTest(user_id=user_id): result = TokenRemover.is_valid_user_id(user_id) - self.assertIs(result, is_valid) + self.assertTrue(result) + + def test_is_valid_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) - def test_is_valid_timestamp(self): - """Should correctly discern valid timestamps.""" - subtests = ( - ("DN9r_A", True), - ("MTIz", False), # base64(123) - ("λδµ", False), + for user_id, msg in ids: + with self.subTest(msg=msg): + result = TokenRemover.is_valid_user_id(user_id) + self.assertFalse(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = TokenRemover.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), ) - for timestamp, is_valid in subtests: - with self.subTest(timestamp=timestamp, is_valid=is_valid): + for timestamp, msg in timestamps: + with self.subTest(msg=msg): result = TokenRemover.is_valid_timestamp(timestamp) - self.assertIs(result, is_valid) + self.assertFalse(result) def test_mod_log_property(self): """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -- cgit v1.2.3 From 67472080fef5c38b21d74daa2178c3f35081b58f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 19:52:41 -0700 Subject: Remove is_maybe_token tests The function was removed due to redundancy. Therefore, its tests are obsolete. --- tests/bot/cogs/test_token_remover.py | 33 --------------------------------- 1 file changed, 33 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index ffe76865a..5dd12636c 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -213,39 +213,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = [match[0] for match in results] self.assertCountEqual((token_1, token_2), results) - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - def test_is_maybe_token_missing_part_returns_false(self, valid_user, valid_time): - """False should be returned for tokens which do not have all 3 parts.""" - return_value = TokenRemover.is_maybe_token("x.y") - - self.assertFalse(return_value) - valid_user.assert_not_called() - valid_time.assert_not_called() - - @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") - def test_is_maybe_token(self, valid_user, valid_time): - """Should return True if the user ID and timestamp are valid or return False otherwise.""" - subtests = ( - (False, True, False), - (True, False, False), - (True, True, True), - ) - - for user_return, time_return, expected in subtests: - valid_user.reset_mock() - valid_time.reset_mock() - - with self.subTest(user_return=user_return, time_return=time_return, expected=expected): - valid_user.return_value = user_return - valid_time.return_value = time_return - - actual = TokenRemover.is_maybe_token("x.y.z") - self.assertIs(actual, expected) - - valid_user.assert_called_once_with("x") - if user_return: - valid_time.assert_called_once_with("y") - async def test_delete_message(self): """The message should be deleted, and a message should be sent to the same channel.""" await TokenRemover.delete_message(self.msg) -- cgit v1.2.3 From 84cd8235863acc80b7f140309424c33180cc34ea Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 20:32:48 -0700 Subject: Adjust find_token_in_message tests for the recent cog changes It now supports the changes that switched to finditer, added match groups, and added the Token NamedTuple. It also accounts for the is_maybe_token function being removed. For the sake of simplicity, call assertions on is_valid_user_id and is_valid_timestamp were not made. --- tests/bot/cogs/test_token_remover.py | 39 ++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 5dd12636c..8238e235a 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,4 +1,5 @@ import unittest +from re import Match from unittest import mock from unittest.mock import MagicMock @@ -130,9 +131,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.finditer.assert_not_called() - @autospec(TokenRemover, "is_maybe_token") @autospec("bot.cogs.token_remover", "TOKEN_RE") - def test_find_token_no_matches_returns_none(self, token_re, is_maybe_token): + def test_find_token_no_matches(self, token_re): """None should be returned if the regex matches no tokens in a message.""" token_re.finditer.return_value = () @@ -140,30 +140,31 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(return_value) token_re.finditer.assert_called_once_with(self.msg.content) - is_maybe_token.assert_not_called() - @autospec(TokenRemover, "is_maybe_token") + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.token_remover", "Token") @autospec("bot.cogs.token_remover", "TOKEN_RE") - def test_find_token_returns_found_token(self, token_re, is_maybe_token): - """The found token should be returned.""" - true_index = 1 - matches = ("foo", "bar", "baz") - side_effects = [False] * len(matches) - side_effects[true_index] = True - - token_re.findall.return_value = matches - is_maybe_token.side_effect = side_effects + def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """The first match with a valid user ID and timestamp should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True return_value = TokenRemover.find_token_in_message(self.msg) - self.assertEqual(return_value, matches[true_index]) + self.assertEqual(tokens[1], return_value) token_re.finditer.assert_called_once_with(self.msg.content) - # assert_has_calls isn't used cause it'd allow for extra calls before or after. - # The function should short-circuit, so nothing past true_index should have been used. - calls = [mock.call(match) for match in matches[:true_index + 1]] - self.assertEqual(is_maybe_token.mock_calls, calls) - def test_regex_invalid_tokens(self): """Messages without anything looking like a token are not matched.""" tokens = ( -- cgit v1.2.3 From 5930a044b8347019d474a809fc86f89263574ad0 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 28 May 2020 20:33:34 -0700 Subject: Test find_token_in_message returns None for invalid matches This covers the case when a token is matched, but its user ID and timestamp turn out to be invalid. --- tests/bot/cogs/test_token_remover.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 8238e235a..9b4b04ecd 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -165,6 +165,21 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(tokens[1], return_value) token_re.finditer.assert_called_once_with(self.msg.content) + @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp") + @autospec("bot.cogs.token_remover", "Token") + @autospec("bot.cogs.token_remover", "TOKEN_RE") + def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp): + """None should be returned if no matches have valid user IDs or timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + is_valid_id.return_value = False + is_valid_timestamp.return_value = False + + return_value = TokenRemover.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + token_re.finditer.assert_called_once_with(self.msg.content) + def test_regex_invalid_tokens(self): """Messages without anything looking like a token are not matched.""" tokens = ( -- cgit v1.2.3 From f59e63454ffa582765847e8a26d9d97dcd9ff7b2 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 30 May 2020 01:42:02 +0200 Subject: Fix busted test_information test. I wish this test didn't exist. --- tests/bot/cogs/test_information.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index aca6b594f..79c0e0ad3 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -148,14 +148,18 @@ class InformationCogTests(unittest.TestCase): Voice region: {self.ctx.guild.region} Features: {', '.join(self.ctx.guild.features)} - **Counts** - Members: {self.ctx.guild.member_count:,} - Roles: {len(self.ctx.guild.roles)} + **Channel counts** Category channels: 1 Text channels: 1 Voice channels: 1 + Staff channels: 0 + + **Member counts** + Members: {self.ctx.guild.member_count:,} + Staff members: 0 + Roles: {len(self.ctx.guild.roles)} - **Members** + **Member statuses** {constants.Emojis.status_online} 2 {constants.Emojis.status_idle} 1 {constants.Emojis.status_dnd} 4 -- cgit v1.2.3 From 96b026198a4ca2074f4fd7ea68e8a09acd5b38e4 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 30 May 2020 09:34:39 +0300 Subject: Simplify infraction reason truncation tests --- tests/bot/cogs/moderation/test_infractions.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index 5548d9f68..ad3c95958 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -27,15 +27,14 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.cog.apply_infraction = AsyncMock() self.bot.get_cog.return_value = AsyncMock() self.cog.mod_log.ignore = Mock() + self.ctx.guild.ban = Mock() await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - ban = self.cog.apply_infraction.call_args[0][3] - self.assertEqual( - ban.cr_frame.f_locals["kwargs"]["reason"], - textwrap.shorten("foo bar" * 3000, 512, placeholder="...") + self.ctx.guild.ban.assert_called_once_with( + self.target, + reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), + delete_message_days=0 ) - # Await ban to avoid not awaited coroutine warning - await ban @patch("bot.cogs.moderation.utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): @@ -44,12 +43,7 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.cog.apply_infraction = AsyncMock() self.cog.mod_log.ignore = Mock() + self.target.kick = Mock() await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - kick = self.cog.apply_infraction.call_args[0][3] - self.assertEqual( - kick.cr_frame.f_locals["kwargs"]["reason"], - textwrap.shorten("foo bar" * 3000, 512, placeholder="...") - ) - # Await kick to avoid not awaited coroutine warning - await kick + self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) -- cgit v1.2.3 From 323317496310ef474a39d468e273703106e44768 Mon Sep 17 00:00:00 2001 From: ks129 <45097959+ks129@users.noreply.github.com> Date: Sat, 30 May 2020 10:07:21 +0300 Subject: Infr. Tests: Add `apply_infraction` awaiting assertion with args --- tests/bot/cogs/moderation/test_infractions.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py index ad3c95958..da4e92ccc 100644 --- a/tests/bot/cogs/moderation/test_infractions.py +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -35,6 +35,9 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), delete_message_days=0 ) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value + ) @patch("bot.cogs.moderation.utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): @@ -47,3 +50,6 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + ) -- cgit v1.2.3 From 876b4846f612fe0011cc2e0b498b4df9e54d74cb Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 31 May 2020 19:17:07 +0200 Subject: Add support for bool values in RedisCache We're gonna need this for the help channel handling, and it seems like a reasonable type to support anyway. It requires a tiny bit of special handling, but nothing outrageous. --- bot/utils/redis_cache.py | 14 ++++++++++++-- tests/bot/utils/test_redis_cache.py | 4 +++- 2 files changed, 15 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py index de80cee84..2926e7a89 100644 --- a/bot/utils/redis_cache.py +++ b/bot/utils/redis_cache.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import logging +from distutils.util import strtobool from functools import partialmethod from typing import Any, Dict, ItemsView, Optional, Tuple, Union @@ -11,7 +12,7 @@ log = logging.getLogger(__name__) # Type aliases RedisKeyType = Union[str, int] -RedisValueType = Union[str, int, float] +RedisValueType = Union[str, int, float, bool] RedisKeyOrValue = Union[RedisKeyType, RedisValueType] # Prefix tuples @@ -20,6 +21,7 @@ _VALUE_PREFIXES = ( ("f|", float), ("i|", int), ("s|", str), + ("b|", bool), ) _KEY_PREFIXES = ( ("i|", int), @@ -117,7 +119,8 @@ class RedisCache: def _to_typestring(key_or_value: RedisKeyOrValue, prefixes: _PrefixTuple) -> str: """Turn a valid Redis type into a typestring.""" for prefix, _type in prefixes: - if isinstance(key_or_value, _type): + # isinstance is a bad idea here, because isintance(False, int) == True. + if type(key_or_value) is _type: return f"{prefix}{key_or_value}" raise TypeError(f"RedisCache._to_typestring only supports the following: {prefixes}.") @@ -131,6 +134,13 @@ class RedisCache: # Now we convert our unicode string back into the type it originally was. for prefix, _type in prefixes: if key_or_value.startswith(prefix): + + # For booleans, we need special handling because bool("False") is True. + if prefix == "b|": + value = key_or_value[len(prefix):] + return bool(strtobool(value)) + + # Otherwise we can just convert normally. return _type(key_or_value[len(prefix):]) raise TypeError(f"RedisCache._from_typestring only supports the following: {prefixes}.") diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 8c1a40640..62c411681 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -59,7 +59,9 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): test_cases = ( ('favorite_fruit', 'melon'), ('favorite_number', 86), - ('favorite_fraction', 86.54) + ('favorite_fraction', 86.54), + ('favorite_boolean', False), + ('other_boolean', True), ) # Test that we can get and set different types. -- cgit v1.2.3 From ebbaa6274cfc278c772593b193356aa8bf066de4 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 31 May 2020 14:17:20 -0700 Subject: Remove redis namespace collision test --- tests/bot/utils/test_redis_cache.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'tests') diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py index 8c1a40640..e5d6e4078 100644 --- a/tests/bot/utils/test_redis_cache.py +++ b/tests/bot/utils/test_redis_cache.py @@ -44,16 +44,6 @@ class RedisCacheTests(unittest.IsolatedAsyncioTestCase): with self.assertRaises(RuntimeError): await bad_cache.set("test", "me_up_deadman") - def test_namespace_collision(self): - """Test that we prevent colliding namespaces.""" - bob_cache_1 = RedisCache() - bob_cache_1._set_namespace("BobRoss") - self.assertEqual(bob_cache_1._namespace, "BobRoss") - - bob_cache_2 = RedisCache() - bob_cache_2._set_namespace("BobRoss") - self.assertEqual(bob_cache_2._namespace, "BobRoss_") - async def test_set_get_item(self): """Test that users can set and get items from the RedisDict.""" test_cases = ( -- cgit v1.2.3 From 9b3ab7df5ae1ecf95705f2fab7d99fdb36eb98ea Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 2 Jun 2020 19:22:49 -0700 Subject: Token remover: remove the `delete_message` function It's redundant; there's no benefit here in abstracting two lines of code into a function. --- bot/cogs/token_remover.py | 9 ++------- tests/bot/cogs/test_token_remover.py | 19 +++++++------------ 2 files changed, 9 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 46329e207..d55e079e9 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -79,7 +79,8 @@ class TokenRemover(Cog): async def take_action(self, msg: Message, found_token: Token) -> None: """Remove the `msg` containing the `found_token` and send a mod log message.""" self.mod_log.ignore(Event.message_delete, msg.id) - await self.delete_message(msg) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) log_message = self.format_log_message(msg, found_token) log.debug(log_message) @@ -96,12 +97,6 @@ class TokenRemover(Cog): self.bot.stats.incr("tokens.removed_tokens") - @staticmethod - async def delete_message(msg: Message) -> None: - """Remove a `msg` containing a token and send an explanatory message in the same channel.""" - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - @staticmethod def format_log_message(msg: Message, token: Token) -> str: """Return the log message to send for `token` being censored in `msg`.""" diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 9b4b04ecd..a10124d2d 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -229,15 +229,6 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): results = [match[0] for match in results] self.assertCountEqual((token_1, token_2), results) - async def test_delete_message(self): - """The message should be deleted, and a message should be sent to the same channel.""" - await TokenRemover.delete_message(self.msg) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - @autospec("bot.cogs.token_remover", "LOG_MESSAGE") def test_format_log_message(self, log_message): """Should correctly format the log message with info from the message and token.""" @@ -258,8 +249,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) @autospec("bot.cogs.token_remover", "log") - @autospec(TokenRemover, "delete_message", "format_log_message") - async def test_take_action(self, delete_message, format_log_message, logger, mod_log_property): + @autospec(TokenRemover, "format_log_message") + async def test_take_action(self, format_log_message, logger, mod_log_property): """Should delete the message and send a mod log.""" cog = TokenRemover(self.bot) mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) @@ -271,7 +262,11 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): await cog.take_action(self.msg, token) - delete_message.assert_awaited_once_with(self.msg) + self.msg.delete.assert_called_once_with() + self.msg.channel.send.assert_called_once_with( + token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) + ) + format_log_message.assert_called_once_with(self.msg, token) logger.debug.assert_called_with(log_msg) self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") -- cgit v1.2.3 From be4902cbd66c2f7223608ddbfee4aa4f0e1a011a Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Sat, 6 Jun 2020 22:11:53 +0200 Subject: Test for channel not silenced message --- tests/bot/cogs/moderation/test_silence.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 3fd149f04..ab3d0742a 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -127,10 +127,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.ctx.reset_mock() async def test_unsilence_sent_correct_discord_message(self): - """Proper reply after a successful unsilence.""" - with mock.patch.object(self.cog, "_unsilence", return_value=True): - await self.cog.unsilence.callback(self.cog, self.ctx) - self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") + """Check if proper message was sent when unsilencing channel.""" + test_cases = ( + (True, f"{Emojis.check_mark} unsilenced current channel."), + (False, f"{Emojis.cross_mark} current channel was not silenced.") + ) + for _unsilence_patch_return, result_message in test_cases: + with self.subTest( + starting_silenced_state=_unsilence_patch_return, + result_message=result_message + ): + with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): + await self.cog.unsilence.callback(self.cog, self.ctx) + self.ctx.send.assert_called_once_with(result_message) + self.ctx.reset_mock() async def test_silence_private_for_false(self): """Permissions are not set and `False` is returned in an already silenced channel.""" -- cgit v1.2.3