diff options
Diffstat (limited to 'tests')
28 files changed, 565 insertions, 524 deletions
diff --git a/tests/README.md b/tests/README.md index fc03b3d43..b7fddfaa2 100644 --- a/tests/README.md +++ b/tests/README.md @@ -121,9 +121,9 @@ As we are trying to test our "units" of code independently, we want to make sure However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". -To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `disnake` types (see the section on the below.). +To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). -An example of mocking is when we provide a command with a mocked version of `disnake.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): +An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): ```py import asyncio @@ -152,15 +152,15 @@ class BotCogTests(unittest.TestCase): By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected. -### Special mocks for some `disnake` types +### Special mocks for some `discord.py` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual disnake types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `disnake` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. -**Note:** These mock types only "know" the attributes that are set by default when these `disnake` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: +**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: ```py import unittest.mock @@ -245,7 +245,7 @@ All in all, it's not only important to consider if all statements or branches we ### Unit Testing vs Integration Testing -Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `disnake` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. +Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. diff --git a/tests/base.py b/tests/base.py index dea7dd678..4863a1821 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,8 +3,9 @@ import unittest from contextlib import contextmanager from typing import Dict -import disnake -from disnake.ext import commands +import discord +from async_rediscache import RedisSession +from discord.ext import commands from bot.log import get_logger from tests import helpers @@ -80,7 +81,7 @@ class LoggingTestsMixin: class CommandTestCase(unittest.IsolatedAsyncioTestCase): - """TestCase with additional assertions that are useful for testing disnake commands.""" + """TestCase with additional assertions that are useful for testing Discord commands.""" async def assertHasPermissionsCheck( # noqa: N802 self, @@ -98,9 +99,32 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase): permissions = {k: not v for k, v in permissions.items()} ctx = helpers.MockContext() - ctx.channel.permissions_for.return_value = disnake.Permissions(**permissions) + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) with self.assertRaises(commands.MissingPermissions) as cm: await cmd.can_run(ctx) self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions) + + +class RedisTestCase(unittest.IsolatedAsyncioTestCase): + """ + Use this as a base class for any test cases that require a redis session. + + This will prepare a fresh redis instance for each test function, and will + not make any assertions on its own. Tests can mutate the instance as they wish. + """ + + session = None + + async def flush(self): + """Flush everything from the redis database to prevent carry-overs between tests.""" + await self.session.client.flushall() + + async def asyncSetUp(self): + self.session = await RedisSession(use_fakeredis=True).connect() + await self.flush() + + async def asyncTearDown(self): + if self.session: + await self.session.client.close() diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 9dc46005b..a17c1fa10 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -1,7 +1,8 @@ import unittest from unittest import mock -from bot.api import ResponseCodeError +from botcore.site_api import ResponseCodeError + from bot.exts.backend.sync._syncers import Syncer from tests import helpers diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 4ed7de64d..87b76c6b4 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -1,10 +1,10 @@ import unittest from unittest import mock -import disnake +import discord +from botcore.site_api import ResponseCodeError from bot import constants -from bot.api import ResponseCodeError from bot.exts.backend import sync from bot.exts.backend.sync._cog import Sync from bot.exts.backend.sync._syncers import Syncer @@ -16,11 +16,11 @@ class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the sync extension.""" @staticmethod - def test_extension_setup(): + async def test_extension_setup(): """The Sync cog should be added.""" bot = helpers.MockBot() - sync.setup(bot) - bot.add_cog.assert_called_once() + await sync.setup(bot) + bot.add_cog.assert_awaited_once() class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): @@ -60,22 +60,18 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): class SyncCogTests(SyncCogTestCase): """Tests for the Sync cog.""" - @mock.patch("bot.utils.scheduling.create_task") - @mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock) - def test_sync_cog_init(self, sync_guild, create_task): - """Should instantiate syncers and run a sync for the guild.""" - # Reset because a Sync cog was already instantiated in setUp. + async def test_sync_cog_sync_on_load(self): + """Roles and users should be synced on cog load.""" + guild = helpers.MockGuild() + self.bot.get_guild = mock.MagicMock(return_value=guild) + self.RoleSyncer.reset_mock() self.UserSyncer.reset_mock() - mock_sync_guild_coro = mock.MagicMock() - sync_guild.return_value = mock_sync_guild_coro - - Sync(self.bot) + await self.cog.cog_load() - sync_guild.assert_called_once_with() - create_task.assert_called_once() - self.assertEqual(create_task.call_args.args[0], mock_sync_guild_coro) + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) async def test_sync_cog_sync_guild(self): """Roles and users should be synced only if a guild is successfully retrieved.""" @@ -87,7 +83,7 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild = mock.MagicMock(return_value=guild) - await self.cog.sync_guild() + await self.cog.cog_load() self.bot.wait_until_guild_available.assert_called_once() self.bot.get_guild.assert_called_once_with(constants.Guild.id) @@ -257,9 +253,9 @@ class SyncCogListenerTests(SyncCogTestCase): self.assertTrue(self.cog.on_member_update.__cog_listener__) subtests = ( - ("activities", disnake.Game("Pong"), disnake.Game("Frogger")), + ("activities", discord.Game("Pong"), discord.Game("Frogger")), ("nick", "old nick", "new nick"), - ("status", disnake.Status.online, disnake.Status.offline), + ("status", discord.Status.online, discord.Status.offline), ) for attribute, old_value, new_value in subtests: diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 9ecb8fae0..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -import disnake +import discord from bot.exts.backend.sync._syncers import RoleSyncer, _Diff, _Role from tests import helpers @@ -34,8 +34,8 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): for role in roles: mock_role = helpers.MockRole(**role) - mock_role.colour = disnake.Colour(role["colour"]) - mock_role.permissions = disnake.Permissions(role["permissions"]) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) guild.roles.append(mock_role) return guild diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index f55f5360f..2fc97af2d 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from disnake.errors import NotFound +from discord.errors import NotFound from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 83b5f2749..7562f6aa8 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -1,11 +1,11 @@ import unittest from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake.ext.commands import errors +from botcore.site_api import ResponseCodeError +from discord.ext.commands import errors -from bot.api import ResponseCodeError from bot.errors import InvalidInfractedUserError, LockedResourceError -from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure @@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_error_handler_already_handled(self): """Should not do anything when error is already handled by local error handler.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandError() error.handled = "foo" - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_not_awaited() async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): @@ -45,27 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): "called_try_get_tag": True } ) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock(return_value=False) for case in test_cases: with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): self.ctx.reset_mock() - cog.try_silence.reset_mock(return_value=True) - cog.try_get_tag.reset_mock() + self.cog.try_silence.reset_mock(return_value=True) + self.cog.try_get_tag.reset_mock() - cog.try_silence.return_value = case["try_silence_return"] + self.cog.try_silence.return_value = case["try_silence_return"] self.ctx.channel.id = 1234 - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) if case["try_silence_return"]: - cog.try_get_tag.assert_not_awaited() - cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_awaited_once() else: - cog.try_silence.assert_awaited_once() - cog.try_get_tag.assert_awaited_once() + self.cog.try_silence.assert_awaited_once() + self.cog.try_get_tag.assert_awaited_once() self.ctx.send.assert_not_awaited() @@ -73,57 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) - cog = ErrorHandler(self.bot) - cog.try_silence = AsyncMock() - cog.try_get_tag = AsyncMock() + self.cog.try_silence = AsyncMock() + self.cog.try_get_tag = AsyncMock() + self.cog.try_run_eval = AsyncMock() error = errors.CommandNotFound() - self.assertIsNone(await cog.on_command_error(ctx, error)) + self.assertIsNone(await self.cog.on_command_error(ctx, error)) - cog.try_silence.assert_not_awaited() - cog.try_get_tag.assert_not_awaited() + self.cog.try_silence.assert_not_awaited() + self.cog.try_get_tag.assert_not_awaited() + self.cog.try_run_eval.assert_not_awaited() self.ctx.send.assert_not_awaited() async def test_error_handler_user_input_error(self): """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_user_input_error = AsyncMock() + self.cog.handle_user_input_error = AsyncMock() error = errors.UserInputError() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) async def test_error_handler_check_failure(self): """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) - cog.handle_check_failure = AsyncMock() + self.cog.handle_check_failure = AsyncMock() error = errors.CheckFailure() - self.assertIsNone(await cog.on_command_error(self.ctx, error)) - cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) + self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) async def test_error_handler_command_on_cooldown(self): """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" self.ctx.reset_mock() - cog = ErrorHandler(self.bot) error = errors.CommandOnCooldown(10, 9, type=None) - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) self.ctx.send.assert_awaited_once_with(error) async def test_error_handler_command_invoke_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() test_cases = ( { "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), - "expect_mock_call": cog.handle_api_error + "expect_mock_call": self.cog.handle_api_error }, { "args": (self.ctx, errors.CommandInvokeError(TypeError)), - "expect_mock_call": cog.handle_unexpected_error + "expect_mock_call": self.cog.handle_unexpected_error }, { "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), @@ -138,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): self.ctx.send.reset_mock() - self.assertIsNone(await cog.on_command_error(*case["args"])) + self.assertIsNone(await self.cog.on_command_error(*case["args"])) if case["expect_mock_call"] == "send": self.ctx.send.assert_awaited_once() else: @@ -148,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): async def test_error_handler_conversion_error(self): """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" - cog = ErrorHandler(self.bot) - cog.handle_api_error = AsyncMock() - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_api_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() cases = ( { "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), - "mock_function_to_call": cog.handle_api_error + "mock_function_to_call": self.cog.handle_api_error }, { "error": errors.ConversionError(AsyncMock(), TypeError), - "mock_function_to_call": cog.handle_unexpected_error + "mock_function_to_call": self.cog.handle_unexpected_error } ) for case in cases: with self.subTest(**case): - self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) + self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"])) case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) async def test_error_handler_two_other_errors(self): """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" - cog = ErrorHandler(self.bot) - cog.handle_unexpected_error = AsyncMock() + self.cog.handle_unexpected_error = AsyncMock() errs = ( errors.MaxConcurrencyReached(1, MagicMock()), errors.ExtensionError(name="foo") @@ -178,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for err in errs: with self.subTest(error=err): - cog.handle_unexpected_error.reset_mock() - self.assertIsNone(await cog.on_command_error(self.ctx, err)) - cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + self.cog.handle_unexpected_error.reset_mock() + self.assertIsNone(await self.cog.on_command_error(self.ctx, err)) + self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) @patch("bot.exts.backend.error_handler.log") async def test_error_handler_other_errors(self, log_mock): """Should `log.debug` other errors.""" - cog = ErrorHandler(self.bot) error = errors.DisabledCommand() # Use this just as a other error - self.assertIsNone(await cog.on_command_error(self.ctx, error)) + self.assertIsNone(await self.cog.on_command_error(self.ctx, error)) log_mock.debug.assert_called_once() @@ -199,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.silence = Silence(self.bot) self.bot.get_command.return_value = self.silence.silence self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_try_silence_context_invoked_from_error_handler(self): """Should set `Context.invoked_from_error_handler` to `True`.""" @@ -331,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.ctx = MockContext() self.tag = Tags(self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command async def test_try_get_tag_get_command(self): @@ -396,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() self.ctx = MockContext(bot=self.bot) - self.cog = ErrorHandler(self.bot) + self.cog = error_handler.ErrorHandler(self.bot) async def test_handle_input_error_handler_errors(self): """Should handle each error probably.""" @@ -477,11 +471,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.backend.error_handler.log") async def test_handle_api_error(self, log_mock): - """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" + """Should `ctx.send` on HTTP error codes, and log at correct level.""" test_cases = ( { "error": ResponseCodeError(AsyncMock(status=400)), - "log_level": "debug" + "log_level": "error" }, { "error": ResponseCodeError(AsyncMock(status=404)), @@ -505,6 +499,8 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): self.ctx.send.assert_awaited_once() if case["log_level"] == "warning": log_mock.warning.assert_called_once() + elif case["log_level"] == "error": + log_mock.error.assert_called_once() else: log_mock.debug.assert_called_once() @@ -544,11 +540,11 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): push_scope_mock.set_extra.has_calls(set_extra_calls) -class ErrorHandlerSetupTests(unittest.TestCase): +class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase): """Tests for `ErrorHandler` `setup` function.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog` with `ErrorHandler`.""" bot = MockBot() - setup(bot) - bot.add_cog.assert_called_once() + await error_handler.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py index fdff36b61..684f7abcd 100644 --- a/tests/bot/exts/events/test_code_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import AsyncMock, MagicMock, create_autospec, patch -from disnake import CategoryChannel -from disnake.ext.commands import BadArgument +from discord import CategoryChannel +from discord.ext.commands import BadArgument from bot.constants import Roles from bot.exts.events import code_jams @@ -160,11 +160,11 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase): member.add_roles.assert_not_awaited() -class CodeJamSetup(unittest.TestCase): +class CodeJamSetup(unittest.IsolatedAsyncioTestCase): """Test for `setup` function of `CodeJam` cog.""" - def test_setup(self): + async def test_setup(self): """Should call `bot.add_cog`.""" bot = MockBot() - code_jams.setup(bot) - bot.add_cog.assert_called_once() + await code_jams.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py index 0cab405d0..7282334e2 100644 --- a/tests/bot/exts/filters/test_antimalware.py +++ b/tests/bot/exts/filters/test_antimalware.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock, Mock -from disnake import NotFound +from discord import NotFound from bot.constants import Channels, STAFF_ROLES from bot.exts.filters import antimalware @@ -192,11 +192,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -class AntiMalwareSetupTests(unittest.TestCase): +class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `AntiMalware` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - antimalware.setup(bot) - bot.add_cog.assert_called_once() + await antimalware.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py index 8ae59c1f1..bd26532f1 100644 --- a/tests/bot/exts/filters/test_filtering.py +++ b/tests/bot/exts/filters/test_filtering.py @@ -11,7 +11,7 @@ class FilteringCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Instantiate the bot and cog.""" self.bot = MockBot() - with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()): + with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): self.cog = filtering.Filtering(self.bot) @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py index 46fa82fd7..007b7b1eb 100644 --- a/tests/bot/exts/filters/test_security.py +++ b/tests/bot/exts/filters/test_security.py @@ -1,7 +1,6 @@ import unittest -from unittest.mock import MagicMock -from disnake.ext.commands import NoPrivateMessage +from discord.ext.commands import NoPrivateMessage from bot.exts.filters import security from tests.helpers import MockBot, MockContext @@ -44,11 +43,11 @@ class SecurityCogTests(unittest.TestCase): self.assertTrue(self.cog.check_on_guild(self.ctx)) -class SecurityCogLoadTests(unittest.TestCase): +class SecurityCogLoadTests(unittest.IsolatedAsyncioTestCase): """Tests loading the `Security` cog.""" - def test_security_cog_load(self): + async def test_security_cog_load(self): """Setup of the extension should call add_cog.""" - bot = MagicMock() - security.setup(bot) - bot.add_cog.assert_called_once() + bot = MockBot() + await security.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py index dd56c10dd..c1f3762ac 100644 --- a/tests/bot/exts/filters/test_token_remover.py +++ b/tests/bot/exts/filters/test_token_remover.py @@ -3,7 +3,7 @@ from re import Match from unittest import mock from unittest.mock import MagicMock -from disnake import Colour, NotFound +from discord import Colour, NotFound from bot import constants from bot.exts.filters import token_remover @@ -395,15 +395,15 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): self.msg.channel.send.assert_not_awaited() -class TokenRemoverExtensionTests(unittest.TestCase): +class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): """Tests for the token_remover extension.""" @autospec("bot.exts.filters.token_remover", "TokenRemover") - def test_extension_setup(self, cog): + async def test_extension_setup(self, cog): """The TokenRemover cog should be added.""" bot = MockBot() - token_remover.setup(bot) + await token_remover.setup(bot) cog.assert_called_once_with(bot) - bot.add_cog.assert_called_once() + bot.add_cog.assert_awaited_once() self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py index 604c69671..2644ae40d 100644 --- a/tests/bot/exts/info/test_help.py +++ b/tests/bot/exts/info/test_help.py @@ -12,7 +12,6 @@ class HelpCogTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = help.Help(self.bot) self.ctx = MockContext(bot=self.bot) - self.bot.help_command.context = self.ctx @autospec(help.CustomHelpCommand, "get_all_help_choices", return_value={"help"}, pass_mocks=False) async def test_help_fuzzy_matching(self): diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 9a35de7a9..d896b7652 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -3,7 +3,7 @@ import unittest import unittest.mock from datetime import datetime -import disnake +import discord from bot import constants from bot.exts.info import information @@ -43,7 +43,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): embed = kwargs.pop('embed') self.assertEqual(embed.title, "Role information (Total 1 role)") - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") async def test_role_info_command(self): @@ -51,19 +51,19 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): dummy_role = helpers.MockRole( name="Dummy", id=112233445566778899, - colour=disnake.Colour.og_blurple(), + colour=discord.Colour.og_blurple(), position=10, members=[self.ctx.author], - permissions=disnake.Permissions(0) + permissions=discord.Permissions(0) ) admin_role = helpers.MockRole( name="Admins", id=998877665544332211, - colour=disnake.Colour.red(), + colour=discord.Colour.red(), position=3, members=[self.ctx.author], - permissions=disnake.Permissions(0), + permissions=discord.Permissions(0), ) self.ctx.guild.roles.extend([dummy_role, admin_role]) @@ -81,7 +81,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): admin_embed = admin_kwargs["embed"] self.assertEqual(dummy_embed.title, "Dummy info") - self.assertEqual(dummy_embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(dummy_embed.colour, discord.Colour.og_blurple()) self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") @@ -91,7 +91,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(dummy_embed.fields[5].value, "0") self.assertEqual(admin_embed.title, "Admins info") - self.assertEqual(admin_embed.colour, disnake.Colour.red()) + self.assertEqual(admin_embed.colour, discord.Colour.red()) class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase): @@ -449,7 +449,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour(100)) + self.assertEqual(embed.colour, discord.Colour(100)) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", @@ -463,11 +463,11 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase): """The embed should be created with the og blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() - user = helpers.MockMember(id=217, colour=disnake.Colour.default()) + user = helpers.MockMember(id=217, colour=discord.Colour.default()) user.created_at = user.joined_at = datetime.utcnow() embed = await self.cog.create_user_embed(ctx, user, False) - self.assertEqual(embed.colour, disnake.Colour.og_blurple()) + self.assertEqual(embed.colour, discord.Colour.og_blurple()) @unittest.mock.patch( f"{COG_PATH}.basic_user_infraction_counts", diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b85d086c9..052048053 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -3,7 +3,7 @@ import textwrap import unittest from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch -from disnake.errors import NotFound +from discord.errors import NotFound from bot.constants import Event from bot.exts.moderation.clean import Clean diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index eaa0e701e..5cf02033d 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -3,9 +3,9 @@ from collections import namedtuple from datetime import datetime from unittest.mock import AsyncMock, MagicMock, call, patch -from disnake import Embed, Forbidden, HTTPException, NotFound +from botcore.site_api import ResponseCodeError +from discord import Embed, Forbidden, HTTPException, NotFound -from bot.api import ResponseCodeError from bot.constants import Colours, Icons from bot.exts.moderation.infraction import _utils as utils from tests.helpers import MockBot, MockContext, MockMember, MockUser diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 725455bbe..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@ import asyncio +import datetime import enum import logging import typing as t @@ -7,24 +8,27 @@ from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import aiohttp -import disnake -from async_rediscache import RedisSession +import discord from bot.constants import Colours from bot.exts.moderation import incidents from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp +from tests.base import RedisTestCase from tests.helpers import ( MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel, MockUser ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) + class MockAsyncIterable: """ Helper for mocking asynchronous for loops. It does not appear that the `unittest` library currently provides anything that would - allow us to simply mock an async iterator, such as `disnake.TextChannel.history`. + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. We therefore write our own helper to wrap a regular synchronous iterable, and feed its values via `__anext__` rather than `__next__`. @@ -60,7 +64,7 @@ class MockSignal(enum.Enum): B = "B" -mock_404 = disnake.NotFound( +mock_404 = discord.NotFound( response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response message="Not found", ) @@ -70,8 +74,8 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): """Collection of tests for the `download_file` helper function.""" async def test_download_file_success(self): - """If `to_file` succeeds, function returns the acquired `disnake.File`.""" - file = MagicMock(disnake.File, filename="bigbadlemon.jpg") + """If `to_file` succeeds, function returns the acquired `discord.File`.""" + file = MagicMock(discord.File, filename="bigbadlemon.jpg") attachment = MockAttachment(to_file=AsyncMock(return_value=file)) acquired_file = await incidents.download_file(attachment) @@ -86,7 +90,7 @@ class TestDownloadFile(unittest.IsolatedAsyncioTestCase): async def test_download_file_fail(self): """If `to_file` fails on a non-404 error, function logs the exception & returns None.""" - arbitrary_error = disnake.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") + arbitrary_error = discord.HTTPException(MagicMock(aiohttp.ClientResponse), "Arbitrary API error") attachment = MockAttachment(to_file=AsyncMock(side_effect=arbitrary_error)) with self.assertLogs(logger=incidents.log, level=logging.ERROR): @@ -100,30 +104,45 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_actioned(self): """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_green) self.assertIn("Actioned", embed.footer.text) async def test_make_embed_not_actioned(self): """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" - embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) + embed, file = await incidents.make_embed( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=incidents.Signal.NOT_ACTIONED, + actioned_by=MockMember() + ) self.assertEqual(embed.colour.value, Colours.soft_red) self.assertIn("Rejected", embed.footer.text) async def test_make_embed_content(self): """Incident content appears as embed description.""" - incident = MockMessage(content="this is an incident") + incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) + + reported_timestamp = discord_timestamp(CURRENT_TIME) + relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) + embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) - self.assertEqual(incident.content, embed.description) + self.assertEqual( + f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", + embed.description + ) async def test_make_embed_with_attachment_succeeds(self): """Incident's attachment is downloaded and displayed in the embed's image field.""" - file = MagicMock(disnake.File, filename="bigbadjoe.jpg") + file = MagicMock(discord.File, filename="bigbadjoe.jpg") attachment = MockAttachment(filename="bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return our `file` with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -135,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase): async def test_make_embed_with_attachment_fails(self): """Incident's attachment fails to download, proxy url is linked instead.""" attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") - incident = MockMessage(content="this is an incident", attachments=[attachment]) + incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME) # Patch `download_file` to return None as if the download failed with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -270,7 +289,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase): self.incident.add_reaction.assert_not_called() -class TestIncidents(unittest.IsolatedAsyncioTestCase): +class TestIncidents(RedisTestCase): """ Tests for bound methods of the `Incidents` cog. @@ -279,22 +298,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase): the instance as they wish. """ - session = None - - async def flush(self): - """Flush everything from the database to prevent carry-overs between tests.""" - with await self.session.pool as connection: - await connection.flushall() - - async def asyncSetUp(self): # noqa: N802 - self.session = RedisSession(use_fakeredis=True) - await self.session.connect() - await self.flush() - - async def asyncTearDown(self): # noqa: N802 - if self.session: - await self.session.close() - def setUp(self): """ Prepare a fresh `Incidents` instance for each test. @@ -365,7 +368,6 @@ class TestCrawlIncidents(TestIncidents): class TestArchive(TestIncidents): """Tests for the `Incidents.archive` coroutine.""" - async def test_archive_webhook_not_found(self): """ Method recovers and returns False when the webhook is not found. @@ -375,7 +377,11 @@ class TestArchive(TestIncidents): """ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404) self.assertFalse( - await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) + await self.cog_instance.archive( + incident=MockMessage(created_at=CURRENT_TIME), + outcome=MagicMock(), + actioned_by=MockMember() + ) ) async def test_archive_relays_incident(self): @@ -391,10 +397,10 @@ class TestArchive(TestIncidents): # Define our own `incident` to be archived incident = MockMessage( content="this is an incident", - author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), + author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")), id=123, ) - built_embed = MagicMock(disnake.Embed, id=123) # We patch `make_embed` to return this + built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this with patch("bot.exts.moderation.incidents.make_embed", AsyncMock(return_value=(built_embed, None))): archive_return = await self.cog_instance.archive(incident, MagicMock(value="A"), MockMember()) @@ -422,7 +428,7 @@ class TestArchive(TestIncidents): webhook = MockAsyncWebhook() self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) - message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) + message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME) await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember()) self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -521,12 +527,13 @@ class TestProcessEvent(TestIncidents): async def test_process_event_confirmation_task_is_awaited(self): """Task given by `Incidents.make_confirmation_task` is awaited before method exits.""" mock_task = AsyncMock() + mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)]) with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), - member=MockMember(roles=[MockRole(id=1)]) + incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), + member=mock_member ) mock_task.assert_awaited() @@ -545,7 +552,7 @@ class TestProcessEvent(TestIncidents): with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task): await self.cog_instance.process_event( reaction=incidents.Signal.ACTIONED.value, - incident=MockMessage(id=123), + incident=MockMessage(id=123, created_at=CURRENT_TIME), member=MockMember(roles=[MockRole(id=1)]) ) except asyncio.TimeoutError: @@ -616,7 +623,7 @@ class TestResolveMessage(TestIncidents): """ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None - arbitrary_error = disnake.HTTPException( + arbitrary_error = discord.HTTPException( response=MagicMock(aiohttp.ClientResponse), message="Arbitrary error", ) @@ -649,14 +656,14 @@ class TestOnRawReactionAdd(TestIncidents): super().setUp() # Ensure `cog_instance` is assigned self.payload = MagicMock( - disnake.RawReactionActionEvent, + discord.RawReactionActionEvent, channel_id=123, # Patched at class level message_id=456, member=MockMember(bot=False), emoji="reaction", ) - async def asyncSetUp(self): # noqa: N802 + async def asyncSetUp(self): """ Prepare an empty task and assign it as `crawl_task`. diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index 6c9ebed95..79e04837d 100644 --- a/tests/bot/exts/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -1,6 +1,6 @@ import unittest -import disnake +import discord from bot.exts.moderation.modlog import ModLog from tests.helpers import MockBot, MockTextChannel @@ -19,7 +19,7 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase): self.bot.get_channel.return_value = self.channel await self.cog.send_log_message( icon_url="foo", - colour=disnake.Colour.blue(), + colour=discord.Colour.blue(), title="bar", text="foo bar" * 3000 ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 539651d6c..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio import itertools import unittest from datetime import datetime, timezone @@ -6,31 +5,15 @@ from typing import List, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from async_rediscache import RedisSession -from disnake import PermissionOverwrite +from discord import PermissionOverwrite from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence +from tests.base import RedisTestCase from tests.helpers import ( MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec ) -redis_session = None -redis_loop = asyncio.get_event_loop() - - -def setUpModule(): # noqa: N802 - """Create and connect to the fakeredis session.""" - global redis_session - redis_session = RedisSession(use_fakeredis=True) - redis_loop.run_until_complete(redis_session.connect()) - - -def tearDownModule(): # noqa: N802 - """Close the fakeredis session.""" - if redis_session: - redis_loop.run_until_complete(redis_session.close()) - # Have to subclass it because builtins can't be patched. class PatchedDatetime(datetime): @@ -39,8 +22,24 @@ class PatchedDatetime(datetime): now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): + """A base class for Silence tests that correctly sets up the cog and redis.""" + + @autospec(silence, "Scheduler", pass_mocks=False) + @autospec(silence.Silence, "_reschedule", pass_mocks=False) def setUp(self) -> None: + self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) + self.cog = silence.Silence(self.bot) + + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest): + def setUp(self) -> None: + super().setUp() self.alert_channel = MockTextChannel() self.notifier = silence.SilenceNotifier(self.alert_channel) self.notifier.stop = self.notifier_stop_mock = Mock() @@ -105,54 +104,36 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +class SilenceCogTests(SilenceTest): """Tests for the general functionality of the Silence cog.""" - @autospec(silence, "Scheduler", pass_mocks=False) - def setUp(self) -> None: - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_guild(self): + async def test_cog_load_got_guild(self): """Bot got guild after it became available.""" - await self.cog._async_init() self.bot.wait_until_guild_available.assert_awaited_once() self.bot.get_guild.assert_called_once_with(Guild.id) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_got_channels(self): + async def test_cog_load_got_channels(self): """Got channels from bot.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - - await self.cog._async_init() + await self.cog.cog_load() self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) @autospec(silence, "SilenceNotifier") - async def test_async_init_got_notifier(self, notifier): + async def test_cog_load_got_notifier(self, notifier): """Notifier was started with channel.""" - self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) - - await self.cog._async_init() + await self.cog.cog_load() notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) self.assertEqual(self.cog.notifier, notifier.return_value) @autospec(silence, "SilenceNotifier", pass_mocks=False) - async def test_async_init_rescheduled(self): + async def testcog_load_rescheduled(self): """`_reschedule_` coroutine was awaited.""" self.cog._reschedule = mock.create_autospec(self.cog._reschedule) - await self.cog._async_init() + await self.cog.cog_load() self.cog._reschedule.assert_awaited_once_with() - def test_cog_unload_cancelled_tasks(self): - """The init task was cancelled.""" - self.cog._init_task = asyncio.Future() - self.cog.cog_unload() - - # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. - self.assertTrue(self.cog._init_task.cancelled()) - - @autospec("disnake.ext.commands", "has_any_role") + @autospec("discord.ext.commands", "has_any_role") @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" @@ -165,7 +146,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync(self): """Tests the _force_voice_sync helper function.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -187,7 +168,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_force_voice_sync_no_channel(self): """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" - await self.cog._async_init() + await self.cog.cog_load() channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) new_channel = MockVoiceChannel(delete=AsyncMock()) @@ -206,7 +187,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_voice_kick(self): """Test to ensure kick function can remove all members from a voice channel.""" - await self.cog._async_init() + await self.cog.cog_load() # Create a regular member, and one member for each of the moderation roles moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] @@ -236,7 +217,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_kick_move_to_error(self): """Test to ensure move_to gets called on all members during kick, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() _, members = self.create_erroneous_members() await self.cog._kick_voice_members(MockVoiceChannel(members=members)) @@ -245,7 +226,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): async def test_sync_move_to_error(self): """Test to ensure move_to gets called on all members during sync, even if some fail.""" - await self.cog._async_init() + await self.cog.cog_load() failing_member, members = self.create_erroneous_members() await self.cog._force_voice_sync(MockVoiceChannel(members=members)) @@ -253,15 +234,9 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +class SilenceArgumentParserTests(SilenceTest): """Tests for the silence argument parser utility function.""" - def setUp(self): - self.bot = MockBot() - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) - @autospec(silence.Silence, "send_message", pass_mocks=False) @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) @autospec(silence.Silence, "parse_silence_args") @@ -329,17 +304,19 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase): """Tests for the rescheduling of cached unsilences.""" - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) - def setUp(self): + @autospec(silence, "Scheduler", pass_mocks=False) + def setUp(self) -> None: self.bot = MockBot() self.cog = silence.Silence(self.bot) self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) - with mock.patch.object(self.cog, "_reschedule", autospec=True): - asyncio.run(self.cog._async_init()) # Populate instance attributes. + @autospec(silence, "SilenceNotifier", pass_mocks=False) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + await self.cog.cog_load() # Populate instance attributes. async def test_skipped_missing_channel(self): """Did nothing because the channel couldn't be retrieved.""" @@ -414,22 +391,14 @@ def voice_sync_helper(function): @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(unittest.IsolatedAsyncioTestCase): +class SilenceTests(SilenceTest): """Tests for the silence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) + super().setUp() # Avoid unawaited coroutine warnings. self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - - asyncio.run(self.cog._async_init()) # Populate instance attributes. - self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite( send_messages=True, @@ -687,24 +656,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest): """Tests for the unsilence command and its related helper methods.""" - @autospec(silence.Silence, "_reschedule", pass_mocks=False) - @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot(get_channel=lambda _: MockTextChannel()) - self.cog = silence.Silence(self.bot) - self.cog._init_task = asyncio.Future() - self.cog._init_task.set_result(None) - - overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) - self.cog.previous_overwrites = overwrites_cache - - asyncio.run(self.cog._async_init()) # Populate instance attributes. + super().setUp() self.cog.scheduler.__contains__.return_value = True - overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' self.text_channel = MockTextChannel() self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -713,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) self.voice_channel.overwrites_for.return_value = self.voice_overwrite + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) + self.cog.previous_overwrites = overwrites_cache + + overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' + async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) diff --git a/tests/bot/exts/test_cogs.py b/tests/bot/exts/test_cogs.py index 5cb071d58..f8e120262 100644 --- a/tests/bot/exts/test_cogs.py +++ b/tests/bot/exts/test_cogs.py @@ -8,7 +8,7 @@ from collections import defaultdict from types import ModuleType from unittest import mock -from disnake.ext import commands +from discord.ext import commands from bot import exts @@ -34,7 +34,7 @@ class CommandNameTests(unittest.TestCase): raise ImportError(name=name) # pragma: no cover # The mock prevents asyncio.get_event_loop() from being called. - with mock.patch("disnake.ext.tasks.loop"): + with mock.patch("discord.ext.tasks.loop"): prefix = f"{exts.__name__}." for module in pkgutil.walk_packages(exts.__path__, prefix, onerror=on_error): if not module.ispkg: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index bec7574fb..b1f32c210 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -2,13 +2,14 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from disnake import AllowedMentions -from disnake.ext import commands +from discord import AllowedMentions +from discord.ext import commands from bot import constants +from bot.errors import LockedResourceError from bot.exts.utils import snekbox from bot.exts.utils.snekbox import Snekbox -from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser +from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -17,7 +18,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) - async def test_post_eval(self): + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") @@ -26,7 +27,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_eval("import random"), "return") + self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -35,17 +36,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): resp.json.assert_awaited_once() async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LEN.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + """Reject output longer than MAX_PASTE_LENGTH.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") @patch("bot.exts.utils.snekbox.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt") + mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) - def test_prepare_input(self): + async def test_codeblock_converter(self): + ctx = MockContext() cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -61,33 +63,50 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) + + def test_prepare_timeit_input(self): + """Test the prepare_timeit_input codeblock detection.""" + base_args = ('-m', 'timeit', '-s') + cases = ( + (['print("Hello World")'], '', 'single block of code'), + (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), + (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') + ) + + for case, setup_code, testname in cases: + setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( - ('ERROR', None, ('Your eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), - ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127', '') ) @patch('bot.exts.utils.snekbox.Signals') def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), - ('Your eval job has completed with return code 127 (SIGTEST)', '') + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), + ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') ) def test_get_status_emoji(self): @@ -156,64 +175,64 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=None) + ctx.command = MagicMock() + + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_once_with(ctx, response) + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) - - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_with(ctx, response) + ctx.command = MagicMock() + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock() + self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + + await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) + self.cog.send_job.assert_called_with( + ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' + ) + self.cog.continue_job.assert_called_with(ctx, response, 'eval') async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" ctx = MockContext() ctx.author.id = 42 - ctx.author.mention = '@LemonLemonishBeard#0042' - ctx.send = AsyncMock() - self.cog.jobs = (42,) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - ctx.send.assert_called_once_with( - "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" - ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) + async def delay_with_side_effect(*args, **kwargs) -> dict: + """Delay the post_job call to ensure the job runs long enough to conflict.""" + await asyncio.sleep(1) + return {'stdout': '', 'returncode': 0} + + self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) + with self.assertRaises(LockedResourceError): + await asyncio.gather( + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + ) - async def test_send_eval(self): - """Test the send_eval function.""" + async def test_send_job(self): + """Test the send_job function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -224,28 +243,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') self.cog.format_output.assert_called_once_with('') - async def test_send_eval_with_paste_link(self): - """Test the send_eval function with a too long output that generate a paste link.""" + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -254,27 +273,29 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with( + {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' + ) self.cog.format_output.assert_called_once_with('Way too long beard') - async def test_send_eval_with_non_zero_eval(self): - """Test the send_eval function with a code returning a non-zero code.""" + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -282,45 +303,53 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") - async def test_continue_eval_does_continue(self, partial_mock): - """Test that the continue_eval function does continue if required conditions are met.""" - ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) - response = MockMessage(delete=AsyncMock()) + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" + ctx = MockContext( + message=MockMessage( + id=4, + add_reaction=AsyncMock(), + clear_reactions=AsyncMock() + ), + author=MockMember(id=14) + ) + response = MockMessage(id=42, delete=AsyncMock()) new_msg = MockMessage() + self.cog.jobs = {4: 42} self.bot.wait_for.side_effect = ((None, new_msg), None) expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) - self.assertEqual(actual, expected) + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) + self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_eval_message_edit, ctx), - timeout=snekbox.REEVAL_TIMEOUT, + check=partial_mock(snekbox.predicate_message_edit, ctx), + timeout=snekbox.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) response.delete.assert_called_once() - async def test_continue_eval_does_not_continue(self): + async def test_continue_job_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage()) - self.assertEqual(actual, None) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, (None, None)) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -343,13 +372,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) - def test_predicate_eval_message_edit(self): - """Test the predicate_eval_message_edit function.""" + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" msg0 = MockMessage(id=1, content='abc') msg1 = MockMessage(id=2, content='abcdef') msg2 = MockMessage(id=1, content='abcdef') @@ -362,18 +391,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) - def test_predicate_eval_emoji_reaction(self): - """Test the predicate_eval_emoji_reaction function.""" + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_reaction.__str__.return_value = snekbox.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -386,15 +415,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) -class SnekboxSetupTests(unittest.TestCase): +class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): """Tests setup of the `Snekbox` cog.""" - def test_setup(self): + async def test_setup(self): """Setup of the extension should call add_cog.""" bot = MockBot() - snekbox.setup(bot) - bot.add_cog.assert_called_once() + await snekbox.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index f8805ac48..e1f904917 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,15 +1,32 @@ -from typing import Iterable +from typing import Iterable, Optional + +import discord from bot.rules import mentions from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage +from tests.helpers import MockMember, MockMessage, MockMessageReference -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: - """Makes a message with `total_mentions` mentions.""" +def make_msg( + author: str, + total_user_mentions: int, + total_bot_mentions: int = 0, + *, + reference: Optional[MockMessageReference] = None +) -> MockMessage: + """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" user_mentions = [MockMember() for _ in range(total_user_mentions)] bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - return MockMessage(author=author, mentions=user_mentions+bot_mentions) + + mentions = user_mentions + bot_mentions + if reference is not None: + # For the sake of these tests we assume that all references are mentions. + mentions.append(reference.resolved.author) + msg_type = discord.MessageType.reply + else: + msg_type = discord.MessageType.default + + return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) class TestMentions(RuleTest): @@ -56,6 +73,16 @@ class TestMentions(RuleTest): ("bob",), 3, ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference())], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], + ("bob",), + 3 + ) ) await self.run_disallowed(cases) @@ -71,6 +98,27 @@ class TestMentions(RuleTest): await self.run_allowed(cases) + async def test_ignore_reply_mentions(self): + """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" + cases = ( + [ + make_msg("bob", 2, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) + ], + [ + make_msg("bob", 2, reference=MockMessageReference()), + make_msg("bob", 0, reference=MockMessageReference()) + ], + [ + make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), + make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) + ] + ) + + await self.run_allowed(cases) + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: last_message = case.recent_messages[0] return tuple( diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py deleted file mode 100644 index 76bcb481d..000000000 --- a/tests/bot/test_api.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from bot import api - - -class APIClientTests(unittest.IsolatedAsyncioTestCase): - """Tests for the bot's API client.""" - - @classmethod - def setUpClass(cls): - """Sets up the shared fixtures for the tests.""" - cls.error_api_response = MagicMock() - cls.error_api_response.status = 999 - - def test_response_code_error_default_initialization(self): - """Test the default initialization of `ResponseCodeError` without `text` or `json`""" - error = api.ResponseCodeError(response=self.error_api_response) - - self.assertIs(error.status, self.error_api_response.status) - self.assertEqual(error.response_json, {}) - self.assertEqual(error.response_text, "") - self.assertIs(error.response, self.error_api_response) - - def test_response_code_error_string_representation_default_initialization(self): - """Test the string representation of `ResponseCodeError` initialized without text or json.""" - error = api.ResponseCodeError(response=self.error_api_response) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") - - def test_response_code_error_initialization_with_json(self): - """Test the initialization of `ResponseCodeError` with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data, - ) - self.assertEqual(error.response_json, json_data) - self.assertEqual(error.response_text, "") - - def test_response_code_error_string_representation_with_nonempty_response_json(self): - """Test the string representation of `ResponseCodeError` initialized with json.""" - json_data = {'hello': 'world'} - error = api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") - - def test_response_code_error_initialization_with_text(self): - """Test the initialization of `ResponseCodeError` with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data, - ) - self.assertEqual(error.response_text, text_data) - self.assertEqual(error.response_json, {}) - - def test_response_code_error_string_representation_with_nonempty_response_text(self): - """Test the string representation of `ResponseCodeError` initialized with text.""" - text_data = 'Lemon will eat your soul' - error = api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index afb8a973d..1bb678db2 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -4,7 +4,7 @@ from datetime import MAXYEAR, datetime, timezone from unittest.mock import MagicMock, patch from dateutil.relativedelta import relativedelta -from disnake.ext.commands import BadArgument +from discord.ext.commands import BadArgument from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 5675e10ec..4ae11d5d3 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock -from disnake import DMChannel +from discord import DMChannel from bot.utils import checks from bot.utils.checks import InWhitelistCheckFailure diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 3b71022db..d0e801299 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError -from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from bot.utils.services import ( + FAILED_REQUEST_ATTEMPTS, MAX_PASTE_LENGTH, PasteTooLongError, PasteUploadError, send_to_paste_service +) from tests.helpers import MockBot @@ -55,23 +57,34 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) - self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" self.bot.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service("") + with self.assertRaises(PasteUploadError): + await send_to_paste_service("") self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) - self.assertIsNone(result) + + async def test_raises_error_on_too_long_input(self): + """Ensure PasteTooLongError is raised if `contents` is longer than `MAX_PASTE_LENGTH`.""" + contents = "a" * (MAX_PASTE_LENGTH + 1) + with self.assertRaises(PasteTooLongError): + await send_to_paste_service(contents) + + async def test_raises_on_too_large_max_length(self): + """Ensure ValueError is raised if `max_length` passed is greater than `MAX_PASTE_LENGTH`.""" + with self.assertRaises(ValueError): + await send_to_paste_service("Hello World!", max_length=MAX_PASTE_LENGTH + 1) diff --git a/tests/helpers.py b/tests/helpers.py index bd1418ab9..687e15b96 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,12 +7,12 @@ import unittest.mock from asyncio import AbstractEventLoop from typing import Iterable, Optional -import disnake +import discord from aiohttp import ClientSession -from disnake.ext.commands import Context +from botcore.async_stats import AsyncStatsClient +from botcore.site_api import APIClient +from discord.ext.commands import Context -from bot.api import APIClient -from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module @@ -26,11 +26,11 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -class HashableMixin(disnake.mixins.EqualityComparable): +class HashableMixin(discord.mixins.EqualityComparable): """ - Mixin that provides similar hashing and equality functionality as disnake's `Hashable` mixin. + Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. - Note: disnake`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions + Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions for the relative small `id` integers we generally use in tests, this bit-shift is omitted. """ @@ -39,22 +39,22 @@ class HashableMixin(disnake.mixins.EqualityComparable): class ColourMixin: - """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like disnake does.""" + """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like discord.py does.""" @property - def color(self) -> disnake.Colour: + def color(self) -> discord.Colour: return self.colour @color.setter - def color(self, color: disnake.Colour) -> None: + def color(self, color: discord.Colour) -> None: self.colour = color @property - def accent_color(self) -> disnake.Colour: + def accent_color(self) -> discord.Colour: return self.accent_colour @accent_color.setter - def accent_color(self, color: disnake.Colour) -> None: + def accent_color(self, color: discord.Colour) -> None: self.accent_colour = color @@ -63,7 +63,7 @@ class CustomMockMixin: Provides common functionality for our custom Mock types. The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock - object. As disnake also uses synchronous methods that nonetheless return coroutine objects, the + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The class method `spec_set` can be overwritten with the object that should be uses as the specification @@ -119,7 +119,7 @@ class CustomMockMixin: return klass(**kw) -# Create a guild instance to get a realistic Mock of `disnake.Guild` +# Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { 'id': 1, 'name': 'guild', @@ -139,20 +139,20 @@ guild_data = { 'owner_id': 1, 'afk_channel_id': 464033278631084042, } -guild_instance = disnake.Guild(data=guild_data, state=unittest.mock.MagicMock()) +guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A `Mock` subclass to mock `disnake.Guild` objects. + A `Mock` subclass to mock `discord.Guild` objects. - A MockGuild instance will follow the specifications of a `disnake.Guild` instance. This means + A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means that if the code you're testing tries to access an attribute or method that normally does not - exist for a `disnake.Guild` object this will raise an `AttributeError`. This is to make sure our - tests fail if the code we're testing uses a `disnake.Guild` object in the wrong way. + exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our + tests fail if the code we're testing uses a `discord.Guild` object in the wrong way. One restriction of that is that if the code tries to access an attribute that normally does not - exist for `disnake.Guild` instance but was added dynamically, this will raise an exception with + exist for `discord.Guild` instance but was added dynamically, this will raise an exception with the mocked object. To get around that, you can set the non-standard attribute explicitly for the instance of `MockGuild`: @@ -160,10 +160,10 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock() In addition to attribute simulation, mocked guild object will pass an `isinstance` check against - `disnake.Guild`: + `discord.Guild`: >>> guild = MockGuild() - >>> isinstance(guild, disnake.Guild) + >>> isinstance(guild, discord.Guild) True For more info, see the `Mocking` section in `tests/README.md`. @@ -171,7 +171,7 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): spec_set = guild_instance def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'members': []} + default_kwargs = {'id': next(self.discord_id), 'members': [], "chunked": True} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] @@ -179,16 +179,16 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): self.roles.extend(roles) -# Create a Role instance to get a realistic Mock of `disnake.Role` +# Create a Role instance to get a realistic Mock of `discord.Role` role_data = {'name': 'role', 'id': 1} -role_instance = disnake.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) +role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ - A Mock subclass to mock `disnake.Role` objects. + A Mock subclass to mock `discord.Role` objects. - Instances of this class will follow the specifications of `disnake.Role` instances. For more + Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ spec_set = role_instance @@ -198,40 +198,40 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): 'id': next(self.discord_id), 'name': 'role', 'position': 1, - 'colour': disnake.Colour(0xdeadbf), - 'permissions': disnake.Permissions(), + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), } super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if isinstance(self.colour, int): - self.colour = disnake.Colour(self.colour) + self.colour = discord.Colour(self.colour) if isinstance(self.permissions, int): - self.permissions = disnake.Permissions(self.permissions) + self.permissions = discord.Permissions(self.permissions) if 'mention' not in kwargs: self.mention = f'&{self.name}' def __lt__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position < other.position def __ge__(self, other): - """Simplified position-based comparisons similar to those of `disnake.Role`.""" + """Simplified position-based comparisons similar to those of `discord.Role`.""" return self.position >= other.position -# Create a Member instance to get a realistic Mock of `disnake.Member` +# Create a Member instance to get a realistic Mock of `discord.Member` member_data = {'user': 'lemon', 'roles': [1]} state_mock = unittest.mock.MagicMock() -member_instance = disnake.Member(data=member_data, guild=guild_instance, state=state_mock) +member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock Member objects. - Instances of this class will follow the specifications of `disnake.Member` instances. For more + Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ spec_set = member_instance @@ -249,11 +249,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" -# Create a User instance to get a realistic Mock of `disnake.User` +# Create a User instance to get a realistic Mock of `discord.User` _user_data_mock = collections.defaultdict(unittest.mock.MagicMock, { "accent_color": 0 }) -user_instance = disnake.User( +user_instance = discord.User( data=unittest.mock.MagicMock(get=unittest.mock.Mock(side_effect=_user_data_mock.get)), state=unittest.mock.MagicMock() ) @@ -263,7 +263,7 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): """ A Mock subclass to mock User objects. - Instances of this class will follow the specifications of `disnake.User` instances. For more + Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ spec_set = user_instance @@ -305,13 +305,17 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Bot objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Bot` instances. + Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ spec_set = Bot( command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop(), redis_session=unittest.mock.MagicMock(), + http_session=unittest.mock.MagicMock(), + allowed_roles=[1], + guild_id=1, + intents=discord.Intents.all(), ) additional_spec_asyncs = ("wait_for", "redis_ready") @@ -322,9 +326,10 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): self.api_client = MockAPIClient(loop=self.loop) self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True) + self.add_cog = unittest.mock.AsyncMock() -# Create a TextChannel instance to get a realistic MagicMock of `disnake.TextChannel` +# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` channel_data = { 'id': 1, 'type': 'TextChannel', @@ -334,20 +339,22 @@ channel_data = { 'position': 1, 'nsfw': False, 'last_message_id': 1, + 'bitrate': 1337, + 'user_limit': 25, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -text_channel_instance = disnake.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) channel_data["type"] = "VoiceChannel" -voice_channel_instance = disnake.VoiceChannel(state=state, guild=guild, data=channel_data) +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = text_channel_instance @@ -364,7 +371,7 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock VoiceChannel objects. - Instances of this class will follow the specifications of `disnake.VoiceChannel` instances. For + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = voice_channel_instance @@ -381,14 +388,14 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): state = unittest.mock.MagicMock() me = unittest.mock.MagicMock() dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} -dm_channel_instance = disnake.DMChannel(me=me, state=state, data=dm_channel_data) +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ A MagicMock subclass to mock TextChannel objects. - Instances of this class will follow the specifications of `disnake.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance @@ -398,17 +405,17 @@ class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -# Create CategoryChannel instance to get a realistic MagicMock of `disnake.CategoryChannel` +# Create CategoryChannel instance to get a realistic MagicMock of `discord.CategoryChannel` category_channel_data = { 'id': 1, - 'type': disnake.ChannelType.category, + 'type': discord.ChannelType.category, 'name': 'category', 'position': 1, } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -category_channel_instance = disnake.CategoryChannel( +category_channel_instance = discord.CategoryChannel( state=state, guild=guild, data=category_channel_data ) @@ -419,13 +426,13 @@ class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): super().__init__(**collections.ChainMap(default_kwargs, kwargs)) -# Create a Message instance to get a realistic MagicMock of `disnake.Message` +# Create a Message instance to get a realistic MagicMock of `discord.Message` message_data = { 'id': 1, 'webhook_id': 431341013479718912, 'attachments': [], 'embeds': [], - 'application': 'Python Discord', + 'application': {"id": 4, "description": "A Python Bot", "name": "Python Discord", "icon": None}, 'activity': 'mocking', 'channel': unittest.mock.MagicMock(), 'edited_timestamp': '2019-10-14T15:33:48+00:00', @@ -438,10 +445,11 @@ message_data = { } state = unittest.mock.MagicMock() channel = unittest.mock.MagicMock() -message_instance = disnake.Message(state=state, channel=channel, data=message_data) +channel.type = discord.ChannelType.text +message_instance = discord.Message(state=state, channel=channel, data=message_data) -# Create a Context instance to get a realistic MagicMock of `disnake.ext.commands.Context` +# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` context_instance = Context( message=unittest.mock.MagicMock(), prefix="$", @@ -455,7 +463,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Context objects. - Instances of this class will follow the specifications of `disnake.ext.commands.Context` + Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ spec_set = context_instance @@ -471,24 +479,46 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) -attachment_instance = disnake.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Attachment objects. - Instances of this class will follow the specifications of `disnake.Attachment` instances. For + Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ spec_set = attachment_instance +message_reference_instance = discord.MessageReference( + message_id=unittest.mock.MagicMock(id=1), + channel_id=unittest.mock.MagicMock(id=2), + guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock MessageReference objects. + + Instances of this class will follow the specification of `discord.MessageReference` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = message_reference_instance + + def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): + super().__init__(**kwargs) + referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) + self.resolved = MockMessage(author=referenced_msg_author) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. - Instances of this class will follow the specifications of `disnake.Message` instances. For more + Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ spec_set = message_instance @@ -501,14 +531,14 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} -emoji_instance = disnake.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) +emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Emoji objects. - Instances of this class will follow the specifications of `disnake.Emoji` instances. For more + Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = emoji_instance @@ -518,27 +548,27 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): self.guild = kwargs.get('guild', MockGuild()) -partial_emoji_instance = disnake.PartialEmoji(animated=False, name='guido') +partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido') class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock PartialEmoji objects. - Instances of this class will follow the specifications of `disnake.PartialEmoji` instances. For + Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ spec_set = partial_emoji_instance -reaction_instance = disnake.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) +reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) class MockReaction(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Reaction objects. - Instances of this class will follow the specifications of `disnake.Reaction` instances. For + Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ spec_set = reaction_instance @@ -556,14 +586,14 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.__str__.return_value = str(self.emoji) -webhook_instance = disnake.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock()) class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. - Instances of this class will follow the specifications of `disnake.Webhook` instances. For + Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ spec_set = webhook_instance diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c5e799a85..f3040b305 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,20 +2,20 @@ import asyncio import unittest import unittest.mock -import disnake +import discord from tests import helpers class DiscordMocksTests(unittest.TestCase): - """Tests for our specialized disnake mocks.""" + """Tests for our specialized discord.py mocks.""" def test_mock_role_default_initialization(self): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `disnake.Role` pass - self.assertIsInstance(role, disnake.Role) + # The `spec` argument makes sure `isistance` checks with `discord.Role` pass + self.assertIsInstance(role, discord.Role) self.assertEqual(role.name, "role") self.assertEqual(role.position, 1) @@ -61,8 +61,8 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of Mockmember results in the correct object.""" member = helpers.MockMember() - # The `spec` argument makes sure `isistance` checks with `disnake.Member` pass - self.assertIsInstance(member, disnake.Member) + # The `spec` argument makes sure `isistance` checks with `discord.Member` pass + self.assertIsInstance(member, discord.Member) self.assertEqual(member.name, "member") self.assertListEqual(member.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) @@ -86,18 +86,18 @@ class DiscordMocksTests(unittest.TestCase): """Test if MockMember accepts and sets abitrary keyword arguments.""" member = helpers.MockMember( nick="Dino Man", - colour=disnake.Colour.default(), + colour=discord.Colour.default(), ) self.assertEqual(member.nick, "Dino Man") - self.assertEqual(member.colour, disnake.Colour.default()) + self.assertEqual(member.colour, discord.Colour.default()) def test_mock_guild_default_initialization(self): """Test if the default initialization of Mockguild results in the correct object.""" guild = helpers.MockGuild() - # The `spec` argument makes sure `isistance` checks with `disnake.Guild` pass - self.assertIsInstance(guild, disnake.Guild) + # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass + self.assertIsInstance(guild, discord.Guild) self.assertListEqual(guild.roles, [helpers.MockRole(name="@everyone", position=1, id=0)]) self.assertListEqual(guild.members, []) @@ -127,15 +127,15 @@ class DiscordMocksTests(unittest.TestCase): """Tests if MockBot initializes with the correct values.""" bot = helpers.MockBot() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Bot` pass - self.assertIsInstance(bot, disnake.ext.commands.Bot) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass + self.assertIsInstance(bot, discord.ext.commands.Bot) def test_mock_context_default_initialization(self): """Tests if MockContext initializes with the correct values.""" context = helpers.MockContext() - # The `spec` argument makes sure `isistance` checks with `disnake.ext.commands.Context` pass - self.assertIsInstance(context, disnake.ext.commands.Context) + # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass + self.assertIsInstance(context, discord.ext.commands.Context) self.assertIsInstance(context.bot, helpers.MockBot) self.assertIsInstance(context.guild, helpers.MockGuild) @@ -327,7 +327,7 @@ class MockObjectTests(unittest.TestCase): def test_spec_propagation_of_mock_subclasses(self): """Test if the `spec` does not propagate to attributes of the mock object.""" test_values = ( - (helpers.MockGuild, "region"), + (helpers.MockGuild, "features"), (helpers.MockRole, "mentionable"), (helpers.MockMember, "display_name"), (helpers.MockBot, "owner_id"), |