diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/base.py | 24 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 103 | ||||
-rw-r--r-- | tests/bot/exts/info/test_information.py | 74 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 49 | ||||
-rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 29 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 71 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 103 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 84 | ||||
-rw-r--r-- | tests/bot/rules/test_mentions.py | 131 | ||||
-rw-r--r-- | tests/helpers.py | 24 | ||||
-rw-r--r-- | tests/test_helpers.py | 2 |
11 files changed, 477 insertions, 217 deletions
diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError from discord.ext.commands import errors from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock(return_value=False) + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() - cog.try_run_eval.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - await setup(bot) + await error_handler.setup(bot) bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index d896b7652..9f5143c01 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -2,6 +2,7 @@ import textwrap import unittest import unittest.mock from datetime import datetime +from textwrap import shorten import discord @@ -573,3 +574,76 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): create_embed.assert_called_once_with(ctx, self.target, False) ctx.send.assert_called_once() + + +class RuleCommandTests(unittest.IsolatedAsyncioTestCase): + """Tests for the `!rule` command.""" + + def setUp(self) -> None: + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + self.ctx = helpers.MockContext(author=helpers.MockMember(id=1, name="Bellaluma")) + self.full_rules = [ + ( + "First rule", + ["first", "number_one"] + ), + ( + "Second rule", + ["second", "number_two"] + ), + ( + "Third rule", + ["third", "number_three"] + ) + ] + self.bot.api_client.get.return_value = self.full_rules + + async def test_return_none_if_one_rule_number_is_invalid(self): + + test_cases = [ + (('1', '6', '7', '8'), (6, 7, 8)), + (('10', "first"), (10, )), + (("first", 10), (10, )) + ] + + for raw_user_input, extracted_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + invalid = ", ".join( + str(rule_number) for rule_number in extracted_rule_numbers + if rule_number < 1 or rule_number > len(self.full_rules)) + + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + + self.assertEqual( + self.ctx.send.call_args, + unittest.mock.call(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ..."))) + self.assertEqual(None, final_rule_numbers) + + async def test_return_correct_rule_numbers(self): + + test_cases = [ + (("1", "2", "first"), {1, 2}), + (("1", "hello", "2", "second"), {1}), + (("second", "third", "unknown", "999"), {2, 3}) + ] + + for raw_user_input, expected_matched_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + self.assertEqual(expected_matched_rule_numbers, final_rule_numbers) + + async def test_return_default_rules_when_no_input_or_no_match_are_found(self): + test_cases = [ + ((), None), + (("hello", "2", "second"), None), + (("hello", "999"), None), + ] + + for raw_user_input, expected_matched_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + embed = self.ctx.send.call_args.kwargs['embed'] + self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description) + self.assertEqual(expected_matched_rule_numbers, final_rule_numbers) diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..b78328137 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -35,17 +35,20 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.cog.apply_infraction = AsyncMock() self.bot.get_cog.return_value = AsyncMock() self.cog.mod_log.ignore = Mock() - self.ctx.guild.ban = Mock() + self.ctx.guild.ban = AsyncMock() await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) - self.ctx.guild.ban.assert_called_once_with( + self.cog.apply_infraction.assert_awaited_once_with( + self.ctx, {"foo": "bar", "purge": ""}, self.target, ANY + ) + + action = self.cog.apply_infraction.call_args.args[-1] + await action() + self.ctx.guild.ban.assert_awaited_once_with( self.target, reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), delete_message_days=0 ) - self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar", "purge": ""}, self.target, self.ctx.guild.ban.return_value - ) @patch("bot.exts.moderation.infraction._utils.post_infraction") async def test_apply_kick_reason_truncation(self, post_infraction_mock): @@ -54,14 +57,17 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): self.cog.apply_infraction = AsyncMock() self.cog.mod_log.ignore = Mock() - self.target.kick = Mock() + self.target.kick = AsyncMock() await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) - self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) self.cog.apply_infraction.assert_awaited_once_with( - self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value + self.ctx, {"foo": "bar"}, self.target, ANY ) + action = self.cog.apply_infraction.call_args.args[-1] + await action() + self.target.kick.assert_awaited_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) + @patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): @@ -79,19 +85,25 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): """Should call voice mute applying function without expiry.""" self.cog.apply_voice_mute = AsyncMock() self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) - self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None) async def test_temporary_voice_mute(self): """Should call voice mute applying function with expiry.""" self.cog.apply_voice_mute = AsyncMock() self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) - self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz") async def test_voice_unmute(self): """Should call infraction pardoning function.""" self.cog.pardon_infraction = AsyncMock() + self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user, pardon_reason="foobar")) + self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user, "foobar") + + async def test_voice_unmute_reasonless(self): + """Should call infraction pardoning function without a pardon reason.""" + self.cog.pardon_infraction = AsyncMock() self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user)) - self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user) + self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user, None) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") @@ -141,8 +153,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): async def action_tester(self, action, reason: str) -> None: """Helper method to test voice mute action.""" - self.assertTrue(inspect.iscoroutine(action)) - await action + self.assertTrue(inspect.iscoroutinefunction(action)) + await action() self.user.move_to.assert_called_once_with(None, reason=ANY) self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason=reason) @@ -189,13 +201,14 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): user = MockUser() await self.cog.voicemute(self.cog, self.ctx, user, reason=None) - post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) + post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, + duration_or_expiry=None) apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY) # Test action action = self.cog.apply_infraction.call_args[0][-1] - self.assertTrue(inspect.iscoroutine(action)) - await action + self.assertTrue(inspect.iscoroutinefunction(action)) + await action() async def test_voice_unmute_user_not_found(self): """Should include info to return dict when user was not found from guild.""" @@ -273,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.user, "FooBar", purge_days=1, - expires_at=None, + duration_or_expiry=None, ) async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): @@ -285,7 +298,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.ctx, self.user, "FooBar", - expires_at=None, + duration_or_expiry=None, ) @patch("bot.exts.moderation.infraction.infractions.Age") diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5cf02033d..29dadf372 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -1,7 +1,7 @@ import unittest from collections import namedtuple from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch from botcore.site_api import ResponseCodeError from discord import Embed, Forbidden, HTTPException, NotFound @@ -309,8 +309,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): async def test_normal_post_infraction(self): """Should return response from POST request if there are no errors.""" - now = datetime.now() - payload = { + now = datetime.utcnow() + expected = { "actor": self.ctx.author.id, "hidden": True, "reason": "Test reason", @@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "user": self.member.id, "active": False, "expires_at": now.isoformat(), - "dm_sent": False + "dm_sent": False, } self.ctx.bot.api_client.post.return_value = "foo" actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) - self.assertEqual(actual, "foo") - self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.ctx.bot.api_client.post.assert_awaited_once() + + # Since `last_applied` is based on current time, just check if expected is a subset of payload + payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"] + self.assertEqual(payload, payload | expected) async def test_unknown_error_post_infraction(self): """Should send an error message to chat when a non-400 error occurs.""" @@ -349,19 +352,25 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar") async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" - payload = { + expected = { "actor": self.ctx.author.id, "hidden": False, "reason": "Test reason", "type": "mute", "user": self.user.id, "active": True, - "dm_sent": False + "dm_sent": False, } self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] - actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) + await_args = self.bot.api_client.post.await_args_list + self.assertEqual(len(await_args), 2, "Expected 2 awaits") + + # Since `last_applied` is based on current time, just check if expected is a subset of payload + for args in await_args: + payload: dict = args.kwargs["json"] + self.assertEqual(payload, payload | expected) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@ import asyncio +import datetime import enum import logging import typing as t @@ -8,16 +9,19 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) + class MockAsyncIterable: """ @@ -100,30 +104,45 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_actioned(self): """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_green) self.assertIn("Actioned", embed.footer.text) async def test_make_embed_not_actioned(self): """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.NOT_ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_red) self.assertIn("Rejected", embed.footer.text) async def test_make_embed_content(self): """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") + incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) + + reported_timestamp = discord_timestamp(CURRENT_TIME) + relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - self.assertEqual(incident.content, embed.description) + self.assertEqual( + f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", + embed.description + ) async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return our `file` with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -135,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_fails(self): """Incident's attachment fails to download, proxy url is linked instead.""" attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return None as if the download failed with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -270,7 +289,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,22 +298,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() - - async def asyncSetUp(self): # noqa: N802 - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): # noqa: N802 - if self.session: - await self.session.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. @@ -365,7 +368,6 @@ class TestCrawlIncidents(TestIncidents): class TestArchive(TestIncidents): """Tests for the `Incidents.archive` coroutine.""" - async def test_archive_webhook_not_found(self): """ Method recovers and returns False when the webhook is not found. @@ -375,7 +377,11 @@ class TestArchive(TestIncidents): """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + await self.cog_instance.archive( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=MagicMock(), + actioned_by=MockMember() + ) ) async def test_archive_relays_incident(self): @@ -391,7 +397,7 @@ class TestArchive(TestIncidents): # Define our own `incident` to be archived incident = MockMessage( content="this is an incident", - author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), + author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this @@ -422,7 +428,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -521,12 +527,13 @@ class TestProcessEvent(TestIncidents): async def test_process_event_confirmation_task_is_awaited(self): """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" mock_task = AsyncMock() + mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)]) with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), - member=MockMember(roles=[MockRole(id=1)]) + incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), + member=mock_member ) mock_task.assert_awaited() @@ -545,7 +552,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), + incident=MockMessage(id=123, created_at=CURRENT_TIME), member=MockMember(roles=[MockRole(id=1)]) ) except asyncio.TimeoutError: @@ -656,7 +663,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio import itertools import unittest from datetime import datetime, timezone @@ -6,31 +5,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule(): # noqa: N802 - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule(): # noqa: N802 - """Close the fakeredis session.""" - if redis_session: - redis_loop.run_until_complete(redis_session.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -39,8 +22,24 @@ class PatchedDatetime(datetime): now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): + """A base class for Silence tests that correctly sets up the cog and redis.""" + + @autospec(silence, "Scheduler", pass_mocks=False) + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) + self.cog = silence.Silence(self.bot) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest): def setUp(self) -> None: + super().setUp() self.alert_channel = MockTextChannel() self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() @@ -105,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(SilenceTest): """Tests for the general functionality of the Silence cog.""" - @autospec(silence, "Scheduler", pass_mocks=False) - def setUp(self) -> None: - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_channels(self): """Got channels from bot.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @@ -245,15 +234,9 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(SilenceTest): """Tests for the silence argument parser utility function.""" - def setUp(self): - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) - @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @autospec(silence.Silence, "parse_silence_args") @@ -321,17 +304,19 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase): """Tests for the rescheduling of cached unsilences.""" - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) - def setUp(self): + @autospec(silence, "Scheduler", pass_mocks=False) + def setUp(self) -> None: self.bot = MockBot() self.cog = silence.Silence(self.bot) self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) - with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -406,22 +391,14 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(SilenceTest): """Tests for the silence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) + super().setUp() # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. - self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( send_messages=True, @@ -679,24 +656,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest): """Tests for the unsilence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) - - overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) - self.cog.previous_overwrites = overwrites_cache - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + super().setUp() self.cog.scheduler.__contains__.return_value = True - overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -705,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) self.voice_channel.overwrites_for.return_value = self.voice_overwrite + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index b870a9945..b1f32c210 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -6,9 +6,10 @@ from discord import AllowedMentions from discord.ext import commands from bot import constants +from bot.errors import LockedResourceError from bot.exts.utils import snekbox from bot.exts.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser +from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -26,7 +27,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_job("import random"), "return") + self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -84,28 +85,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval') + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), - ('Your eval job has completed with return code 127', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127', '') ) @patch('bot.exts.utils.snekbox.Signals') def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), - ('Your eval job has completed with return code 127 (SIGTEST)', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') ) def test_get_status_emoji(self): @@ -179,9 +180,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.send_job = AsyncMock(return_value=response) self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) - self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') - self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" @@ -192,23 +193,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.continue_job = AsyncMock() self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) self.cog.send_job.assert_called_with( - ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' + ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' ) - self.cog.continue_job.assert_called_with(ctx, response, ctx.command) + self.cog.continue_job.assert_called_with(ctx, response, 'eval') async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" ctx = MockContext() ctx.author.id = 42 - ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.send = AsyncMock() - self.cog.jobs = (42,) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - ctx.send.assert_called_once_with( - "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" - ) + + async def delay_with_side_effect(*args, **kwargs) -> dict: + """Delay the post_job call to ensure the job runs long enough to conflict.""" + await asyncio.sleep(1) + return {'stdout': '', 'returncode': 0} + + self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) + with self.assertRaises(LockedResourceError): + await asyncio.gather( + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + ) async def test_send_job(self): """Test the send_job function.""" @@ -226,7 +232,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -237,9 +243,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') self.cog.format_output.assert_called_once_with('') async def test_send_job_with_paste_link(self): @@ -258,7 +264,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -267,9 +273,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') + self.cog.get_results_message.assert_called_once_with( + {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' + ) self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_job_with_non_zero_eval(self): @@ -287,7 +295,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -295,17 +303,25 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") async def test_continue_job_does_continue(self, partial_mock): """Test that the continue_job function does continue if required conditions are met.""" - ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) - response = MockMessage(delete=AsyncMock()) + ctx = MockContext( + message=MockMessage( + id=4, + add_reaction=AsyncMock(), + clear_reactions=AsyncMock() + ), + author=MockMember(id=14) + ) + response = MockMessage(id=42, delete=AsyncMock()) new_msg = MockMessage() + self.cog.jobs = {4: 42} self.bot.wait_for.side_effect = ((None, new_msg), None) expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py new file mode 100644 index 000000000..e1f904917 --- /dev/null +++ b/tests/bot/rules/test_mentions.py @@ -0,0 +1,131 @@ +from typing import Iterable, Optional + +import discord + +from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMember, MockMessage, MockMessageReference + + +def make_msg( + author: str, + total_user_mentions: int, + total_bot_mentions: int = 0, + *, + reference: Optional[MockMessageReference] = None +) -> MockMessage: + """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" + user_mentions = [MockMember() for _ in range(total_user_mentions)] + bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] + + mentions = user_mentions + bot_mentions + if reference is not None: + # For the sake of these tests we assume that all references are mentions. + mentions.append(reference.resolved.author) + msg_type = discord.MessageType.reply + else: + msg_type = discord.MessageType.default + + return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) + + +class TestMentions(RuleTest): + """Tests applying the `mentions` antispam rule.""" + + def setUp(self): + self.apply = mentions.apply + self.config = { + "max": 2, + "interval": 10, + } + + async def test_mentions_within_limit(self): + """Messages with an allowed amount of mentions.""" + cases = ( + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 1), make_msg("alice", 2)], + ) + + await self.run_allowed(cases) + + async def test_mentions_exceeding_limit(self): + """Messages with a higher than allowed amount of mentions.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], + ("alice",), + 3, + ), + DisallowedCase( + [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], + ("bob",), + 4, + ), + DisallowedCase( + [make_msg("bob", 3, 1)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference())], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], + ("bob",), + 3 + ) + ) + + await self.run_disallowed(cases) + + async def test_ignore_bot_mentions(self): + """Messages with an allowed amount of mentions, also containing bot mentions.""" + cases = ( + [make_msg("bob", 0, 3)], + [make_msg("bob", 2, 1)], + [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], + [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] + ) + + await self.run_allowed(cases) + + async def test_ignore_reply_mentions(self): + """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" + cases = ( + [ + make_msg("bob", 2, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) + ], + [ + make_msg("bob", 2, reference=MockMessageReference()), + make_msg("bob", 0, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), + make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) + ] + ) + + await self.run_allowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/helpers.py b/tests/helpers.py index e74306d23..28a8e40a7 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -317,7 +317,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): guild_id=1, intents=discord.Intents.all(), ) - additional_spec_asyncs = ("wait_for", "redis_ready") + additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): spec_set = attachment_instance +message_reference_instance = discord.MessageReference( + message_id=unittest.mock.MagicMock(id=1), + channel_id=unittest.mock.MagicMock(id=2), + guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock MessageReference objects. + + Instances of this class will follow the specification of `discord.MessageReference` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = message_reference_instance + + def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): + super().__init__(**kwargs) + referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) + self.resolved = MockMessage(author=referenced_msg_author) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. diff --git a/tests/test_helpers.py b/tests/test_helpers.py index f3040b305..b2686b1d0 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -14,7 +14,7 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `discord.Role` pass + # The `spec` argument makes sure `isinstance` checks with `discord.Role` pass self.assertIsInstance(role, discord.Role) self.assertEqual(role.name, "role") |