diff options
| author | 2022-09-18 17:49:16 +0100 | |
|---|---|---|
| committer | 2022-09-18 17:49:16 +0100 | |
| commit | 7fb3f25278c1254a5c98f8f69817c8978d70db78 (patch) | |
| tree | 00684c73a2ddf1e5ab508e0b425c83aabdc59f27 /tests | |
| parent | Disable nose plugin in pytest (diff) | |
| parent | Use Python Poetry Base Action (#2277) (diff) | |
Merge branch 'main' into disable-pytest-nose-plugin
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/base.py | 24 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 103 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 11 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 29 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 71 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 103 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 58 | ||||
| -rw-r--r-- | tests/helpers.py | 24 | 
8 files changed, 249 insertions, 174 deletions
| diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager  from typing import Dict  import discord +from async_rediscache import RedisSession  from discord.ext import commands  from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):              await cmd.can_run(ctx)          self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): +    """ +    Use this as a base class for any test cases that require a redis session. + +    This will prepare a fresh redis instance for each test function, and will +    not make any assertions on its own. Tests can mutate the instance as they wish. +    """ + +    session = None + +    async def flush(self): +        """Flush everything from the redis database to prevent carry-overs between tests.""" +        await self.session.client.flushall() + +    async def asyncSetUp(self): +        self.session = await RedisSession(use_fakeredis=True).connect() +        await self.flush() + +    async def asyncTearDown(self): +        if self.session: +            await self.session.client.close() diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError  from discord.ext.commands import errors  from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler  from bot.exts.info.tags import Tags  from bot.exts.moderation.silence import Silence  from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_error_handler_already_handled(self):          """Should not do anything when error is already handled by local error handler."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandError()          error.handled = "foo" -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_not_awaited()      async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):                  "called_try_get_tag": True              }          ) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock(return_value=False) +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock(return_value=False)          for case in test_cases:              with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):                  self.ctx.reset_mock() -                cog.try_silence.reset_mock(return_value=True) -                cog.try_get_tag.reset_mock() +                self.cog.try_silence.reset_mock(return_value=True) +                self.cog.try_get_tag.reset_mock() -                cog.try_silence.return_value = case["try_silence_return"] +                self.cog.try_silence.return_value = case["try_silence_return"]                  self.ctx.channel.id = 1234 -                self.assertIsNone(await cog.on_command_error(self.ctx, error)) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, error))                  if case["try_silence_return"]: -                    cog.try_get_tag.assert_not_awaited() -                    cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_not_awaited() +                    self.cog.try_silence.assert_awaited_once()                  else: -                    cog.try_silence.assert_awaited_once() -                    cog.try_get_tag.assert_awaited_once() +                    self.cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_awaited_once()                  self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""          ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock() +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock()          error = errors.CommandNotFound() -        self.assertIsNone(await cog.on_command_error(ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(ctx, error)) -        cog.try_silence.assert_not_awaited() -        cog.try_get_tag.assert_not_awaited() -        cog.try_run_eval.assert_not_awaited() +        self.cog.try_silence.assert_not_awaited() +        self.cog.try_get_tag.assert_not_awaited() +        self.cog.try_run_eval.assert_not_awaited()          self.ctx.send.assert_not_awaited()      async def test_error_handler_user_input_error(self):          """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_user_input_error = AsyncMock() +        self.cog.handle_user_input_error = AsyncMock()          error = errors.UserInputError() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_check_failure(self):          """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_check_failure = AsyncMock() +        self.cog.handle_check_failure = AsyncMock()          error = errors.CheckFailure() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_command_on_cooldown(self):          """Should send error with `ctx.send` when error is `CommandOnCooldown`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandOnCooldown(10, 9, type=None) -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_awaited_once_with(error)      async def test_error_handler_command_invoke_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          test_cases = (              {                  "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), -                "expect_mock_call": cog.handle_api_error +                "expect_mock_call": self.cog.handle_api_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(TypeError)), -                "expect_mock_call": cog.handle_unexpected_error +                "expect_mock_call": self.cog.handle_unexpected_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for case in test_cases:              with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):                  self.ctx.send.reset_mock() -                self.assertIsNone(await cog.on_command_error(*case["args"])) +                self.assertIsNone(await self.cog.on_command_error(*case["args"]))                  if case["expect_mock_call"] == "send":                      self.ctx.send.assert_awaited_once()                  else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      async def test_error_handler_conversion_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          cases = (              {                  "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), -                "mock_function_to_call": cog.handle_api_error +                "mock_function_to_call": self.cog.handle_api_error              },              {                  "error": errors.ConversionError(AsyncMock(), TypeError), -                "mock_function_to_call": cog.handle_unexpected_error +                "mock_function_to_call": self.cog.handle_unexpected_error              }          )          for case in cases:              with self.subTest(**case): -                self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"]))                  case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)      async def test_error_handler_two_other_errors(self):          """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" -        cog = ErrorHandler(self.bot) -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          errs = (              errors.MaxConcurrencyReached(1, MagicMock()),              errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for err in errs:              with self.subTest(error=err): -                cog.handle_unexpected_error.reset_mock() -                self.assertIsNone(await cog.on_command_error(self.ctx, err)) -                cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) +                self.cog.handle_unexpected_error.reset_mock() +                self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) +                self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)      @patch("bot.exts.backend.error_handler.log")      async def test_error_handler_other_errors(self, log_mock):          """Should `log.debug` other errors.""" -        cog = ErrorHandler(self.bot)          error = errors.DisabledCommand()  # Use this just as a other error -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):          self.silence = Silence(self.bot)          self.bot.get_command.return_value = self.silence.silence          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_try_silence_context_invoked_from_error_handler(self):          """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):          self.bot = MockBot()          self.ctx = MockContext()          self.tag = Tags(self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)          self.bot.get_command.return_value = self.tag.get_command      async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_handle_input_error_handler_errors(self):          """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase):      async def test_setup(self):          """Should call `bot.add_cog` with `ErrorHandler`."""          bot = MockBot() -        await setup(bot) +        await error_handler.setup(bot)          bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..a18a4d23b 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -79,13 +79,13 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          """Should call voice mute applying function without expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None)      async def test_temporary_voice_mute(self):          """Should call voice mute applying function with expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz")      async def test_voice_unmute(self):          """Should call infraction pardoning function.""" @@ -189,7 +189,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          user = MockUser()          await self.cog.voicemute(self.cog, self.ctx, user, reason=None) -        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) +        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, +                                                     duration_or_expiry=None)          apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY)          # Test action @@ -273,7 +274,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.user,              "FooBar",              purge_days=1, -            expires_at=None, +            duration_or_expiry=None,          )      async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): @@ -285,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.ctx,              self.user,              "FooBar", -            expires_at=None, +            duration_or_expiry=None,          )      @patch("bot.exts.moderation.infraction.infractions.Age") diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5cf02033d..29dadf372 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -1,7 +1,7 @@  import unittest  from collections import namedtuple  from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch  from botcore.site_api import ResponseCodeError  from discord import Embed, Forbidden, HTTPException, NotFound @@ -309,8 +309,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      async def test_normal_post_infraction(self):          """Should return response from POST request if there are no errors.""" -        now = datetime.now() -        payload = { +        now = datetime.utcnow() +        expected = {              "actor": self.ctx.author.id,              "hidden": True,              "reason": "Test reason", @@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):              "user": self.member.id,              "active": False,              "expires_at": now.isoformat(), -            "dm_sent": False +            "dm_sent": False,          }          self.ctx.bot.api_client.post.return_value = "foo"          actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) -          self.assertEqual(actual, "foo") -        self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) +        self.ctx.bot.api_client.post.assert_awaited_once() + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"] +        self.assertEqual(payload, payload | expected)      async def test_unknown_error_post_infraction(self):          """Should send an error message to chat when a non-400 error occurs.""" @@ -349,19 +352,25 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar")      async def test_first_fail_second_success_user_post_infraction(self, post_user_mock):          """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" -        payload = { +        expected = {              "actor": self.ctx.author.id,              "hidden": False,              "reason": "Test reason",              "type": "mute",              "user": self.user.id,              "active": True, -            "dm_sent": False +            "dm_sent": False,          }          self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] -          actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason")          self.assertEqual(actual, "foo") -        self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) +        await_args = self.bot.api_client.post.await_args_list +        self.assertEqual(len(await_args), 2, "Expected 2 awaits") + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        for args in await_args: +            payload: dict = args.kwargs["json"] +            self.assertEqual(payload, payload | expected) +          post_user_mock.assert_awaited_once_with(self.ctx, self.user) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@  import asyncio +import datetime  import enum  import logging  import typing as t @@ -8,16 +9,19 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch  import aiohttp  import discord -from async_rediscache import RedisSession  from bot.constants import Colours  from bot.exts.moderation import incidents  from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp +from tests.base import RedisTestCase  from tests.helpers import (      MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,      MockUser  ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) +  class MockAsyncIterable:      """ @@ -100,30 +104,45 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):      async def test_make_embed_actioned(self):          """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" -        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) +        embed, file = await incidents.make_embed( +            incident=MockMessage(created_at=CURRENT_TIME), +            outcome=incidents.Signal.ACTIONED, +            actioned_by=MockMember() +        )          self.assertEqual(embed.colour.value, Colours.soft_green)          self.assertIn("Actioned", embed.footer.text)      async def test_make_embed_not_actioned(self):          """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" -        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) +        embed, file = await incidents.make_embed( +            incident=MockMessage(created_at=CURRENT_TIME), +            outcome=incidents.Signal.NOT_ACTIONED, +            actioned_by=MockMember() +        )          self.assertEqual(embed.colour.value, Colours.soft_red)          self.assertIn("Rejected", embed.footer.text)      async def test_make_embed_content(self):          """Incident content appears as embed description.""" -        incident = MockMessage(content="this is an incident") +        incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) + +        reported_timestamp = discord_timestamp(CURRENT_TIME) +        relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) +          embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) -        self.assertEqual(incident.content, embed.description) +        self.assertEqual( +            f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", +            embed.description +        )      async def test_make_embed_with_attachment_succeeds(self):          """Incident's attachment is downloaded and displayed in the embed's image field."""          file = MagicMock(discord.File, filename="bigbadjoe.jpg")          attachment = MockAttachment(filename="bigbadjoe.jpg") -        incident = MockMessage(content="this is an incident", attachments=[attachment]) +        incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)          # Patch `download_file` to return our `file`          with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -135,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):      async def test_make_embed_with_attachment_fails(self):          """Incident's attachment fails to download, proxy url is linked instead."""          attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") -        incident = MockMessage(content="this is an incident", attachments=[attachment]) +        incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)          # Patch `download_file` to return None as if the download failed          with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -270,7 +289,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase):          self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase):      """      Tests for bound methods of the `Incidents` cog. @@ -279,22 +298,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      the instance as they wish.      """ -    session = None - -    async def flush(self): -        """Flush everything from the database to prevent carry-overs between tests.""" -        with await self.session.pool as connection: -            await connection.flushall() - -    async def asyncSetUp(self):  # noqa: N802 -        self.session = RedisSession(use_fakeredis=True) -        await self.session.connect() -        await self.flush() - -    async def asyncTearDown(self):  # noqa: N802 -        if self.session: -            await self.session.close() -      def setUp(self):          """          Prepare a fresh `Incidents` instance for each test. @@ -365,7 +368,6 @@ class TestCrawlIncidents(TestIncidents):  class TestArchive(TestIncidents):      """Tests for the `Incidents.archive` coroutine.""" -      async def test_archive_webhook_not_found(self):          """          Method recovers and returns False when the webhook is not found. @@ -375,7 +377,11 @@ class TestArchive(TestIncidents):          """          self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404)          self.assertFalse( -            await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) +            await self.cog_instance.archive( +                incident=MockMessage(created_at=CURRENT_TIME), +                outcome=MagicMock(), +                actioned_by=MockMember() +            )          )      async def test_archive_relays_incident(self): @@ -391,7 +397,7 @@ class TestArchive(TestIncidents):          # Define our own `incident` to be archived          incident = MockMessage(              content="this is an incident", -            author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), +            author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")),              id=123,          )          built_embed = MagicMock(discord.Embed, id=123)  # We patch `make_embed` to return this @@ -422,7 +428,7 @@ class TestArchive(TestIncidents):          webhook = MockAsyncWebhook()          self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) -        message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) +        message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME)          await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember())          self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -521,12 +527,13 @@ class TestProcessEvent(TestIncidents):      async def test_process_event_confirmation_task_is_awaited(self):          """Task given by `Incidents.make_confirmation_task` is awaited before method exits."""          mock_task = AsyncMock() +        mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)])          with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):              await self.cog_instance.process_event(                  reaction=incidents.Signal.ACTIONED.value, -                incident=MockMessage(id=123), -                member=MockMember(roles=[MockRole(id=1)]) +                incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), +                member=mock_member              )          mock_task.assert_awaited() @@ -545,7 +552,7 @@ class TestProcessEvent(TestIncidents):              with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):                  await self.cog_instance.process_event(                      reaction=incidents.Signal.ACTIONED.value, -                    incident=MockMessage(id=123), +                    incident=MockMessage(id=123, created_at=CURRENT_TIME),                      member=MockMember(roles=[MockRole(id=1)])                  )          except asyncio.TimeoutError: @@ -656,7 +663,7 @@ class TestOnRawReactionAdd(TestIncidents):              emoji="reaction",          ) -    async def asyncSetUp(self):  # noqa: N802 +    async def asyncSetUp(self):          """          Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio  import itertools  import unittest  from datetime import datetime, timezone @@ -6,31 +5,15 @@ from typing import List, Tuple  from unittest import mock  from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession  from discord import PermissionOverwrite  from bot.constants import Channels, Guild, MODERATION_ROLES, Roles  from bot.exts.moderation import silence +from tests.base import RedisTestCase  from tests.helpers import (      MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec  ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule():  # noqa: N802 -    """Create and connect to the fakeredis session.""" -    global redis_session -    redis_session = RedisSession(use_fakeredis=True) -    redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule():  # noqa: N802 -    """Close the fakeredis session.""" -    if redis_session: -        redis_loop.run_until_complete(redis_session.close()) -  # Have to subclass it because builtins can't be patched.  class PatchedDatetime(datetime): @@ -39,8 +22,24 @@ class PatchedDatetime(datetime):      now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): +    """A base class for Silence tests that correctly sets up the cog and redis.""" + +    @autospec(silence, "Scheduler", pass_mocks=False) +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) +        self.cog = silence.Silence(self.bot) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest):      def setUp(self) -> None: +        super().setUp()          self.alert_channel = MockTextChannel()          self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock() @@ -105,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(SilenceTest):      """Tests for the general functionality of the Silence cog.""" -    @autospec(silence, "Scheduler", pass_mocks=False) -    def setUp(self) -> None: -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog.cog_load()          self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id)      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_channels(self):          """Got channels from bot.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)      @autospec(silence, "SilenceNotifier")      async def test_cog_load_got_notifier(self, notifier):          """Notifier was started with channel.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))          self.assertEqual(self.cog.notifier, notifier.return_value) @@ -245,15 +234,9 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):              self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(SilenceTest):      """Tests for the silence argument parser utility function.""" -    def setUp(self): -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) -      @autospec(silence.Silence, "send_message", pass_mocks=False)      @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)      @autospec(silence.Silence, "parse_silence_args") @@ -321,17 +304,19 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase):      """Tests for the rescheduling of cached unsilences.""" -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) -    def setUp(self): +    @autospec(silence, "Scheduler", pass_mocks=False) +    def setUp(self) -> None:          self.bot = MockBot()          self.cog = silence.Silence(self.bot)          self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) -        with mock.patch.object(self.cog, "_reschedule", autospec=True): -            asyncio.run(self.cog.cog_load())  # Populate instance attributes. +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes.      async def test_skipped_missing_channel(self):          """Did nothing because the channel couldn't be retrieved.""" @@ -406,22 +391,14 @@ def voice_sync_helper(function):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(SilenceTest):      """Tests for the silence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) +        super().setUp()          # Avoid unawaited coroutine warnings.          self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. -          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(              send_messages=True, @@ -679,24 +656,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest):      """Tests for the unsilence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) - -        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) -        self.cog.previous_overwrites = overwrites_cache - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. +        super().setUp()          self.cog.scheduler.__contains__.return_value = True -        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False)          self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -705,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)          self.voice_channel.overwrites_for.return_value = self.voice_overwrite +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +      async def test_sent_correct_message(self):          """Appropriate failure/success message was sent by the command."""          unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index f8805ac48..e1f904917 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,15 +1,32 @@ -from typing import Iterable +from typing import Iterable, Optional + +import discord  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage +from tests.helpers import MockMember, MockMessage, MockMessageReference -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: -    """Makes a message with `total_mentions` mentions.""" +def make_msg( +    author: str, +    total_user_mentions: int, +    total_bot_mentions: int = 0, +    *, +    reference: Optional[MockMessageReference] = None +) -> MockMessage: +    """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""      user_mentions = [MockMember() for _ in range(total_user_mentions)]      bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] -    return MockMessage(author=author, mentions=user_mentions+bot_mentions) + +    mentions = user_mentions + bot_mentions +    if reference is not None: +        # For the sake of these tests we assume that all references are mentions. +        mentions.append(reference.resolved.author) +        msg_type = discord.MessageType.reply +    else: +        msg_type = discord.MessageType.default + +    return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)  class TestMentions(RuleTest): @@ -56,6 +73,16 @@ class TestMentions(RuleTest):                  ("bob",),                  3,              ), +            DisallowedCase( +                [make_msg("bob", 3, reference=MockMessageReference())], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], +                ("bob",), +                3 +            )          )          await self.run_disallowed(cases) @@ -71,6 +98,27 @@ class TestMentions(RuleTest):          await self.run_allowed(cases) +    async def test_ignore_reply_mentions(self): +        """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" +        cases = ( +            [ +                make_msg("bob", 2, reference=MockMessageReference()) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference()), +                make_msg("bob", 0, reference=MockMessageReference()) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), +                make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) +            ] +        ) + +        await self.run_allowed(cases) +      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          last_message = case.recent_messages[0]          return tuple( diff --git a/tests/helpers.py b/tests/helpers.py index 17214553c..a4b919dcb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -317,7 +317,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):          guild_id=1,          intents=discord.Intents.all(),      ) -    additional_spec_asyncs = ("wait_for", "redis_ready") +    additional_spec_asyncs = ("wait_for",)      def __init__(self, **kwargs) -> None:          super().__init__(**kwargs) @@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):      spec_set = attachment_instance +message_reference_instance = discord.MessageReference( +    message_id=unittest.mock.MagicMock(id=1), +    channel_id=unittest.mock.MagicMock(id=2), +    guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock MessageReference objects. + +    Instances of this class will follow the specification of `discord.MessageReference` instances. +    For more information, see the `MockGuild` docstring. +    """ +    spec_set = message_reference_instance + +    def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): +        super().__init__(**kwargs) +        referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) +        self.resolved = MockMessage(author=referenced_msg_author) + +  class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      """      A MagicMock subclass to mock Message objects. | 
