diff options
| author | 2022-09-14 20:58:02 +0100 | |
|---|---|---|
| committer | 2022-09-14 20:58:02 +0100 | |
| commit | 15e0491a3ba533a2423d44b415de355d1152f84c (patch) | |
| tree | ec4d951b1a23713b35840d6db7d051c4a48750e0 /tests | |
| parent | Update docstrings & comment. (diff) | |
| parent | Merge branch 'main' into bot-2231-bug (diff) | |
Merge remote-tracking branch 'origin/bot-2231-bug' into bot-2231-bug
Diffstat (limited to '')
| -rw-r--r-- | tests/base.py | 24 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 103 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 22 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 30 | 
4 files changed, 78 insertions, 101 deletions
| diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager  from typing import Dict  import discord +from async_rediscache import RedisSession  from discord.ext import commands  from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):              await cmd.can_run(ctx)          self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): +    """ +    Use this as a base class for any test cases that require a redis session. + +    This will prepare a fresh redis instance for each test function, and will +    not make any assertions on its own. Tests can mutate the instance as they wish. +    """ + +    session = None + +    async def flush(self): +        """Flush everything from the redis database to prevent carry-overs between tests.""" +        await self.session.client.flushall() + +    async def asyncSetUp(self): +        self.session = await RedisSession(use_fakeredis=True).connect() +        await self.flush() + +    async def asyncTearDown(self): +        if self.session: +            await self.session.client.close() diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError  from discord.ext.commands import errors  from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler  from bot.exts.info.tags import Tags  from bot.exts.moderation.silence import Silence  from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_error_handler_already_handled(self):          """Should not do anything when error is already handled by local error handler."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandError()          error.handled = "foo" -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_not_awaited()      async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):                  "called_try_get_tag": True              }          ) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock(return_value=False) +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock(return_value=False)          for case in test_cases:              with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):                  self.ctx.reset_mock() -                cog.try_silence.reset_mock(return_value=True) -                cog.try_get_tag.reset_mock() +                self.cog.try_silence.reset_mock(return_value=True) +                self.cog.try_get_tag.reset_mock() -                cog.try_silence.return_value = case["try_silence_return"] +                self.cog.try_silence.return_value = case["try_silence_return"]                  self.ctx.channel.id = 1234 -                self.assertIsNone(await cog.on_command_error(self.ctx, error)) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, error))                  if case["try_silence_return"]: -                    cog.try_get_tag.assert_not_awaited() -                    cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_not_awaited() +                    self.cog.try_silence.assert_awaited_once()                  else: -                    cog.try_silence.assert_awaited_once() -                    cog.try_get_tag.assert_awaited_once() +                    self.cog.try_silence.assert_awaited_once() +                    self.cog.try_get_tag.assert_awaited_once()                  self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""          ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) -        cog = ErrorHandler(self.bot) -        cog.try_silence = AsyncMock() -        cog.try_get_tag = AsyncMock() -        cog.try_run_eval = AsyncMock() +        self.cog.try_silence = AsyncMock() +        self.cog.try_get_tag = AsyncMock() +        self.cog.try_run_eval = AsyncMock()          error = errors.CommandNotFound() -        self.assertIsNone(await cog.on_command_error(ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(ctx, error)) -        cog.try_silence.assert_not_awaited() -        cog.try_get_tag.assert_not_awaited() -        cog.try_run_eval.assert_not_awaited() +        self.cog.try_silence.assert_not_awaited() +        self.cog.try_get_tag.assert_not_awaited() +        self.cog.try_run_eval.assert_not_awaited()          self.ctx.send.assert_not_awaited()      async def test_error_handler_user_input_error(self):          """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_user_input_error = AsyncMock() +        self.cog.handle_user_input_error = AsyncMock()          error = errors.UserInputError() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_check_failure(self):          """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot) -        cog.handle_check_failure = AsyncMock() +        self.cog.handle_check_failure = AsyncMock()          error = errors.CheckFailure() -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) -        cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) +        self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)      async def test_error_handler_command_on_cooldown(self):          """Should send error with `ctx.send` when error is `CommandOnCooldown`."""          self.ctx.reset_mock() -        cog = ErrorHandler(self.bot)          error = errors.CommandOnCooldown(10, 9, type=None) -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          self.ctx.send.assert_awaited_once_with(error)      async def test_error_handler_command_invoke_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          test_cases = (              {                  "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), -                "expect_mock_call": cog.handle_api_error +                "expect_mock_call": self.cog.handle_api_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(TypeError)), -                "expect_mock_call": cog.handle_unexpected_error +                "expect_mock_call": self.cog.handle_unexpected_error              },              {                  "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for case in test_cases:              with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):                  self.ctx.send.reset_mock() -                self.assertIsNone(await cog.on_command_error(*case["args"])) +                self.assertIsNone(await self.cog.on_command_error(*case["args"]))                  if case["expect_mock_call"] == "send":                      self.ctx.send.assert_awaited_once()                  else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      async def test_error_handler_conversion_error(self):          """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" -        cog = ErrorHandler(self.bot) -        cog.handle_api_error = AsyncMock() -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_api_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          cases = (              {                  "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), -                "mock_function_to_call": cog.handle_api_error +                "mock_function_to_call": self.cog.handle_api_error              },              {                  "error": errors.ConversionError(AsyncMock(), TypeError), -                "mock_function_to_call": cog.handle_unexpected_error +                "mock_function_to_call": self.cog.handle_unexpected_error              }          )          for case in cases:              with self.subTest(**case): -                self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) +                self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"]))                  case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)      async def test_error_handler_two_other_errors(self):          """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" -        cog = ErrorHandler(self.bot) -        cog.handle_unexpected_error = AsyncMock() +        self.cog.handle_unexpected_error = AsyncMock()          errs = (              errors.MaxConcurrencyReached(1, MagicMock()),              errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          for err in errs:              with self.subTest(error=err): -                cog.handle_unexpected_error.reset_mock() -                self.assertIsNone(await cog.on_command_error(self.ctx, err)) -                cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) +                self.cog.handle_unexpected_error.reset_mock() +                self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) +                self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)      @patch("bot.exts.backend.error_handler.log")      async def test_error_handler_other_errors(self, log_mock):          """Should `log.debug` other errors.""" -        cog = ErrorHandler(self.bot)          error = errors.DisabledCommand()  # Use this just as a other error -        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.assertIsNone(await self.cog.on_command_error(self.ctx, error))          log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):          self.silence = Silence(self.bot)          self.bot.get_command.return_value = self.silence.silence          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_try_silence_context_invoked_from_error_handler(self):          """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):          self.bot = MockBot()          self.ctx = MockContext()          self.tag = Tags(self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)          self.bot.get_command.return_value = self.tag.get_command      async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = MockBot()          self.ctx = MockContext(bot=self.bot) -        self.cog = ErrorHandler(self.bot) +        self.cog = error_handler.ErrorHandler(self.bot)      async def test_handle_input_error_handler_errors(self):          """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase):      async def test_setup(self):          """Should call `bot.add_cog` with `ErrorHandler`."""          bot = MockBot() -        await setup(bot) +        await error_handler.setup(bot)          bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch  import aiohttp  import discord -from async_rediscache import RedisSession  from bot.constants import Colours  from bot.exts.moderation import incidents  from bot.utils.messages import format_user +from tests.base import RedisTestCase  from tests.helpers import (      MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,      MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase):          self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase):      """      Tests for bound methods of the `Incidents` cog. @@ -279,22 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):      the instance as they wish.      """ -    session = None - -    async def flush(self): -        """Flush everything from the database to prevent carry-overs between tests.""" -        with await self.session.pool as connection: -            await connection.flushall() - -    async def asyncSetUp(self):  # noqa: N802 -        self.session = RedisSession(use_fakeredis=True) -        await self.session.connect() -        await self.flush() - -    async def asyncTearDown(self):  # noqa: N802 -        if self.session: -            await self.session.close() -      def setUp(self):          """          Prepare a fresh `Incidents` instance for each test. @@ -656,7 +640,7 @@ class TestOnRawReactionAdd(TestIncidents):              emoji="reaction",          ) -    async def asyncSetUp(self):  # noqa: N802 +    async def asyncSetUp(self):          """          Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,31 +6,15 @@ from typing import List, Tuple  from unittest import mock  from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession  from discord import PermissionOverwrite  from bot.constants import Channels, Guild, MODERATION_ROLES, Roles  from bot.exts.moderation import silence +from tests.base import RedisTestCase  from tests.helpers import (      MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec  ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule():  # noqa: N802 -    """Create and connect to the fakeredis session.""" -    global redis_session -    redis_session = RedisSession(use_fakeredis=True) -    redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule():  # noqa: N802 -    """Close the fakeredis session.""" -    if redis_session: -        redis_loop.run_until_complete(redis_session.close()) -  # Have to subclass it because builtins can't be patched.  class PatchedDatetime(datetime): @@ -105,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase):      """Tests for the general functionality of the Silence cog."""      @autospec(silence, "Scheduler", pass_mocks=False) @@ -245,14 +229,12 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):              self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase):      """Tests for the silence argument parser utility function."""      def setUp(self):          self.bot = MockBot()          self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None)      @autospec(silence.Silence, "send_message", pass_mocks=False)      @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -406,7 +388,7 @@ def voice_sync_helper(function):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase):      """Tests for the silence command and its related helper methods."""      @autospec(silence.Silence, "_reschedule", pass_mocks=False) @@ -414,8 +396,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.bot = MockBot(get_channel=lambda _: MockTextChannel())          self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None)          # Avoid unawaited coroutine warnings.          self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() @@ -687,8 +667,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.bot = MockBot(get_channel=lambda _: MockTextChannel())          self.cog = silence.Silence(self.bot) -        self.cog._init_task = asyncio.Future() -        self.cog._init_task.set_result(None)          overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)          self.cog.previous_overwrites = overwrites_cache | 
