diff options
author | 2022-08-14 19:59:33 +0100 | |
---|---|---|
committer | 2022-08-14 19:59:33 +0100 | |
commit | c206138d27954f3692ff22f9bd94acaf8e81b06a (patch) | |
tree | 56a2ddbce449ec5cb022c8420e23052a65fe5ad0 /tests | |
parent | Address Reviews (diff) | |
parent | Merge pull request #2229 from python-discord/py3.10-rediscache (diff) |
Merge branch 'main' into incident-archive-msg-improvements
Diffstat (limited to 'tests')
22 files changed, 603 insertions, 447 deletions
diff --git a/tests/base.py b/tests/base.py index 5e304ea9d..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Dict import discord +from async_rediscache import RedisSession from discord.ext import commands from bot.log import get_logger @@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 9dc46005b..a17c1fa10 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -1,7 +1,8 @@ import unittest from unittest import mock -from bot.api import ResponseCodeError +from botcore.site_api import ResponseCodeError + from bot.exts.backend.sync._syncers import Syncer from tests import helpers diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index fdd0ab74a..87b76c6b4 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -2,9 +2,9 @@ import unittest from unittest import mock import discord +from botcore.site_api import ResponseCodeError from bot import constants -from bot.api import ResponseCodeError from bot.exts.backend import sync from bot.exts.backend.sync._cog import Sync from bot.exts.backend.sync._syncers import Syncer @@ -16,11 +16,11 @@ class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod - def test_extension_setup(): + async def test_extension_setup(): """The Sync cog should be added.""" bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() + await sync.setup(bot) + bot.add_cog.assert_awaited_once() class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): @@ -60,22 +60,18 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("bot.utils.scheduling.create_task") - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild, create_task): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. + async def test_sync_cog_sync_on_load(self): + """Roles and users should be synced on cog load.""" + guild = helpers.MockGuild() + self.bot.get_guild = mock.MagicMock(return_value=guild) + self.RoleSyncer.reset_mock() self.UserSyncer.reset_mock() - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) + await self.cog.cog_load() - sync_guild.assert_called_once_with() - create_task.assert_called_once() - self.assertEqual(create_task.call_args.args[0], mock_sync_guild_coro) + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" @@ -87,7 +83,7 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild = mock.MagicMock(return_value=guild) - await self.cog.sync_guild() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_called_once() self.bot.get_guild.assert_called_once_with(constants.Guild.id) diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 35fa0ee59..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,11 +1,11 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord.ext.commands import errors -from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,27 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -73,57 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -138,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -148,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -178,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -199,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -331,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -396,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -477,11 +471,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.backend.error_handler.log") async def test_handle_api_error(self, log_mock): - """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" + """Should `ctx.send` on HTTP error codes, and log at correct level.""" test_cases = ( { "error": ResponseCodeError(AsyncMock(status=400)), - "log_level": "debug" + "log_level": "error" }, { "error": ResponseCodeError(AsyncMock(status=404)), @@ -505,6 +499,8 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_awaited_once() if case["log_level"] == "warning": log_mock.warning.assert_called_once() + elif case["log_level"] == "error": + log_mock.error.assert_called_once() else: log_mock.debug.assert_called_once() @@ -544,11 +540,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): push_scope_mock.set_extra.has_calls(set_extra_calls) -class ErrorHandlerSetupTests(unittest.TestCase): +class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): """Tests for `ErrorHandler` `setup` function.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - setup(bot) - bot.add_cog.assert_called_once() + await error_handler.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index 0856546af..684f7abcd 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -160,11 +160,11 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase): member.add_roles.assert_not_awaited() -class CodeJamSetup(unittest.TestCase): +class CodeJamSetup(unittest.IsolatedAsyncioTestCase): """Test for `setup` function of `CodeJam` cog.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog`.""" bot = MockBot() - code_jams.setup(bot) - bot.add_cog.assert_called_once() + await code_jams.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 06d78de9d..7282334e2 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -192,11 +192,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -class AntiMalwareSetupTests(unittest.TestCase): +class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `AntiMalware` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() + await antimalware.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py index 8ae59c1f1..bd26532f1 100644 --- a/tests/bot/exts/filters/test_filtering.py +++ b/tests/bot/exts/filters/test_filtering.py @@ -11,7 +11,7 @@ class FilteringCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Instantiate the bot and cog.""" self.bot = MockBot() - with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()): + with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): self.cog = filtering.Filtering(self.bot) @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index c0c3baa42..007b7b1eb 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,5 +1,4 @@ import unittest -from unittest.mock import MagicMock from discord.ext.commands import NoPrivateMessage @@ -44,11 +43,11 @@ class SecurityCogTests(unittest.TestCase): self.assertTrue(self.cog.check_on_guild(self.ctx)) -class SecurityCogLoadTests(unittest.TestCase): +class SecurityCogLoadTests(unittest.IsolatedAsyncioTestCase): """Tests loading the `Security` cog.""" - def test_security_cog_load(self): + async def test_security_cog_load(self): """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() + bot = MockBot() + await security.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index 4db27269a..c1f3762ac 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -395,15 +395,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.channel.send.assert_not_awaited() -class TokenRemoverExtensionTests(unittest.TestCase): +class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the token_remover extension.""" @autospec("bot.exts.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): + async def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - token_remover.setup(bot) + await token_remover.setup(bot) cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() + bot.add_cog.assert_awaited_once() self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 604c69671..2644ae40d 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -12,7 +12,6 @@ class HelpCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = help.Help(self.bot) self.ctx = MockContext(bot=self.bot) - self.bot.help_command.context = self.ctx @autospec(help.CustomHelpCommand, "get_all_help_choices", return_value={"help"}, pass_mocks=False) async def test_help_fuzzy_matching(self): diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 724456b04..d896b7652 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -1,6 +1,7 @@ import textwrap import unittest import unittest.mock +from datetime import datetime import discord @@ -288,8 +289,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.nick = None user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 + user.created_at = user.joined_at = datetime.utcnow() - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.title, "Mr. Hemlock") @@ -309,8 +311,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.nick = "Cat lover" user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") user.colour = 0 + user.created_at = user.joined_at = datetime.utcnow() - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") @@ -329,8 +332,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): # A `MockMember` has the @Everyone role by default; we add the Admins to that. user = helpers.MockMember(roles=[admins_role], colour=100) + user.created_at = user.joined_at = datetime.utcnow() - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertIn("&Admins", embed.fields[1].value) self.assertNotIn("&Everyone", embed.fields[1].value) @@ -355,7 +359,8 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + user.created_at = user.joined_at = datetime.utcfromtimestamp(1) + embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) nomination_counts.assert_called_once_with(user) @@ -394,7 +399,8 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user_messages.return_value = ("Messages", "user message counts") user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + user.created_at = user.joined_at = datetime.utcfromtimestamp(1) + embed = await self.cog.create_user_embed(ctx, user, False) infraction_counts.assert_called_once_with(user) @@ -440,7 +446,8 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): moderators_role = helpers.MockRole(name='Moderators') user = helpers.MockMember(id=314, roles=[moderators_role], colour=100) - embed = await self.cog.create_user_embed(ctx, user) + user.created_at = user.joined_at = datetime.utcnow() + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour(100)) @@ -457,7 +464,8 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ctx = helpers.MockContext() user = helpers.MockMember(id=217, colour=discord.Colour.default()) - embed = await self.cog.create_user_embed(ctx, user) + user.created_at = user.joined_at = datetime.utcnow() + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.colour, discord.Colour.og_blurple()) @@ -474,8 +482,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): ctx = helpers.MockContext() user = helpers.MockMember(id=217, colour=0) + user.created_at = user.joined_at = datetime.utcnow() user.display_avatar.url = "avatar url" - embed = await self.cog.create_user_embed(ctx, user) + embed = await self.cog.create_user_embed(ctx, user, False) self.assertEqual(embed.thumbnail.url, "avatar url") @@ -528,7 +537,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx) - create_embed.assert_called_once_with(ctx, self.author) + create_embed.assert_called_once_with(ctx, self.author, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -539,7 +548,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx, self.author) - create_embed.assert_called_once_with(ctx, self.author) + create_embed.assert_called_once_with(ctx, self.author, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -550,7 +559,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx) - create_embed.assert_called_once_with(ctx, self.moderator) + create_embed.assert_called_once_with(ctx, self.moderator, False) ctx.send.assert_called_once() @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") @@ -562,5 +571,5 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): await self.cog.user_info(self.cog, ctx, self.target) - create_embed.assert_called_once_with(ctx, self.target) + create_embed.assert_called_once_with(ctx, self.target, False) ctx.send.assert_called_once() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 4d01e18a5..052048053 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,13 +1,15 @@ import inspect import textwrap import unittest -from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch from discord.errors import NotFound from bot.constants import Event +from bot.exts.moderation.clean import Clean from bot.exts.moderation.infraction import _utils from bot.exts.moderation.infraction.infractions import Infractions +from bot.exts.moderation.infraction.management import ModManagement from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockUser, autospec @@ -62,8 +64,8 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) -class VoiceBanTests(unittest.IsolatedAsyncioTestCase): - """Tests for voice ban related functions and commands.""" +class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): + """Tests for voice mute related functions and commands.""" def setUp(self): self.bot = MockBot() @@ -73,59 +75,59 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): self.ctx = MockContext(bot=self.bot, author=self.mod) self.cog = Infractions(self.bot) - async def test_permanent_voice_ban(self): - """Should call voice ban applying function without expiry.""" - self.cog.apply_voice_ban = AsyncMock() - self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) - self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) + async def test_permanent_voice_mute(self): + """Should call voice mute applying function without expiry.""" + self.cog.apply_voice_mute = AsyncMock() + self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) - async def test_temporary_voice_ban(self): - """Should call voice ban applying function with expiry.""" - self.cog.apply_voice_ban = AsyncMock() - self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) - self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + async def test_temporary_voice_mute(self): + """Should call voice mute applying function with expiry.""" + self.cog.apply_voice_mute = AsyncMock() + self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) + self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") - async def test_voice_unban(self): + async def test_voice_unmute(self): """Should call infraction pardoning function.""" self.cog.pardon_infraction = AsyncMock() - self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) - self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user)) + self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): - """Should return early when user already have Voice Ban infraction.""" + async def test_voice_mute_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): + """Should return early when user already have Voice Mute infraction.""" get_active_infraction.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) - get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) + get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_mute") post_infraction_mock.assert_not_awaited() @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_infraction_post_failed(self, get_active_infraction, post_infraction_mock): """Should return early when posting infraction fails.""" self.cog.mod_log.ignore = MagicMock() get_active_infraction.return_value = None post_infraction_mock.return_value = None - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) post_infraction_mock.assert_awaited_once() self.cog.mod_log.ignore.assert_not_called() @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): - """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" + async def test_voice_mute_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): + """Should pass all kwargs passed to apply_voice_mute to post_infraction.""" get_active_infraction.return_value = None # We don't want that this continue yet post_infraction_mock.return_value = None - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar", my_kwarg=23)) post_infraction_mock.assert_awaited_once_with( - self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 + self.ctx, self.user, "voice_mute", "foobar", active=True, my_kwarg=23 ) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_mod_log_ignore(self, get_active_infraction, post_infraction_mock): """Should ignore Voice Verified role removing.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() @@ -134,11 +136,11 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): get_active_infraction.return_value = None post_infraction_mock.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar")) self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) async def action_tester(self, action, reason: str) -> None: - """Helper method to test voice ban action.""" + """Helper method to test voice mute action.""" self.assertTrue(inspect.iscoroutine(action)) await action @@ -147,7 +149,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): + async def test_voice_mute_apply_infraction(self, get_active_infraction, post_infraction_mock): """Should ignore Voice Verified role removing.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() @@ -156,22 +158,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): post_infraction_mock.return_value = {"foo": "bar"} reason = "foobar" - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, reason)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, reason)) self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY) await self.action_tester(self.cog.apply_infraction.call_args[0][-1], reason) @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") - async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): - """Should truncate reason for voice ban.""" + async def test_voice_mute_truncate_reason(self, get_active_infraction, post_infraction_mock): + """Should truncate reason for voice mute.""" self.cog.mod_log.ignore = MagicMock() self.cog.apply_infraction = AsyncMock() get_active_infraction.return_value = None post_infraction_mock.return_value = {"foo": "bar"} - self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) + self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar" * 3000)) self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY) # Test action @@ -180,14 +182,14 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @autospec(_utils, "post_infraction", "get_active_infraction", return_value=None) @autospec(Infractions, "apply_infraction") - async def test_voice_ban_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _): - """Should voice ban user that left the guild without throwing an error.""" + async def test_voice_mute_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _): + """Should voice mute user that left the guild without throwing an error.""" infraction = {"foo": "bar"} post_infraction_mock.return_value = {"foo": "bar"} user = MockUser() - await self.cog.voiceban(self.cog, self.ctx, user, reason=None) - post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_ban", None, active=True, expires_at=None) + await self.cog.voicemute(self.cog, self.ctx, user, reason=None) + post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY) # Test action @@ -195,22 +197,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(inspect.iscoroutine(action)) await action - async def test_voice_unban_user_not_found(self): + async def test_voice_unmute_user_not_found(self): """Should include info to return dict when user was not found from guild.""" self.guild.get_member.return_value = None self.guild.fetch_member.side_effect = NotFound(Mock(status=404), "Not found") - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, {"Info": "User was not found in the guild."}) @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @patch("bot.exts.moderation.infraction.infractions.format_user") - async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): + async def test_voice_unmute_user_found(self, format_user_mock, notify_pardon_mock): """Should add role back with ignoring, notify user and return log dictionary..""" self.guild.get_member.return_value = self.user notify_pardon_mock.return_value = True format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "Sent" @@ -219,15 +221,100 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @patch("bot.exts.moderation.infraction.infractions.format_user") - async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): + async def test_voice_unmute_dm_fail(self, format_user_mock, notify_pardon_mock): """Should add role back with ignoring, notify user and return log dictionary..""" self.guild.get_member.return_value = self.user notify_pardon_mock.return_value = False format_user_mock.return_value = "my-user" - result = await self.cog.pardon_voice_ban(self.user.id, self.guild) + result = await self.cog.pardon_voice_mute(self.user.id, self.guild) self.assertEqual(result, { "Member": "my-user", "DM": "**Failed**" }) notify_pardon_mock.assert_awaited_once() + + +class CleanBanTests(unittest.IsolatedAsyncioTestCase): + """Tests for cleanban functionality.""" + + def setUp(self): + self.bot = MockBot() + self.mod = MockMember(roles=[MockRole(id=7890123, position=10)]) + self.user = MockMember(roles=[MockRole(id=123456, position=1)]) + self.guild = MockGuild() + self.ctx = MockContext(bot=self.bot, author=self.mod) + self.cog = Infractions(self.bot) + self.clean_cog = Clean(self.bot) + self.management_cog = ModManagement(self.bot) + + self.cog.apply_ban = AsyncMock(return_value={"id": 42}) + self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + self.clean_cog._clean_messages = AsyncMock(return_value=self.log_url) + + def mock_get_cog(self, enable_clean, enable_manage): + """Mock get cog factory that allows the user to specify whether clean and manage cogs are enabled.""" + def inner(name): + if name == "ModManagement": + return self.management_cog if enable_manage else None + elif name == "Clean": + return self.clean_cog if enable_clean else None + else: + return DEFAULT + return inner + + async def test_cleanban_falls_back_to_native_purge_without_clean_cog(self): + """Should fallback to native purge if the Clean cog is not available.""" + self.bot.get_cog.side_effect = self.mock_get_cog(False, False) + + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + self.cog.apply_ban.assert_awaited_once_with( + self.ctx, + self.user, + "FooBar", + purge_days=1, + expires_at=None, + ) + + async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): + """Cleanban command should use the native purge messages if the clean cog is available.""" + self.bot.get_cog.side_effect = self.mock_get_cog(True, False) + + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + self.cog.apply_ban.assert_awaited_once_with( + self.ctx, + self.user, + "FooBar", + expires_at=None, + ) + + @patch("bot.exts.moderation.infraction.infractions.Age") + async def test_cleanban_uses_clean_cog_when_available(self, mocked_age_converter): + """Test cleanban uses the clean cog to clean messages if it's available.""" + self.bot.api_client.patch = AsyncMock() + self.bot.get_cog.side_effect = self.mock_get_cog(True, False) + + mocked_age_converter.return_value.convert = AsyncMock(return_value="81M") + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + + self.clean_cog._clean_messages.assert_awaited_once_with( + self.ctx, + users=[self.user], + channels="*", + first_limit="81M", + attempt_delete_invocation=False, + ) + + async def test_cleanban_edits_infraction_reason(self): + """Ensure cleanban edits the ban reason with a link to the clean log.""" + self.bot.get_cog.side_effect = self.mock_get_cog(True, True) + + self.management_cog.infraction_append = AsyncMock() + self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar")) + + self.management_cog.infraction_append.assert_awaited_once_with( + self.ctx, + {"id": 42}, + None, + reason=f"[Clean log]({self.log_url})" + ) diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 350274ecd..5cf02033d 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,9 +3,9 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch +from botcore.site_api import ResponseCodeError from discord import Embed, Forbidden, HTTPException, NotFound -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.exts.moderation.infraction import _utils as utils from tests.helpers import MockBot, MockContext, MockMember, MockUser @@ -15,7 +15,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """Tests Moderation utils.""" def setUp(self): - self.bot = MockBot() + patcher = patch("bot.instance", new=MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) + self.member = MockMember(id=1234) self.user = MockUser(id=1234) self.ctx = MockContext(bot=self.bot, author=self.member) @@ -123,8 +126,9 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): else: self.ctx.send.assert_not_awaited() + @unittest.skip("Current time needs to be patched so infraction duration is correct.") @patch("bot.exts.moderation.infraction._utils.send_private_embed") - async def test_notify_infraction(self, send_private_embed_mock): + async def test_send_infraction_embed(self, send_private_embed_mock): """ Should send an embed of a certain format as a DM and return `True` if DM successful. @@ -132,7 +136,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): """ test_cases = [ { - "args": (self.bot, self.user, 0, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"), + "args": (dict(id=0, type="ban", reason=None, expires_at=datetime(2020, 2, 26, 9, 20)), self.user), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -145,12 +149,12 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ).set_author( name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, - icon_url=Icons.token_removed + icon_url=Icons.user_ban ), "send_result": True }, { - "args": (self.bot, self.user, 0, "warning", None, "Test reason."), + "args": (dict(id=0, type="warning", reason="Test reason.", expires_at=None), self.user), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -163,14 +167,14 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ).set_author( name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, - icon_url=Icons.token_removed + icon_url=Icons.user_warn ), "send_result": False }, # Note that this test case asserts that the DM that *would* get sent to the user is formatted # correctly, even though that message is deliberately never sent. { - "args": (self.bot, self.user, 0, "note", None, None, Icons.defcon_denied), + "args": (dict(id=0, type="note", reason=None, expires_at=None), self.user), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -183,20 +187,12 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ).set_author( name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, - icon_url=Icons.defcon_denied + icon_url=Icons.user_warn ), "send_result": False }, { - "args": ( - self.bot, - self.user, - 0, - "mute", - "2020-02-26 09:20 (23 hours and 59 minutes)", - "Test", - Icons.defcon_denied - ), + "args": (dict(id=0, type="mute", reason="Test", expires_at=datetime(2020, 2, 26, 9, 20)), self.user), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -209,12 +205,12 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ).set_author( name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, - icon_url=Icons.defcon_denied + icon_url=Icons.user_mute ), "send_result": False }, { - "args": (self.bot, self.user, 0, "mute", None, "foo bar" * 4000, Icons.defcon_denied), + "args": (dict(id=0, type="mute", reason="foo bar" * 4000, expires_at=None), self.user), "expected_output": Embed( title=utils.INFRACTION_TITLE, description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format( @@ -227,7 +223,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase): ).set_author( name=utils.INFRACTION_AUTHOR_NAME, url=utils.RULES_URL, - icon_url=Icons.defcon_denied + icon_url=Icons.user_mute ), "send_result": True } diff --git a/tests/bot/exts/moderation/test_clean.py b/tests/bot/exts/moderation/test_clean.py new file mode 100644 index 000000000..d7647fa48 --- /dev/null +++ b/tests/bot/exts/moderation/test_clean.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bot.exts.moderation.clean import Clean +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockMessage, MockRole, MockTextChannel + + +class CleanTests(unittest.IsolatedAsyncioTestCase): + """Tests for clean cog functionality.""" + + def setUp(self): + self.bot = MockBot() + self.mod = MockMember(roles=[MockRole(id=7890123, position=10)]) + self.user = MockMember(roles=[MockRole(id=123456, position=1)]) + self.guild = MockGuild() + self.ctx = MockContext(bot=self.bot, author=self.mod) + self.cog = Clean(self.bot) + + self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + self.cog._modlog_cleaned_messages = AsyncMock(return_value=self.log_url) + + self.cog._use_cache = MagicMock(return_value=True) + self.cog._delete_found = AsyncMock(return_value=[42, 84]) + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_deletes_invocation_in_non_mod_channel(self, mod_channel_check): + """Clean command should delete the invocation message if ran in a non mod channel.""" + mod_channel_check.return_value = False + self.ctx.message.delete = AsyncMock() + + self.assertIsNone(await self.cog._delete_invocation(self.ctx)) + + self.ctx.message.delete.assert_awaited_once() + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_doesnt_delete_invocation_in_mod_channel(self, mod_channel_check): + """Clean command should not delete the invocation message if ran in a mod channel.""" + mod_channel_check.return_value = True + self.ctx.message.delete = AsyncMock() + + self.assertIsNone(await self.cog._delete_invocation(self.ctx)) + + self.ctx.message.delete.assert_not_awaited() + + async def test_clean_doesnt_attempt_deletion_when_attempt_delete_invocation_is_false(self): + """Clean command should not attempt to delete the invocation message if attempt_delete_invocation is false.""" + self.cog._delete_invocation = AsyncMock() + self.bot.get_channel = MagicMock(return_value=False) + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + self.cog._delete_invocation.assert_not_awaited() + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_replies_with_success_message_when_ran_in_mod_channel(self, mod_channel_check): + """Clean command should reply to the message with a confirmation message if invoked in a mod channel.""" + mod_channel_check.return_value = True + self.ctx.reply = AsyncMock() + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + self.ctx.reply.assert_awaited_once() + sent_message = self.ctx.reply.await_args[0][0] + self.assertIn(self.log_url, sent_message) + self.assertIn("2 messages", sent_message) + + @patch("bot.exts.moderation.clean.is_mod_channel") + async def test_clean_send_success_message_to_mods_when_ran_in_non_mod_channel(self, mod_channel_check): + """Clean command should send a confirmation message to #mods if invoked in a non-mod channel.""" + mod_channel_check.return_value = False + mocked_mods = MockTextChannel(id=1234567) + mocked_mods.send = AsyncMock() + self.bot.get_channel = MagicMock(return_value=mocked_mods) + + self.assertEqual( + await self.cog._clean_messages( + self.ctx, + None, + first_limit=MockMessage(), + attempt_delete_invocation=False, + ), + self.log_url, + ) + + mocked_mods.send.assert_awaited_once() + sent_message = mocked_mods.send.await_args[0][0] + self.assertIn(self.log_url, sent_message) + self.assertIn("2 messages", sent_message) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index da0a79ce8..11fe565fc 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -9,12 +9,12 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp import discord -from async_rediscache import RedisSession from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user from bot.utils.time import TimestampFormats, discord_timestamp +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser @@ -280,7 +280,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -289,22 +289,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() - - async def asyncSetUp(self): # noqa: N802 - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): # noqa: N802 - if self.session: - await self.session.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. @@ -667,7 +651,7 @@ class TestOnRawReactionAdd(TestIncidents): emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 92ce3418a..98547e2bc 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -6,31 +6,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule(): # noqa: N802 - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule(): # noqa: N802 - """Close the fakeredis session.""" - if redis_session: - redis_loop.run_until_complete(redis_session.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -105,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(RedisTestCase): """Tests for the general functionality of the Silence cog.""" @autospec(silence, "Scheduler", pass_mocks=False) @@ -114,44 +98,36 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.cog = silence.Silence(self.bot) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_guild(self): + async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog._async_init() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_channels(self): + async def test_cog_load_got_channels(self): """Got channels from bot.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def test_async_init_got_notifier(self, notifier): + async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - await self.cog._async_init() + await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_rescheduled(self): + async def testcog_load_rescheduled(self): """`_reschedule_` coroutine was awaited.""" self.cog._reschedule = mock.create_autospec(self.cog._reschedule) - await self.cog._async_init() + await self.cog.cog_load() self.cog._reschedule.assert_awaited_once_with() - def test_cog_unload_cancelled_tasks(self): - """The init task was cancelled.""" - self.cog._init_task = asyncio.Future() - self.cog.cog_unload() - - # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. - self.assertTrue(self.cog._init_task.cancelled()) - @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): @@ -165,7 +141,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync(self): """Tests the _force_voice_sync helper function.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -187,7 +163,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync_no_channel(self): """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" - await self.cog._async_init() + await self.cog.cog_load() channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) new_channel = MockVoiceChannel(delete=AsyncMock()) @@ -206,7 +182,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_voice_kick(self): """Test to ensure kick function can remove all members from a voice channel.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -236,7 +212,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_kick_move_to_error(self): """Test to ensure move_to gets called on all members during kick, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() _, members = self.create_erroneous_members() await self.cog._kick_voice_members(MockVoiceChannel(members=members)) @@ -245,7 +221,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_sync_move_to_error(self): """Test to ensure move_to gets called on all members during sync, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() failing_member, members = self.create_erroneous_members() await self.cog._force_voice_sync(MockVoiceChannel(members=members)) @@ -253,14 +229,12 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(RedisTestCase): """Tests for the silence argument parser utility function.""" def setUp(self): self.bot = MockBot() self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @@ -339,7 +313,7 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase): self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -414,7 +388,7 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(RedisTestCase): """Tests for the silence command and its related helper methods.""" @autospec(silence.Silence, "_reschedule", pass_mocks=False) @@ -422,13 +396,11 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( @@ -695,13 +667,11 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) self.cog.previous_overwrites = overwrites_cache - asyncio.run(self.cog._async_init()) # Populate instance attributes. + asyncio.run(self.cog.cog_load()) # Populate instance attributes. self.cog.scheduler.__contains__.return_value = True overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 321a92445..b1f32c210 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,12 +2,14 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch +from discord import AllowedMentions from discord.ext import commands from bot import constants +from bot.errors import LockedResourceError from bot.exts.utils import snekbox from bot.exts.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser +from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -16,7 +18,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) - async def test_post_eval(self): + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") @@ -25,7 +27,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_eval("import random"), "return") + self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -34,17 +36,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): resp.json.assert_awaited_once() async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + """Reject output longer than MAX_PASTE_LENGTH.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") @patch("bot.exts.utils.snekbox.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt") + mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) - def test_prepare_input(self): + async def test_codeblock_converter(self): + ctx = MockContext() cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -60,33 +63,50 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) + + def test_prepare_timeit_input(self): + """Test the prepare_timeit_input codeblock detection.""" + base_args = ('-m', 'timeit', '-s') + cases = ( + (['print("Hello World")'], '', 'single block of code'), + (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), + (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') + ) + + for case, setup_code, testname in cases: + setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127', '') ) @patch('bot.exts.utils.snekbox.Signals') def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') ) def test_get_status_emoji(self): @@ -155,156 +175,181 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=None) + ctx.command = MagicMock() + + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_once_with(ctx, response) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) - - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_with(ctx, response) + ctx.command = MagicMock() + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock() + self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_with( + ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' + ) + self.cog.continue_job.assert_called_with(ctx, response, 'eval') async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" ctx = MockContext() ctx.author.id = 42 - ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.send = AsyncMock() - self.cog.jobs = (42,) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - ctx.send.assert_called_once_with( - "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" - ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) + async def delay_with_side_effect(*args, **kwargs) -> dict: + """Delay the post_job call to ensure the job runs long enough to conflict.""" + await asyncio.sleep(1) + return {'stdout': '', 'returncode': 0} + + self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) + with self.assertRaises(LockedResourceError): + await asyncio.gather( + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + ) - async def test_send_eval(self): - """Test the send_eval function.""" + async def test_send_job(self): + """Test the send_job function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() - ctx.author.mention = '@LemonLemonishBeard#0042' + ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] + expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) + + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') self.cog.format_output.assert_called_once_with('') - async def test_send_eval_with_paste_link(self): - """Test the send_eval function with a too long output that generate a paste link.""" + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :yay!: Return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with( + {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' + ) self.cog.format_output.assert_called_once_with('Way too long beard') - async def test_send_eval_with_non_zero_eval(self): - """Test the send_eval function with a code returning a non-zero code.""" + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') - ctx.send.assert_called_once_with( + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") - async def test_continue_eval_does_continue(self, partial_mock): - """Test that the continue_eval function does continue if required conditions are met.""" - ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) - response = MockMessage(delete=AsyncMock()) + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" + ctx = MockContext( + message=MockMessage( + id=4, + add_reaction=AsyncMock(), + clear_reactions=AsyncMock() + ), + author=MockMember(id=14) + ) + response = MockMessage(id=42, delete=AsyncMock()) new_msg = MockMessage() + self.cog.jobs = {4: 42} self.bot.wait_for.side_effect = ((None, new_msg), None) expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) - self.assertEqual(actual, expected) + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) + self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_eval_message_edit, ctx), - timeout=snekbox.REEVAL_TIMEOUT, + check=partial_mock(snekbox.predicate_message_edit, ctx), + timeout=snekbox.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) response.delete.assert_called_once() - async def test_continue_eval_does_not_continue(self): + async def test_continue_job_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage()) - self.assertEqual(actual, None) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, (None, None)) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -327,13 +372,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) - def test_predicate_eval_message_edit(self): - """Test the predicate_eval_message_edit function.""" + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" msg0 = MockMessage(id=1, content='abc') msg1 = MockMessage(id=2, content='abcdef') msg2 = MockMessage(id=1, content='abcdef') @@ -346,18 +391,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) - def test_predicate_eval_emoji_reaction(self): - """Test the predicate_eval_emoji_reaction function.""" + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_reaction.__str__.return_value = snekbox.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -370,15 +415,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) -class SnekboxSetupTests(unittest.TestCase): +class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `Snekbox` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - snekbox.setup(bot) - bot.add_cog.assert_called_once() + await snekbox.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py deleted file mode 100644 index 76bcb481d..000000000 --- a/tests/bot/test_api.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from bot import api - - -class APIClientTests(unittest.IsolatedAsyncioTestCase): - """Tests for the bot's API client.""" - - @classmethod - def setUpClass(cls): - """Sets up the shared fixtures for the tests.""" - cls.error_api_response = MagicMock() - cls.error_api_response.status = 999 - - def test_response_code_error_default_initialization(self): - """Test the default initialization of `ResponseCodeError` without `text` or `json`""" - error = api.ResponseCodeError(response=self.error_api_response) - - self.assertIs(error.status, self.error_api_response.status) - self.assertEqual(error.response_json, {}) - self.assertEqual(error.response_text, "") - self.assertIs(error.response, self.error_api_response) - - def test_response_code_error_string_representation_default_initialization(self): - """Test the string representation of `ResponseCodeError` initialized without text or json.""" - error = api.ResponseCodeError(response=self.error_api_response) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") - - def test_response_code_error_initialization_with_json(self): - """Test the initialization of `ResponseCodeError` with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data, - ) - self.assertEqual(error.response_json, json_data) - self.assertEqual(error.response_text, "") - - def test_response_code_error_string_representation_with_nonempty_response_json(self): - """Test the string representation of `ResponseCodeError` initialized with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") - - def test_response_code_error_initialization_with_text(self): - """Test the initialization of `ResponseCodeError` with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data, - ) - self.assertEqual(error.response_text, text_data) - self.assertEqual(error.response_json, {}) - - def test_response_code_error_string_representation_with_nonempty_response_text(self): - """Test the string representation of `ResponseCodeError` initialized with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 3b71022db..d0e801299 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError -from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from bot.utils.services import ( + FAILED_REQUEST_ATTEMPTS, MAX_PASTE_LENGTH, PasteTooLongError, PasteUploadError, send_to_paste_service +) from tests.helpers import MockBot @@ -55,23 +57,34 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" self.bot.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) + + async def test_raises_error_on_too_long_input(self): + """Ensure PasteTooLongError is raised if `contents` is longer than `MAX_PASTE_LENGTH`.""" + contents = "a" * (MAX_PASTE_LENGTH + 1) + with self.assertRaises(PasteTooLongError): + await send_to_paste_service(contents) + + async def test_raises_on_too_large_max_length(self): + """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" + with self.assertRaises(ValueError): + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH + 1) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index a3dcbfc0a..120d65176 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -13,13 +13,15 @@ class TimeTests(unittest.TestCase): """humanize_delta should be able to handle unknown units, and will not abort.""" # Does not abort for unknown units, as the unit name is checked # against the attribute of the relativedelta instance. - self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='elephants', max_units=2) + self.assertEqual(actual, '2 days and 2 hours') def test_humanize_delta_handle_high_units(self): """humanize_delta should be able to handle very high units.""" # Very high maximum units, but it only ever iterates over # each value the relativedelta might have. - self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=20) + self.assertEqual(actual, '2 days and 2 hours') def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" @@ -32,7 +34,8 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + actual = time.humanize_delta(delta, precision=precision, max_units=max_units) + self.assertEqual(actual, expected) def test_humanize_delta_raises_for_invalid_max_units(self): """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" @@ -40,22 +43,11 @@ class TimeTests(unittest.TestCase): for max_units in test_cases: with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: - time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - self.assertEqual(str(error.exception), 'max_units must be positive') - - def test_parse_rfc1123(self): - """Testing parse_rfc1123.""" - self.assertEqual( - time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), - datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) - ) - - def test_format_infraction(self): - """Testing format_infraction.""" - self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '<t:1576108860:f>') + time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=max_units) + self.assertEqual(str(error.exception), 'max_units must be positive.') - def test_format_infraction_with_duration_none_expiry(self): - """format_infraction_with_duration should work for None expiry.""" + def test_format_with_duration_none_expiry(self): + """format_with_duration should work for None expiry.""" test_cases = ( (None, None, None, None), @@ -67,10 +59,10 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) - def test_format_infraction_with_duration_custom_units(self): - """format_infraction_with_duration should work for custom max_units.""" + def test_format_with_duration_custom_units(self): + """format_with_duration should work for custom max_units.""" test_cases = ( ('3000-12-12T00:01:00Z', datetime(3000, 12, 11, 12, 5, 5, tzinfo=timezone.utc), 6, '<t:32533488060:f> (11 hours, 55 minutes and 55 seconds)'), @@ -80,10 +72,10 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) - def test_format_infraction_with_duration_normal_usage(self): - """format_infraction_with_duration should work for normal usage, across various durations.""" + def test_format_with_duration_normal_usage(self): + """format_with_duration should work for normal usage, across various durations.""" utc = timezone.utc test_cases = ( ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5, tzinfo=utc), 2, @@ -105,11 +97,11 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected) def test_until_expiration_with_duration_none_expiry(self): - """until_expiration should work for None expiry.""" - self.assertEqual(time.until_expiration(None), None) + """until_expiration should return "Permanent" is expiry is None.""" + self.assertEqual(time.until_expiration(None), "Permanent") def test_until_expiration_with_duration_custom_units(self): """until_expiration should work for custom max_units.""" @@ -130,7 +122,6 @@ class TimeTests(unittest.TestCase): ('3000-12-12T00:00:00Z', '<t:32533488000:R>'), ('3000-11-23T20:09:00Z', '<t:32531918940:R>'), ('3000-11-23T20:09:00Z', '<t:32531918940:R>'), - (None, None), ) for expiry, expected in test_cases: diff --git a/tests/helpers.py b/tests/helpers.py index 9d4988d23..17214553c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,10 +9,10 @@ from typing import Iterable, Optional import discord from aiohttp import ClientSession +from botcore.async_stats import AsyncStatsClient +from botcore.site_api import APIClient from discord.ext.commands import Context -from bot.api import APIClient -from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module @@ -171,7 +171,7 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): spec_set = guild_instance def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'members': []} + default_kwargs = {'id': next(self.discord_id), 'members': [], "chunked": True} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] @@ -312,6 +312,10 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop(), redis_session=unittest.mock.MagicMock(), + http_session=unittest.mock.MagicMock(), + allowed_roles=[1], + guild_id=1, + intents=discord.Intents.all(), ) additional_spec_asyncs = ("wait_for", "redis_ready") @@ -322,6 +326,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.api_client = MockAPIClient(loop=self.loop) self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) + self.add_cog = unittest.mock.AsyncMock() # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -334,6 +339,8 @@ channel_data = { 'position': 1, 'nsfw': False, 'last_message_id': 1, + 'bitrate': 1337, + 'user_limit': 25, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() @@ -425,7 +432,7 @@ message_data = { 'webhook_id': 431341013479718912, 'attachments': [], 'embeds': [], - 'application': 'Python Discord', + 'application': {"id": 4, "description": "A Python Bot", "name": "Python Discord", "icon": None}, 'activity': 'mocking', 'channel': unittest.mock.MagicMock(), 'edited_timestamp': '2019-10-14T15:33:48+00:00', @@ -438,6 +445,7 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() +channel.type = discord.ChannelType.text message_instance = discord.Message(state=state, channel=channel, data=message_data) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 81285e009..f3040b305 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -327,7 +327,7 @@ class MockObjectTests(unittest.TestCase): def test_spec_propagation_of_mock_subclasses(self): """Test if the `spec` does not propagate to attributes of the mock object.""" test_values = ( - (helpers.MockGuild, "region"), + (helpers.MockGuild, "features"), (helpers.MockRole, "mentionable"), (helpers.MockMember, "display_name"), (helpers.MockBot, "owner_id"), |