diff options
author | 2022-08-14 19:53:25 +0100 | |
---|---|---|
committer | 2022-08-14 19:53:25 +0100 | |
commit | 02cff856c9159a8e1acf4fb502405fe30d2c9f68 (patch) | |
tree | bc09829a0af5c164546e34565e146f5793ac416e /tests | |
parent | Merge pull request #2240 from python-discord/2238-purge-cmd (diff) | |
parent | revert bump to markdownify version (diff) |
Merge pull request #2229 from python-discord/py3.10-rediscache
Diffstat (limited to 'tests')
-rw-r--r-- | tests/base.py | 24 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 103 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 22 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 30 |
4 files changed, 78 insertions, 101 deletions
diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError from discord.ext.commands import errors from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock(return_value=False) + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() - cog.try_run_eval.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - await setup(bot) + await error_handler.setup(bot) bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,22 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() - - async def asyncSetUp(self): # noqa: N802 - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): # noqa: N802 - if self.session: - await self.session.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. @@ -656,7 +640,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,31 +6,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule(): # noqa: N802 - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule(): # noqa: N802 - """Close the fakeredis session.""" - if redis_session: - redis_loop.run_until_complete(redis_session.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -105,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase): """Tests for the general functionality of the Silence cog.""" @autospec(silence, "Scheduler", pass_mocks=False) @@ -245,14 +229,12 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase): """Tests for the silence argument parser utility function.""" def setUp(self): self.bot = MockBot() self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -406,7 +388,7 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase): """Tests for the silence command and its related helper methods.""" @autospec(silence.Silence, "_reschedule", pass_mocks=False) @@ -414,8 +396,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() @@ -687,8 +667,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache |