From 1e0c0cfe37eb9a868454508fcb813d7cf19e12cc Mon Sep 17 00:00:00 2001 From: Izan Date: Wed, 29 Dec 2021 15:09:22 +0000 Subject: Fix tests --- tests/bot/exts/moderation/test_incidents.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..ef33aa62b 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@ import asyncio +import datetime import enum import logging import typing as t @@ -13,6 +14,7 @@ from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -114,10 +116,19 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_content(self): """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") + current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) + incident = MockMessage(content="this is an incident", created_at=current_time) + + day_timestamp = discord_timestamp(current_time, TimestampFormats.DATE) + time_timestamp = discord_timestamp(current_time, TimestampFormats.TIME) + relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - self.assertEqual(incident.content, embed.description) + self.assertEqual( + f"{incident.content}\n\n__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__", + embed.description + ) async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" @@ -391,7 +402,7 @@ class TestArchive(TestIncidents): # Define our own `incident` to be archived incident = MockMessage( content="this is an incident", - author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), + author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this @@ -422,7 +433,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great")) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -521,12 +532,13 @@ class TestProcessEvent(TestIncidents): async def test_process_event_confirmation_task_is_awaited(self): """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" mock_task = AsyncMock() + mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)]) with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), - member=MockMember(roles=[MockRole(id=1)]) + incident=MockMessage(author=mock_member, id=123), + member=mock_member ) mock_task.assert_awaited() -- cgit v1.2.3 From b7e03616ac3fc0b5e8a5a77a352df593983d187a Mon Sep 17 00:00:00 2001 From: Izan Date: Thu, 14 Jul 2022 22:21:34 +0100 Subject: Address Reviews - Use the more concise DATETIME timestamp instead of both a DATE and a TIME timestamp. - Remove underline from the "Reported ..." section at the bottom of the embed. - Re-add time of action/rejection timestamp to footer of embed. --- bot/exts/moderation/incidents.py | 7 ++++--- tests/bot/exts/moderation/test_incidents.py | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index bd9e5b88e..f29cfcdd6 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -1,5 +1,6 @@ import asyncio import re +from datetime import datetime, timezone from enum import Enum from typing import Optional @@ -97,10 +98,9 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di colour = Colours.soft_red footer = f"Rejected by {actioned_by}" - day_timestamp = discord_timestamp(incident.created_at, TimestampFormats.DATE) - time_timestamp = discord_timestamp(incident.created_at, TimestampFormats.TIME) + reported_timestamp = discord_timestamp(incident.created_at) relative_timestamp = discord_timestamp(incident.created_at, TimestampFormats.RELATIVE) - reported_on_msg = f"__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__" + reported_on_msg = f"*Reported {reported_timestamp} ({relative_timestamp}).*" # If the description will be too long (>4096 total characters), truncate the incident content if len(incident.content) > (allowed_content_chars := 4096-len(reported_on_msg)-2): # -2 for the newlines @@ -111,6 +111,7 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di embed = discord.Embed( description=description, colour=colour, + timestamp=datetime.now(timezone.utc) ) embed.set_footer(text=footer, icon_url=actioned_by.display_avatar.url) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index ef33aa62b..da0a79ce8 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -119,14 +119,13 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) incident = MockMessage(content="this is an incident", created_at=current_time) - day_timestamp = discord_timestamp(current_time, TimestampFormats.DATE) - time_timestamp = discord_timestamp(current_time, TimestampFormats.TIME) + reported_timestamp = discord_timestamp(current_time) relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) self.assertEqual( - f"{incident.content}\n\n__*Reported {day_timestamp} at {time_timestamp} ({relative_timestamp}).*__", + f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", embed.description ) -- cgit v1.2.3 From f599c7bb945a4d0e26ff3e9f5f234f3f34f5ff16 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:44:58 +0100 Subject: Remove call to get_event_loop in tests get_event_loop is deprecated as of 3.10 if there is no running loop. --- tests/bot/exts/moderation/test_silence.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 65aecad28..82ec138db 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -16,20 +16,19 @@ from tests.helpers import ( ) redis_session = None -redis_loop = asyncio.get_event_loop() def setUpModule(): # noqa: N802 """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) + asyncio.run(redis_session.connect()) def tearDownModule(): # noqa: N802 """Close the fakeredis session.""" if redis_session: - redis_loop.run_until_complete(redis_session.close()) + asyncio.run(redis_session.client.close()) # Have to subclass it because builtins can't be patched. -- cgit v1.2.3 From c906daa2250558962f00be1c423a9a0cff98f905 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:46:18 +0100 Subject: Remove warnings in error handler tests These warnings were caused by the setup coro from error_handler.py being imported directly, causing a warning about an un-awaited coro whenever the Cog was accessed from the same file. --- tests/bot/exts/backend/test_error_handler.py | 103 ++++++++++++--------------- 1 file changed, 47 insertions(+), 56 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 0a58126e7..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError from discord.ext.commands import errors from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock(return_value=False) + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() - cog.try_run_eval = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() - cog.try_run_eval.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - await setup(bot) + await error_handler.setup(bot) bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From 7782c196830098f81f39d235354636cd0d4a481d Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 23 Jul 2022 22:52:52 +0100 Subject: No longer use the removed RedisSession connection object This has been abstracted away, the correct way to do this now is to directly access the client. --- bot/exts/info/doc/_redis_cache.py | 92 ++++++++++++++--------------- tests/bot/exts/moderation/test_incidents.py | 5 +- 2 files changed, 46 insertions(+), 51 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py index 8e08e7ae4..0f4d663d1 100644 --- a/bot/exts/info/doc/_redis_cache.py +++ b/bot/exts/info/doc/_redis_cache.py @@ -34,55 +34,52 @@ class DocRedisCache(RedisObject): redis_key = f"{self.namespace}:{item_key(item)}" needs_expire = False - with await self._get_pool_connection() as connection: - set_expire = self._set_expires.get(redis_key) - if set_expire is None: - # An expire is only set if the key didn't exist before. - ttl = await connection.ttl(redis_key) - log.debug(f"Checked TTL for `{redis_key}`.") - - if ttl == -1: - log.warning(f"Key `{redis_key}` had no expire set.") - if ttl < 0: # not set or didn't exist - needs_expire = True - else: - log.debug(f"Key `{redis_key}` has a {ttl} TTL.") - self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - - elif time.monotonic() > set_expire: - # If we got here the key expired in redis and we can be sure it doesn't exist. + set_expire = self._set_expires.get(redis_key) + if set_expire is None: + # An expire is only set if the key didn't exist before. + ttl = await self.redis_session.client.ttl(redis_key) + log.debug(f"Checked TTL for `{redis_key}`.") + + if ttl == -1: + log.warning(f"Key `{redis_key}` had no expire set.") + if ttl < 0: # not set or didn't exist needs_expire = True - log.debug(f"Key `{redis_key}` expired in internal key cache.") + else: + log.debug(f"Key `{redis_key}` has a {ttl} TTL.") + self._set_expires[redis_key] = time.monotonic() + ttl - .1 # we need this to expire before redis - await connection.hset(redis_key, item.symbol_id, value) - if needs_expire: - self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS - await connection.expire(redis_key, WEEK_SECONDS) - log.info(f"Set {redis_key} to expire in a week.") + elif time.monotonic() > set_expire: + # If we got here the key expired in redis and we can be sure it doesn't exist. + needs_expire = True + log.debug(f"Key `{redis_key}` expired in internal key cache.") + + await self.redis_session.client.hset(redis_key, item.symbol_id, value) + if needs_expire: + self._set_expires[redis_key] = time.monotonic() + WEEK_SECONDS + await self.redis_session.client.expire(redis_key, WEEK_SECONDS) + log.info(f"Set {redis_key} to expire in a week.") @namespace_lock async def get(self, item: DocItem) -> Optional[str]: """Return the Markdown content of the symbol `item` if it exists.""" - with await self._get_pool_connection() as connection: - return await connection.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id, encoding="utf8") + return await self.redis_session.client.hget(f"{self.namespace}:{item_key(item)}", item.symbol_id) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" pattern = f"{self.namespace}:{package}:*" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=pattern) - ] - if package_keys: - await connection.delete(*package_keys) - log.info(f"Deleted keys from redis: {package_keys}.") - self._set_expires = { - key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) - } - return True - return False + package_keys = [ + package_key async for package_key in self.redis_session.client.iscan(match=pattern) + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + log.info(f"Deleted keys from redis: {package_keys}.") + self._set_expires = { + key: expire for key, expire in self._set_expires.items() if not fnmatch.fnmatchcase(key, pattern) + } + return True + return False class StaleItemCounter(RedisObject): @@ -96,21 +93,20 @@ class StaleItemCounter(RedisObject): If the counter didn't exist, initialize it with 1. """ key = f"{self.namespace}:{item_key(item)}:{item.symbol_id}" - with await self._get_pool_connection() as connection: - await connection.expire(key, WEEK_SECONDS * 3) - return int(await connection.incr(key)) + await self.redis_session.client.expire(key, WEEK_SECONDS * 3) + return int(await self.redis_session.client.incr(key)) @namespace_lock async def delete(self, package: str) -> bool: """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" - with await self._get_pool_connection() as connection: - package_keys = [ - package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") - ] - if package_keys: - await connection.delete(*package_keys) - return True - return False + package_keys = [ + package_key + async for package_key in self.redis_session.client.iscan(match=f"{self.namespace}:{package}:*") + ] + if package_keys: + await self.redis_session.client.delete(*package_keys) + return True + return False def item_key(item: DocItem) -> str: diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index cfe0c4b03..f60c177c5 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -283,8 +283,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def flush(self): """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() + await self.session.client.flushall() async def asyncSetUp(self): # noqa: N802 self.session = RedisSession(use_fakeredis=True) @@ -293,7 +292,7 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): async def asyncTearDown(self): # noqa: N802 if self.session: - await self.session.close() + await self.session.client.close() def setUp(self): """ -- cgit v1.2.3 From 46da1ecf621a64e6d8f0a37572378ae363ba76a2 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 25 Jul 2022 23:24:25 +0100 Subject: Stop creating futures in tests with no event loop running --- tests/bot/exts/moderation/test_silence.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 82ec138db..03b7b2fdb 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -250,8 +250,6 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -413,8 +411,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() @@ -686,8 +682,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache -- cgit v1.2.3 From 9cf3de3e9bf6725b2baa2e7adb77e058c216b332 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 26 Jul 2022 00:56:12 +0100 Subject: Remove unneeded N802 noqas pep-naming now supports these functions being in camel case. --- tests/bot/exts/moderation/test_incidents.py | 6 +++--- tests/bot/exts/moderation/test_silence.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index f60c177c5..211eb1bf8 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -285,12 +285,12 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): """Flush everything from the database to prevent carry-overs between tests.""" await self.session.client.flushall() - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): self.session = RedisSession(use_fakeredis=True) await self.session.connect() await self.flush() - async def asyncTearDown(self): # noqa: N802 + async def asyncTearDown(self): if self.session: await self.session.client.close() @@ -655,7 +655,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 03b7b2fdb..f5caefdca 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -18,14 +18,14 @@ from tests.helpers import ( redis_session = None -def setUpModule(): # noqa: N802 +def setUpModule(): """Create and connect to the fakeredis session.""" global redis_session redis_session = RedisSession(use_fakeredis=True) asyncio.run(redis_session.connect()) -def tearDownModule(): # noqa: N802 +def tearDownModule(): """Close the fakeredis session.""" if redis_session: asyncio.run(redis_session.client.close()) -- cgit v1.2.3 From 4a47c816641332fbb49f8c88c8a7720849cabf06 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:28:36 +0100 Subject: Add a new test helper for managing redis sessions This helper ensures that a fresh RedisSession is given to each test case that inherits from it. --- tests/base.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() -- cgit v1.2.3 From f044f36833e9dc003e89dd81868ea3f48a9da002 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 30 Jul 2022 19:29:27 +0100 Subject: Use RedisTestCase helper class for both Incidents and Silence test cases. --- tests/bot/exts/moderation/test_incidents.py | 19 ++----------------- tests/bot/exts/moderation/test_silence.py | 23 ++++------------------- 2 files changed, 6 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 211eb1bf8..97682163f 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,21 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - await self.session.client.flushall() - - async def asyncSetUp(self): - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): - if self.session: - await self.session.client.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index f5caefdca..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,30 +6,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None - - -def setUpModule(): - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - asyncio.run(redis_session.connect()) - - -def tearDownModule(): - """Close the fakeredis session.""" - if redis_session: - asyncio.run(redis_session.client.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -104,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase): """Tests for the general functionality of the Silence cog.""" @autospec(silence, "Scheduler", pass_mocks=False) @@ -244,7 +229,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase): """Tests for the silence argument parser utility function.""" def setUp(self): @@ -403,7 +388,7 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase): """Tests for the silence command and its related helper methods.""" @autospec(silence.Silence, "_reschedule", pass_mocks=False) -- cgit v1.2.3 From 1df6034a1723fd3ff1bd88047ca6a62f920767e6 Mon Sep 17 00:00:00 2001 From: Izan Date: Mon, 15 Aug 2022 12:04:33 +0100 Subject: Fix incident tests. --- tests/bot/exts/moderation/test_incidents.py | 38 +++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 11fe565fc..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -20,6 +20,8 @@ from tests.helpers import ( MockUser ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) + class MockAsyncIterable: """ @@ -102,25 +104,32 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_actioned(self): """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_green) self.assertIn("Actioned", embed.footer.text) async def test_make_embed_not_actioned(self): """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.NOT_ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_red) self.assertIn("Rejected", embed.footer.text) async def test_make_embed_content(self): """Incident content appears as embed description.""" - current_time = datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) - incident = MockMessage(content="this is an incident", created_at=current_time) + incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) - reported_timestamp = discord_timestamp(current_time) - relative_timestamp = discord_timestamp(current_time, TimestampFormats.RELATIVE) + reported_timestamp = discord_timestamp(CURRENT_TIME) + relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) @@ -133,7 +142,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): """Incident's attachment is downloaded and displayed in the embed's image field.""" file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return our `file` with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -145,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_fails(self): """Incident's attachment fails to download, proxy url is linked instead.""" attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return None as if the download failed with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -359,7 +368,6 @@ class TestCrawlIncidents(TestIncidents): class TestArchive(TestIncidents): """Tests for the `Incidents.archive` coroutine.""" - async def test_archive_webhook_not_found(self): """ Method recovers and returns False when the webhook is not found. @@ -369,7 +377,11 @@ class TestArchive(TestIncidents): """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + await self.cog_instance.archive( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=MagicMock(), + actioned_by=MockMember() + ) ) async def test_archive_relays_incident(self): @@ -416,7 +428,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -520,7 +532,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(author=mock_member, id=123), + incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), member=mock_member ) @@ -540,7 +552,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), + incident=MockMessage(id=123, created_at=CURRENT_TIME), member=MockMember(roles=[MockRole(id=1)]) ) except asyncio.TimeoutError: -- cgit v1.2.3 From e0b593318eba77d6fe93f2145b43838d6eb09278 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Mon, 15 Aug 2022 22:14:32 +0100 Subject: Correctly initialise redis tests Calling the cog_load from within the setUp function resulted in interaction with a RedisSession before it was initialised. This wasn't noticed in CI as it only error under certain concurrency timings due to xdist. To resolve this, we moved the setup and async setup logic to a base class. Co-authored-by: Hassan Abouelela --- tests/bot/exts/moderation/test_silence.py | 79 +++++++++++++++---------------- 1 file changed, 37 insertions(+), 42 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 98547e2bc..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio import itertools import unittest from datetime import datetime, timezone @@ -23,8 +22,24 @@ class PatchedDatetime(datetime): now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): + """A base class for Silence tests that correctly sets up the cog and redis.""" + + @autospec(silence, "Scheduler", pass_mocks=False) + @autospec(silence.Silence, "_reschedule", pass_mocks=False) + def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) + self.cog = silence.Silence(self.bot) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest): def setUp(self) -> None: + super().setUp() self.alert_channel = MockTextChannel() self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() @@ -89,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(RedisTestCase): +class SilenceCogTests(SilenceTest): """Tests for the general functionality of the Silence cog.""" - @autospec(silence, "Scheduler", pass_mocks=False) - def setUp(self) -> None: - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_cog_load_got_channels(self): """Got channels from bot.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @@ -229,13 +234,9 @@ class SilenceCogTests(RedisTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(RedisTestCase): +class SilenceArgumentParserTests(SilenceTest): """Tests for the silence argument parser utility function.""" - def setUp(self): - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @autospec(silence.Silence, "parse_silence_args") @@ -303,17 +304,19 @@ class SilenceArgumentParserTests(RedisTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase): """Tests for the rescheduling of cached unsilences.""" - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) - def setUp(self): + @autospec(silence, "Scheduler", pass_mocks=False) + def setUp(self) -> None: self.bot = MockBot() self.cog = silence.Silence(self.bot) self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) - with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -388,20 +391,14 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(RedisTestCase): +class SilenceTests(SilenceTest): """Tests for the silence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) + super().setUp() # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. - self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( send_messages=True, @@ -659,22 +656,13 @@ class SilenceTests(RedisTestCase): @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest): """Tests for the unsilence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - - overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) - self.cog.previous_overwrites = overwrites_cache - - asyncio.run(self.cog.cog_load()) # Populate instance attributes. + super().setUp() self.cog.scheduler.__contains__.return_value = True - overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -683,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) self.voice_channel.overwrites_for.return_value = self.voice_overwrite + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) -- cgit v1.2.3 From 0e6242f7f4c41329e9270724a9780511f7165240 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Thu, 28 Jul 2022 04:54:26 -0400 Subject: Updated tests - Refactored tests for new time duration arguments --- .../exts/moderation/infraction/test_infractions.py | 11 ++++---- tests/bot/exts/moderation/infraction/test_utils.py | 29 +++++++++++++--------- 2 files changed, 23 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..a18a4d23b 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -79,13 +79,13 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): """Should call voice mute applying function without expiry.""" self.cog.apply_voice_mute = AsyncMock() self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) - self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None) async def test_temporary_voice_mute(self): """Should call voice mute applying function with expiry.""" self.cog.apply_voice_mute = AsyncMock() self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) - self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz") async def test_voice_unmute(self): """Should call infraction pardoning function.""" @@ -189,7 +189,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): user = MockUser() await self.cog.voicemute(self.cog, self.ctx, user, reason=None) - post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) + post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, + duration_or_expiry=None) apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY) # Test action @@ -273,7 +274,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.user, "FooBar", purge_days=1, - expires_at=None, + duration_or_expiry=None, ) async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): @@ -285,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase): self.ctx, self.user, "FooBar", - expires_at=None, + duration_or_expiry=None, ) @patch("bot.exts.moderation.infraction.infractions.Age") diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5cf02033d..4c78c0bd8 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "user": self.member.id, "active": False, "expires_at": now.isoformat(), - "dm_sent": False + "dm_sent": False, + "last_applied": datetime(2020, 1, 1).isoformat(), } - self.ctx.bot.api_client.post.return_value = "foo" - actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) + # Patch the time.now(tz=timezone.utc) function to return a specific time + with patch("bot.exts.moderation.infraction._utils.datetime.now", return_value=datetime(2020, 1, 1)): + self.ctx.bot.api_client.post.return_value = "foo" + actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) - self.assertEqual(actual, "foo") - self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.assertEqual(actual, "foo") + self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) async def test_unknown_error_post_infraction(self): """Should send an error message to chat when a non-400 error occurs.""" @@ -356,12 +359,14 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "type": "mute", "user": self.user.id, "active": True, - "dm_sent": False + "dm_sent": False, + "last_applied": datetime(2020, 1, 1), } - self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] - - actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") - self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) - post_user_mock.assert_awaited_once_with(self.ctx, self.user) + # Patch the time.now(tz=timezone.utc) function to return a specific time + with patch("bot.exts.moderation.infraction._utils.datetime.now", return_value=datetime(2020, 1, 1)): + self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertEqual(actual, "foo") + self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From 035b5accf78f8623b40e0612e3c057f0ef2b93a7 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Thu, 28 Jul 2022 05:04:38 -0400 Subject: Fixed test patches --- tests/bot/exts/moderation/infraction/test_utils.py | 32 ++++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 4c78c0bd8..def06932b 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -307,7 +307,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - async def test_normal_post_infraction(self): + @patch("bot.exts.moderation.infraction._utils.datetime", wraps=datetime) + async def test_normal_post_infraction(self, mock_datetime): """Should return response from POST request if there are no errors.""" now = datetime.now() payload = { @@ -322,13 +323,13 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "last_applied": datetime(2020, 1, 1).isoformat(), } - # Patch the time.now(tz=timezone.utc) function to return a specific time - with patch("bot.exts.moderation.infraction._utils.datetime.now", return_value=datetime(2020, 1, 1)): - self.ctx.bot.api_client.post.return_value = "foo" - actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) + # Patch the datetime.now function to return a specific time + mock_datetime.now.return_value = datetime(2020, 1, 1) + self.ctx.bot.api_client.post.return_value = "foo" + actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) - self.assertEqual(actual, "foo") - self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.assertEqual(actual, "foo") + self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) async def test_unknown_error_post_infraction(self): """Should send an error message to chat when a non-400 error occurs.""" @@ -349,8 +350,9 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.assertIsNone(actual) post_user_mock.assert_awaited_once_with(self.ctx, self.user) + @patch("bot.exts.moderation.infraction._utils.datetime", wraps=datetime) @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar") - async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): + async def test_first_fail_second_success_user_post_infraction(self, post_user_mock, mock_datetime): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" payload = { "actor": self.ctx.author.id, @@ -363,10 +365,10 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "last_applied": datetime(2020, 1, 1), } - # Patch the time.now(tz=timezone.utc) function to return a specific time - with patch("bot.exts.moderation.infraction._utils.datetime.now", return_value=datetime(2020, 1, 1)): - self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] - actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") - self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) - post_user_mock.assert_awaited_once_with(self.ctx, self.user) + # Patch the datetime.now function to return a specific time + mock_datetime.now.return_value = datetime(2020, 1, 1) + self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] + actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") + self.assertEqual(actual, "foo") + self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From 8db7ef5df087e43804d07dabe2037af18adcb0d6 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Thu, 28 Jul 2022 05:19:30 -0400 Subject: Added isoformat for test payload --- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index def06932b..d3a908b28 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -362,7 +362,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "user": self.user.id, "active": True, "dm_sent": False, - "last_applied": datetime(2020, 1, 1), + "last_applied": datetime(2020, 1, 1).isoformat(), } # Patch the datetime.now function to return a specific time -- cgit v1.2.3 From 3280ac48a9031b727bdde69909729093593bd967 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Thu, 28 Jul 2022 06:11:11 -0400 Subject: Fixed tests - Corrected datetime patching --- tests/bot/exts/moderation/infraction/test_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index d3a908b28..b1f23e31c 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -307,8 +307,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) - @patch("bot.exts.moderation.infraction._utils.datetime", wraps=datetime) - async def test_normal_post_infraction(self, mock_datetime): + async def test_normal_post_infraction(self): """Should return response from POST request if there are no errors.""" now = datetime.now() payload = { @@ -320,16 +319,18 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "active": False, "expires_at": now.isoformat(), "dm_sent": False, - "last_applied": datetime(2020, 1, 1).isoformat(), } - # Patch the datetime.now function to return a specific time - mock_datetime.now.return_value = datetime(2020, 1, 1) self.ctx.bot.api_client.post.return_value = "foo" actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) self.assertEqual(actual, "foo") - self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) + self.ctx.bot.api_client.post.assert_awaited_once() + await_args = str(self.ctx.bot.api_client.post.await_args) + # Check existing keys present, allow for additional keys (e.g. `last_applied`) + for key, value in payload.items(): + self.assertTrue(key in await_args) + self.assertTrue(str(value) in await_args) async def test_unknown_error_post_infraction(self): """Should send an error message to chat when a non-400 error occurs.""" -- cgit v1.2.3 From e74ff62122935da349d38a8c06eef14d1e3ba9aa Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 6 Aug 2022 13:47:49 -0400 Subject: Refactored test to not use datetime patch - Used new method of dict subset comparison instead of datetime patching for better compat. with argument types --- tests/bot/exts/moderation/infraction/test_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index b1f23e31c..5ba0f4273 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -1,7 +1,7 @@ import unittest from collections import namedtuple from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch from botcore.site_api import ResponseCodeError from discord import Embed, Forbidden, HTTPException, NotFound @@ -351,11 +351,10 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.assertIsNone(actual) post_user_mock.assert_awaited_once_with(self.ctx, self.user) - @patch("bot.exts.moderation.infraction._utils.datetime", wraps=datetime) @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar") - async def test_first_fail_second_success_user_post_infraction(self, post_user_mock, mock_datetime): + async def test_first_fail_second_success_user_post_infraction(self, post_user_mock): """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" - payload = { + expected = { "actor": self.ctx.author.id, "hidden": False, "reason": "Test reason", @@ -363,13 +362,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): "user": self.user.id, "active": True, "dm_sent": False, - "last_applied": datetime(2020, 1, 1).isoformat(), } - # Patch the datetime.now function to return a specific time - mock_datetime.now.return_value = datetime(2020, 1, 1) self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason") self.assertEqual(actual, "foo") - self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) + await_args = self.bot.api_client.post.await_args_list + self.assertEqual(len(await_args), 2, "Expected 2 awaits") + + # Since `last_applied` is based on current time, just check if expected is a subset of payload + for args in await_args: + payload: dict = args.kwargs["json"] + self.assertEqual(payload, payload | expected) + post_user_mock.assert_awaited_once_with(self.ctx, self.user) -- cgit v1.2.3 From d93afa3dbe16404a60e23acaf394affcda5aff89 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 6 Aug 2022 13:57:51 -0400 Subject: Updated previous tests to use subset method --- tests/bot/exts/moderation/infraction/test_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5ba0f4273..6c9af2555 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -310,7 +310,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): async def test_normal_post_infraction(self): """Should return response from POST request if there are no errors.""" now = datetime.now() - payload = { + expected = { "actor": self.ctx.author.id, "hidden": True, "reason": "Test reason", @@ -323,14 +323,12 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): self.ctx.bot.api_client.post.return_value = "foo" actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) - self.assertEqual(actual, "foo") self.ctx.bot.api_client.post.assert_awaited_once() - await_args = str(self.ctx.bot.api_client.post.await_args) - # Check existing keys present, allow for additional keys (e.g. `last_applied`) - for key, value in payload.items(): - self.assertTrue(key in await_args) - self.assertTrue(str(value) in await_args) + + # Since `last_applied` is based on current time, just check if expected is a subset of payload + payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"] + self.assertEqual(payload, payload | expected) async def test_unknown_error_post_infraction(self): """Should send an error message to chat when a non-400 error occurs.""" -- cgit v1.2.3 From 08a6bbc306d1ec54998b690eeeb258548e8e08fa Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 13 Aug 2022 15:45:52 -0400 Subject: Corrected test use of utcnow Corrected test case to use `datetime.utcnow()` to be consistent with target --- tests/bot/exts/moderation/infraction/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 6c9af2555..29dadf372 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -309,7 +309,7 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase): async def test_normal_post_infraction(self): """Should return response from POST request if there are no errors.""" - now = datetime.now() + now = datetime.utcnow() expected = { "actor": self.ctx.author.id, "hidden": True, -- cgit v1.2.3 From dada405211eac996196cdfb0496f4ff22f9a656a Mon Sep 17 00:00:00 2001 From: arl Date: Thu, 18 Aug 2022 19:01:22 -0400 Subject: fix: don't include replied mentions in mention filter (#2017) Co-authored-by: Izan Co-authored-by: TizzySaurus <47674925+TizzySaurus@users.noreply.github.com> Co-authored-by: Xithrius <15021300+Xithrius@users.noreply.github.com> --- bot/rules/mentions.py | 56 +++++++++++++++++++++++++++++++++----- tests/bot/rules/test_mentions.py | 58 ++++++++++++++++++++++++++++++++++++---- tests/helpers.py | 22 +++++++++++++++ 3 files changed, 124 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 6f5addad1..ca1d0c01c 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,23 +1,65 @@ from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from discord import DeletedReferencedMessage, Member, Message, MessageType, NotFound + +import bot +from bot.log import get_logger + +log = get_logger(__name__) async def apply( last_message: Message, recent_messages: List[Message], config: Dict[str, int] ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - """Detects total mentions exceeding the limit sent by a single user.""" + """ + Detects total mentions exceeding the limit sent by a single user. + + Excludes mentions that are bots, themselves, or replied users. + + In very rare cases, may not be able to determine a + mention was to a reply, in which case it is not ignored. + """ relevant_messages = tuple( msg for msg in recent_messages if msg.author == last_message.author ) + # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned. + # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body. + # In order to exclude users who are mentioned as a reply, we check if the msg has a reference + # + # While we could use regex to parse the message content, and get a list of + # the mentions, that solution is very prone to breaking. + # We would need to deal with codeblocks, escaping markdown, and any discrepancies between + # our implementation and discord's markdown parser which would cause false positives or false negatives. + total_recent_mentions = 0 + for msg in relevant_messages: + # We check if the message is a reply, and if it is try to get the author + # since we ignore mentions of a user that we're replying to + reply_author = None - total_recent_mentions = sum( - not user.bot - for msg in relevant_messages - for user in msg.mentions - ) + if msg.type == MessageType.reply: + ref = msg.reference + + if not (resolved := ref.resolved): + # It is possible, in a very unusual situation, for a message to have a reference + # that is both not in the cache, and deleted while running this function. + # In such a situation, this will throw an error which we catch. + try: + resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message( + resolved.message_id + ) + except NotFound: + log.info('Could not fetch the reference message as it has been deleted.') + + if resolved and not isinstance(resolved, DeletedReferencedMessage): + reply_author = resolved.author + + for user in msg.mentions: + # Don't count bot or self mentions, or the user being replied to (if applicable) + if user.bot or user in {msg.author, reply_author}: + continue + total_recent_mentions += 1 if total_recent_mentions > config['max']: return ( diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index f8805ac48..e1f904917 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,15 +1,32 @@ -from typing import Iterable +from typing import Iterable, Optional + +import discord from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage +from tests.helpers import MockMember, MockMessage, MockMessageReference -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: - """Makes a message with `total_mentions` mentions.""" +def make_msg( + author: str, + total_user_mentions: int, + total_bot_mentions: int = 0, + *, + reference: Optional[MockMessageReference] = None +) -> MockMessage: + """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" user_mentions = [MockMember() for _ in range(total_user_mentions)] bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - return MockMessage(author=author, mentions=user_mentions+bot_mentions) + + mentions = user_mentions + bot_mentions + if reference is not None: + # For the sake of these tests we assume that all references are mentions. + mentions.append(reference.resolved.author) + msg_type = discord.MessageType.reply + else: + msg_type = discord.MessageType.default + + return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) class TestMentions(RuleTest): @@ -56,6 +73,16 @@ class TestMentions(RuleTest): ("bob",), 3, ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference())], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], + ("bob",), + 3 + ) ) await self.run_disallowed(cases) @@ -71,6 +98,27 @@ class TestMentions(RuleTest): await self.run_allowed(cases) + async def test_ignore_reply_mentions(self): + """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" + cases = ( + [ + make_msg("bob", 2, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) + ], + [ + make_msg("bob", 2, reference=MockMessageReference()), + make_msg("bob", 0, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), + make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) + ] + ) + + await self.run_allowed(cases) + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: last_message = case.recent_messages[0] return tuple( diff --git a/tests/helpers.py b/tests/helpers.py index 17214553c..687e15b96 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): spec_set = attachment_instance +message_reference_instance = discord.MessageReference( + message_id=unittest.mock.MagicMock(id=1), + channel_id=unittest.mock.MagicMock(id=2), + guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock MessageReference objects. + + Instances of this class will follow the specification of `discord.MessageReference` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = message_reference_instance + + def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): + super().__init__(**kwargs) + referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) + self.resolved = MockMessage(author=referenced_msg_author) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. -- cgit v1.2.3 From 8028690d4eae27e57dfe1429b8067eabaa94eef9 Mon Sep 17 00:00:00 2001 From: Aleksey Zasorin Date: Fri, 16 Sep 2022 10:42:16 -0700 Subject: Removed "redis_ready" from additional_spec_asyncs in MockBot (#2275) The attribute was removed from Bot in fc05849 --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 687e15b96..a4b919dcb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -317,7 +317,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): guild_id=1, intents=discord.Intents.all(), ) - additional_spec_asyncs = ("wait_for", "redis_ready") + additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) -- cgit v1.2.3