diff options
Diffstat (limited to '')
| -rw-r--r-- | tests/base.py | 24 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 103 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 22 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 103 | 
4 files changed, 112 insertions, 140 deletions
| diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager  from typing import Dict  import discord +from async_rediscache import RedisSession  from discord.ext import commands  from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):              await cmd.can_run(ctx)          self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): +    """ +    Use this as a base class for any test cases that require a redis session. + +    This will prepare a fresh redis instance for each test function, and will +    not make any assertions on its own. Tests can mutate the instance as they wish. +    """ + +    session = None + +    async def flush(self): +        """Flush everything from the redis database to prevent carry-overs between tests.""" +        await self.session.client.flushall() + +    async def asyncSetUp(self): +        self.session = await RedisSession(use_fakeredis=True).connect() +        await self.flush() + +    async def asyncTearDown(self): +        if self.session: +            await self.session.client.close() diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError  from discord.ext.commands import errors  from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler  from bot.exts.info.tags import Tags  from bot.exts.moderation.silence import Silence  from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_error_handler_already_handled(self):          """Should not do anything when error is already handled by local error handler."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandError()          error.handled = "foo" -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_not_awaited()      async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):                  "called_try_get_tag": True              }          ) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock(return_value=False) +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock(return_value=False)          for case in test_cases:              with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):                  self.ctx.reset_mock() -                cog.try_silence.reset_mock(return_value=True) -                cog.try_get_tag.reset_mock() +                self.cog.try_silence.reset_mock(return_value=True) +                self.cog.try_get_tag.reset_mock() -                cog.try_silence.return_value = case["try_silence_return"] +                self.cog.try_silence.return_value = case["try_silence_return"]                  self.ctx.channel.id = 1234 -                self.assertIsNone(await cog.on_command_error(self.ctx, error)) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, error))                  if case["try_silence_return"]: -                    cog.try_get_tag.assert_not_awaited() -                    cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_not_awaited() +                    self.cog.try_silence.assert_awaited_once()                  else: -                    cog.try_silence.assert_awaited_once() -                    cog.try_get_tag.assert_awaited_once() +                    self.cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_awaited_once()                  self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""          ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock() +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock()          error = errors.CommandNotFound() -        self.assertIsNone(await cog.on_command_error(ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(ctx, error)) -        cog.try_silence.assert_not_awaited() -        cog.try_get_tag.assert_not_awaited() -        cog.try_run_eval.assert_not_awaited() +        self.cog.try_silence.assert_not_awaited() +        self.cog.try_get_tag.assert_not_awaited() +        self.cog.try_run_eval.assert_not_awaited()          self.ctx.send.assert_not_awaited()      async def test_error_handler_user_input_error(self):          """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_user_input_error = AsyncMock() +        self.cog.handle_user_input_error = AsyncMock()          error = errors.UserInputError() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_check_failure(self):          """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_check_failure = AsyncMock() +        self.cog.handle_check_failure = AsyncMock()          error = errors.CheckFailure() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_command_on_cooldown(self):          """Should send error with `ctx.send` when error is `CommandOnCooldown`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandOnCooldown(10, 9, type=None) -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_awaited_once_with(error)      async def test_error_handler_command_invoke_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          test_cases = (              {                  "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), -                "expect_mock_call": cog.handle_api_error +                "expect_mock_call": self.cog.handle_api_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(TypeError)), -                "expect_mock_call": cog.handle_unexpected_error +                "expect_mock_call": self.cog.handle_unexpected_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for case in test_cases:              with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):                  self.ctx.send.reset_mock() -                self.assertIsNone(await cog.on_command_error(*case["args"])) +                self.assertIsNone(await self.cog.on_command_error(*case["args"]))                  if case["expect_mock_call"] == "send":                      self.ctx.send.assert_awaited_once()                  else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      async def test_error_handler_conversion_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          cases = (              {                  "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), -                "mock_function_to_call": cog.handle_api_error +                "mock_function_to_call": self.cog.handle_api_error              },              {                  "error": errors.ConversionError(AsyncMock(), TypeError), -                "mock_function_to_call": cog.handle_unexpected_error +                "mock_function_to_call": self.cog.handle_unexpected_error              }          )          for case in cases:              with self.subTest(**case): -                self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"]))                  case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)      async def test_error_handler_two_other_errors(self):          """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" -        cog = ErrorHandler(self.bot) -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          errs = (              errors.MaxConcurrencyReached(1, MagicMock()),              errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for err in errs:              with self.subTest(error=err): -                cog.handle_unexpected_error.reset_mock() -                self.assertIsNone(await cog.on_command_error(self.ctx, err)) -                cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) +                self.cog.handle_unexpected_error.reset_mock() +                self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) +                self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)      @patch("bot.exts.backend.error_handler.log")      async def test_error_handler_other_errors(self, log_mock):          """Should `log.debug` other errors.""" -        cog = ErrorHandler(self.bot)          error = errors.DisabledCommand()  # Use this just as a other error -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):          self.silence = Silence(self.bot)          self.bot.get_command.return_value = self.silence.silence          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_try_silence_context_invoked_from_error_handler(self):          """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):          self.bot = MockBot()          self.ctx = MockContext()          self.tag = Tags(self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)          self.bot.get_command.return_value = self.tag.get_command      async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_handle_input_error_handler_errors(self):          """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase):      async def test_setup(self):          """Should call `bot.add_cog` with `ErrorHandler`."""          bot = MockBot() -        await setup(bot) +        await error_handler.setup(bot)          bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch  import aiohttp  import discord -from async_rediscache import RedisSession  from bot.constants import Colours  from bot.exts.moderation import incidents  from bot.utils.messages import format_user +from tests.base import RedisTestCase  from tests.helpers import (      MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,      MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase):          self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase):      """      Tests for bound methods of the `Incidents` cog. @@ -279,22 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      the instance as they wish.      """ -    session = None - -    async def flush(self): -        """Flush everything from the database to prevent carry-overs between tests.""" -        with await self.session.pool as connection: -            await connection.flushall() - -    async def asyncSetUp(self):  # noqa: N802 -        self.session = RedisSession(use_fakeredis=True) -        await self.session.connect() -        await self.flush() - -    async def asyncTearDown(self):  # noqa: N802 -        if self.session: -            await self.session.close() -      def setUp(self):          """          Prepare a fresh `Incidents` instance for each test. @@ -656,7 +640,7 @@ class TestOnRawReactionAdd(TestIncidents):              emoji="reaction",          ) -    async def asyncSetUp(self):  # noqa: N802 +    async def asyncSetUp(self):          """          Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio  import itertools  import unittest  from datetime import datetime, timezone @@ -6,31 +5,15 @@ from typing import List, Tuple  from unittest import mock  from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession  from discord import PermissionOverwrite  from bot.constants import Channels, Guild, MODERATION_ROLES, Roles  from bot.exts.moderation import silence +from tests.base import RedisTestCase  from tests.helpers import (      MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec  ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule():  # noqa: N802 -    """Create and connect to the fakeredis session.""" -    global redis_session -    redis_session = RedisSession(use_fakeredis=True) -    redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule():  # noqa: N802 -    """Close the fakeredis session.""" -    if redis_session: -        redis_loop.run_until_complete(redis_session.close()) -  # Have to subclass it because builtins can't be patched.  class PatchedDatetime(datetime): @@ -39,8 +22,24 @@ class PatchedDatetime(datetime):      now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): +    """A base class for Silence tests that correctly sets up the cog and redis.""" + +    @autospec(silence, "Scheduler", pass_mocks=False) +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) +        self.cog = silence.Silence(self.bot) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest):      def setUp(self) -> None: +        super().setUp()          self.alert_channel = MockTextChannel()          self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock() @@ -105,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(SilenceTest):      """Tests for the general functionality of the Silence cog.""" -    @autospec(silence, "Scheduler", pass_mocks=False) -    def setUp(self) -> None: -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog.cog_load()          self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id)      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_channels(self):          """Got channels from bot.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)      @autospec(silence, "SilenceNotifier")      async def test_cog_load_got_notifier(self, notifier):          """Notifier was started with channel.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))          self.assertEqual(self.cog.notifier, notifier.return_value) @@ -245,15 +234,9 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):              self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(SilenceTest):      """Tests for the silence argument parser utility function.""" -    def setUp(self): -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) -      @autospec(silence.Silence, "send_message", pass_mocks=False)      @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)      @autospec(silence.Silence, "parse_silence_args") @@ -321,17 +304,19 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase):      """Tests for the rescheduling of cached unsilences.""" -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) -    def setUp(self): +    @autospec(silence, "Scheduler", pass_mocks=False) +    def setUp(self) -> None:          self.bot = MockBot()          self.cog = silence.Silence(self.bot)          self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) -        with mock.patch.object(self.cog, "_reschedule", autospec=True): -            asyncio.run(self.cog.cog_load())  # Populate instance attributes. +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes.      async def test_skipped_missing_channel(self):          """Did nothing because the channel couldn't be retrieved.""" @@ -406,22 +391,14 @@ def voice_sync_helper(function):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(SilenceTest):      """Tests for the silence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) +        super().setUp()          # Avoid unawaited coroutine warnings.          self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. -          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(              send_messages=True, @@ -679,24 +656,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest):      """Tests for the unsilence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None) - -        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) -        self.cog.previous_overwrites = overwrites_cache - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. +        super().setUp()          self.cog.scheduler.__contains__.return_value = True -        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False)          self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -705,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)          self.voice_channel.overwrites_for.return_value = self.voice_overwrite +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +      async def test_sent_correct_message(self):          """Appropriate failure/success message was sent by the command."""          unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) | 
