From f599c7bb945a4d0e26ff3e9f5f234f3f34f5ff16 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:44:58 +0100 Subject: Remove call to get_event_loop in tests get_event_loop is deprecated as of 3.10 if there is no running loop. --- tests/bot/exts/moderation/test_silence.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..82ec138db 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -16,20 +16,19 @@ from tests.helpers import ( ) redis_session = None -redis_loop = asyncio.get_event_loop() def setUpModule(): # noqa: N802 """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) + asyncio.run(redis_session.connect()) def tearDownModule(): # noqa: N802 """Close the fakeredis session.""" if redis_session: - redis_loop.run_until_complete(redis_session.close()) + asyncio.run(redis_session.client.close()) # Have to subclass it because builtins can't be patched. -- cgit v1.2.3 From c906daa2250558962f00be1c423a9a0cff98f905 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:46:18 +0100 Subject: Remove warnings in error handler tests These warnings were caused by the setup coro from error_handler.py being imported directly, causing a warning about an un-awaited coro whenever the Cog was accessed from the same file. --- tests/bot/exts/backend/test_error_handler.py | 103 ++++++++++++--------------- 1 file changed, 47 insertions(+), 56 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError from discord.ext.commands import errors from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock(return_value=False) + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() - cog.try_run_eval.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - await setup(bot) + await error_handler.setup(bot) bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From 7782c196830098f81f39d235354636cd0d4a481d Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:52:52 +0100 Subject: No longer use the removed RedisSession connection object This has been abstracted away, the correct way to do this now is to directly access the client. --- bot/exts/info/doc/_redis_cache.py | 92 ++++++++++++++--------------- tests/bot/exts/moderation/test_incidents.py | 5 +- 2 files changed, 46 insertions(+), 51 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py index 8e08e7ae4..0f4d663d1 100644 --- a/bot/exts/info/doc/_redis_cache.py +++ b/bot/exts/info/doc/_redis_cache.py @@ -34,55 +34,52 @@ class DocRedisCache(RedisObject): redis_key = f"{self.namespace}:{item_key(item)}" needs_expire = False - with await self._get_pool_connection() as connection: - set_expire = self._set_expires.get(redis_key) - if set_expire is None: - # An expire is only set if the key didn't exist before. - ttl = await connection.ttl(redis_key) - log.debug(f"Checked TTL for `{redis_key}`.") - - if ttl == -1: - log.warning(f"Key `{redis_key}` had no expire set.") - if ttl < 0: # not set or didn't exist - needs_expire = True - else: - log.debug(f"Key `{redis_key}` has a {ttl} TTL.") - self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - - elif time.monotonic() > set_expire: - # If we got here the key expired in redis and we can be sure it doesn't exist. + set_expire = self._set_expires.get(redis_key) + if set_expire is None: + # An expire is only set if the key didn't exist before. + ttl = await self.redis_session.client.ttl(redis_key) + log.debug(f"Checked TTL for `{redis_key}`.") + + if ttl == -1: + log.warning(f"Key `{redis_key}` had no expire set.") + if ttl < 0: # not set or didn't exist needs_expire = True - log.debug(f"Key `{redis_key}` expired in internal key cache.") + else: + log.debug(f"Key `{redis_key}` has a {ttl} TTL.") + self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - await connection.hset(redis_key, item.symbol_id, value) - if needs_expire: - self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS - await connection.expire(redis_key, WEEK_SECONDS) - log.info(f"Set {redis_key} to expire in a week.") + elif time.monotonic() > set_expire: + # If we got here the key expired in redis and we can be sure it doesn't exist. + needs_expire = True + log.debug(f"Key `{redis_key}` expired in internal key cache.") + + await self.redis_session.client.hset(redis_key, item.symbol_id, value) + if needs_expire: + self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS + await self.redis_session.client.expire(redis_key, WEEK_SECONDS) + log.info(f"Set {redis_key} to expire in a week.") @namespace_lock async def get(self, item: DocItem) -> Optional[str]: """Return the Markdown content of the symbol `item` if it exists.""" - with await self._get_pool_connection() as connection: - return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8") + return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" pattern = f"{self.namespace}:{package}:*" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=pattern) - ] - if package_keys: - await connection.delete(*package_keys) - log.info(f"Deleted keys from redis: {package_keys}.") - self._set_expires = { - key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) - } - return True - return False + package_keys = [ + package_key async for package_key in self.redis_session.client.iscan(match=pattern) + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + log.info(f"Deleted keys from redis: {package_keys}.") + self._set_expires = { + key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) + } + return True + return False class StaleItemCounter(RedisObject): @@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject): If the counter didn't exist, initialize it with 1. """ key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}" - with await self._get_pool_connection() as connection: - await connection.expire(key, WEEK_SECONDS * 3) - return int(await connection.incr(key)) + await self.redis_session.client.expire(key, WEEK_SECONDS * 3) + return int(await self.redis_session.client.incr(key)) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") - ] - if package_keys: - await connection.delete(*package_keys) - return True - return False + package_keys = [ + package_key + async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*") + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + return True + return False def item_key(item: DocItem) -> str: diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..f60c177c5 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def flush(self): """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() + await self.session.client.flushall() async def asyncSetUp(self): # noqa: N802 self.session = RedisSession(use_fakeredis=True) @@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def asyncTearDown(self): # noqa: N802 if self.session: - await self.session.close() + await self.session.client.close() def setUp(self): """ -- cgit v1.2.3 From 46da1ecf621a64e6d8f0a37572378ae363ba76a2 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 25 Jul 2022 23:24:25 +0100 Subject: Stop creating futures in tests with no event loop running --- tests/bot/exts/moderation/test_silence.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 82ec138db..03b7b2fdb 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -250,8 +250,6 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -413,8 +411,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() @@ -686,8 +682,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache -- cgit v1.2.3 From 9cf3de3e9bf6725b2baa2e7adb77e058c216b332 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 26 Jul 2022 00:56:12 +0100 Subject: Remove unneeded N802 noqas pep-naming now supports these functions being in camel case. --- tests/bot/exts/moderation/test_incidents.py | 6 +++--- tests/bot/exts/moderation/test_silence.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index f60c177c5..211eb1bf8 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -285,12 +285,12 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): """Flush everything from the database to prevent carry-overs between tests.""" await self.session.client.flushall() - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): self.session = RedisSession(use_fakeredis=True) await self.session.connect() await self.flush() - async def asyncTearDown(self): # noqa: N802 + async def asyncTearDown(self): if self.session: await self.session.client.close() @@ -655,7 +655,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 03b7b2fdb..f5caefdca 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -18,14 +18,14 @@ from tests.helpers import ( redis_session = None -def setUpModule(): # noqa: N802 +def setUpModule(): """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) asyncio.run(redis_session.connect()) -def tearDownModule(): # noqa: N802 +def tearDownModule(): """Close the fakeredis session.""" if redis_session: asyncio.run(redis_session.client.close()) -- cgit v1.2.3 From 4a47c816641332fbb49f8c88c8a7720849cabf06 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:28:36 +0100 Subject: Add a new test helper for managing redis sessions This helper ensures that a fresh RedisSession is given to each test case that inherits from it. --- tests/base.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() -- cgit v1.2.3 From f044f36833e9dc003e89dd81868ea3f48a9da002 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:29:27 +0100 Subject: Use RedisTestCase helper class for both Incidents and Silence test cases. --- tests/bot/exts/moderation/test_incidents.py | 19 ++----------------- tests/bot/exts/moderation/test_silence.py | 23 ++++------------------- 2 files changed, 6 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 211eb1bf8..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,21 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - await self.session.client.flushall() - - async def asyncSetUp(self): - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): - if self.session: - await self.session.client.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index f5caefdca..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,30 +6,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None - - -def setUpModule(): - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - asyncio.run(redis_session.connect()) - - -def tearDownModule(): - """Close the fakeredis session.""" - if redis_session: - asyncio.run(redis_session.client.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -104,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase): """Tests for the general functionality of the Silence cog.""" @autospec(silence, "Scheduler", pass_mocks=False) @@ -244,7 +229,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase): """Tests for the silence argument parser utility function.""" def setUp(self): @@ -403,7 +388,7 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase): """Tests for the silence command and its related helper methods.""" @autospec(silence.Silence, "_reschedule", pass_mocks=False) -- cgit v1.2.3