aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/base.py24
-rw-r--r--tests/bot/exts/backend/test_error_handler.py103
-rw-r--r--tests/bot/exts/info/test_information.py74
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py49
-rw-r--r--tests/bot/exts/moderation/infraction/test_utils.py29
-rw-r--r--tests/bot/exts/moderation/test_incidents.py71
-rw-r--r--tests/bot/exts/moderation/test_silence.py103
-rw-r--r--tests/bot/exts/utils/test_snekbox.py84
-rw-r--r--tests/bot/rules/test_mentions.py131
-rw-r--r--tests/helpers.py24
-rw-r--r--tests/test_helpers.py2
11 files changed, 477 insertions, 217 deletions
diff --git a/tests/base.py b/tests/base.py
index 5e304ea9d..4863a1821 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import Dict
import discord
+from async_rediscache import RedisSession
from discord.ext import commands
from bot.log import get_logger
@@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):
await cmd.can_run(ctx)
self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions)
+
+
+class RedisTestCase(unittest.IsolatedAsyncioTestCase):
+ """
+ Use this as a base class for any test cases that require a redis session.
+
+ This will prepare a fresh redis instance for each test function, and will
+ not make any assertions on its own. Tests can mutate the instance as they wish.
+ """
+
+ session = None
+
+ async def flush(self):
+ """Flush everything from the redis database to prevent carry-overs between tests."""
+ await self.session.client.flushall()
+
+ async def asyncSetUp(self):
+ self.session = await RedisSession(use_fakeredis=True).connect()
+ await self.flush()
+
+ async def asyncTearDown(self):
+ if self.session:
+ await self.session.client.close()
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 0a58126e7..7562f6aa8 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError
from discord.ext.commands import errors
from bot.errors import InvalidInfractedUserError, LockedResourceError
-from bot.exts.backend.error_handler import ErrorHandler, setup
+from bot.exts.backend import error_handler
from bot.exts.info.tags import Tags
from bot.exts.moderation.silence import Silence
from bot.utils.checks import InWhitelistCheckFailure
@@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = MockBot()
self.ctx = MockContext(bot=self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_error_handler_already_handled(self):
"""Should not do anything when error is already handled by local error handler."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
error = errors.CommandError()
error.handled = "foo"
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
self.ctx.send.assert_not_awaited()
async def test_error_handler_command_not_found_error_not_invoked_by_handler(self):
@@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
"called_try_get_tag": True
}
)
- cog = ErrorHandler(self.bot)
- cog.try_silence = AsyncMock()
- cog.try_get_tag = AsyncMock()
- cog.try_run_eval = AsyncMock(return_value=False)
+ self.cog.try_silence = AsyncMock()
+ self.cog.try_get_tag = AsyncMock()
+ self.cog.try_run_eval = AsyncMock(return_value=False)
for case in test_cases:
with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):
self.ctx.reset_mock()
- cog.try_silence.reset_mock(return_value=True)
- cog.try_get_tag.reset_mock()
+ self.cog.try_silence.reset_mock(return_value=True)
+ self.cog.try_get_tag.reset_mock()
- cog.try_silence.return_value = case["try_silence_return"]
+ self.cog.try_silence.return_value = case["try_silence_return"]
self.ctx.channel.id = 1234
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
if case["try_silence_return"]:
- cog.try_get_tag.assert_not_awaited()
- cog.try_silence.assert_awaited_once()
+ self.cog.try_get_tag.assert_not_awaited()
+ self.cog.try_silence.assert_awaited_once()
else:
- cog.try_silence.assert_awaited_once()
- cog.try_get_tag.assert_awaited_once()
+ self.cog.try_silence.assert_awaited_once()
+ self.cog.try_get_tag.assert_awaited_once()
self.ctx.send.assert_not_awaited()
@@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
"""Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""
ctx = MockContext(bot=self.bot, invoked_from_error_handler=True)
- cog = ErrorHandler(self.bot)
- cog.try_silence = AsyncMock()
- cog.try_get_tag = AsyncMock()
- cog.try_run_eval = AsyncMock()
+ self.cog.try_silence = AsyncMock()
+ self.cog.try_get_tag = AsyncMock()
+ self.cog.try_run_eval = AsyncMock()
error = errors.CommandNotFound()
- self.assertIsNone(await cog.on_command_error(ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(ctx, error))
- cog.try_silence.assert_not_awaited()
- cog.try_get_tag.assert_not_awaited()
- cog.try_run_eval.assert_not_awaited()
+ self.cog.try_silence.assert_not_awaited()
+ self.cog.try_get_tag.assert_not_awaited()
+ self.cog.try_run_eval.assert_not_awaited()
self.ctx.send.assert_not_awaited()
async def test_error_handler_user_input_error(self):
"""Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
- cog.handle_user_input_error = AsyncMock()
+ self.cog.handle_user_input_error = AsyncMock()
error = errors.UserInputError()
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
- cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
+ self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
async def test_error_handler_check_failure(self):
"""Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
- cog.handle_check_failure = AsyncMock()
+ self.cog.handle_check_failure = AsyncMock()
error = errors.CheckFailure()
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
- cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
+ self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
async def test_error_handler_command_on_cooldown(self):
"""Should send error with `ctx.send` when error is `CommandOnCooldown`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
error = errors.CommandOnCooldown(10, 9, type=None)
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
self.ctx.send.assert_awaited_once_with(error)
async def test_error_handler_command_invoke_error(self):
"""Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
- cog = ErrorHandler(self.bot)
- cog.handle_api_error = AsyncMock()
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_api_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
test_cases = (
{
"args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))),
- "expect_mock_call": cog.handle_api_error
+ "expect_mock_call": self.cog.handle_api_error
},
{
"args": (self.ctx, errors.CommandInvokeError(TypeError)),
- "expect_mock_call": cog.handle_unexpected_error
+ "expect_mock_call": self.cog.handle_unexpected_error
},
{
"args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))),
@@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for case in test_cases:
with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):
self.ctx.send.reset_mock()
- self.assertIsNone(await cog.on_command_error(*case["args"]))
+ self.assertIsNone(await self.cog.on_command_error(*case["args"]))
if case["expect_mock_call"] == "send":
self.ctx.send.assert_awaited_once()
else:
@@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
async def test_error_handler_conversion_error(self):
"""Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
- cog = ErrorHandler(self.bot)
- cog.handle_api_error = AsyncMock()
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_api_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
cases = (
{
"error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())),
- "mock_function_to_call": cog.handle_api_error
+ "mock_function_to_call": self.cog.handle_api_error
},
{
"error": errors.ConversionError(AsyncMock(), TypeError),
- "mock_function_to_call": cog.handle_unexpected_error
+ "mock_function_to_call": self.cog.handle_unexpected_error
}
)
for case in cases:
with self.subTest(**case):
- self.assertIsNone(await cog.on_command_error(self.ctx, case["error"]))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"]))
case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)
async def test_error_handler_two_other_errors(self):
"""Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`."""
- cog = ErrorHandler(self.bot)
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
errs = (
errors.MaxConcurrencyReached(1, MagicMock()),
errors.ExtensionError(name="foo")
@@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for err in errs:
with self.subTest(error=err):
- cog.handle_unexpected_error.reset_mock()
- self.assertIsNone(await cog.on_command_error(self.ctx, err))
- cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
+ self.cog.handle_unexpected_error.reset_mock()
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, err))
+ self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
@patch("bot.exts.backend.error_handler.log")
async def test_error_handler_other_errors(self, log_mock):
"""Should `log.debug` other errors."""
- cog = ErrorHandler(self.bot)
error = errors.DisabledCommand() # Use this just as a other error
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
log_mock.debug.assert_called_once()
@@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
self.silence = Silence(self.bot)
self.bot.get_command.return_value = self.silence.silence
self.ctx = MockContext(bot=self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_try_silence_context_invoked_from_error_handler(self):
"""Should set `Context.invoked_from_error_handler` to `True`."""
@@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
self.bot = MockBot()
self.ctx = MockContext()
self.tag = Tags(self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
self.bot.get_command.return_value = self.tag.get_command
async def test_try_get_tag_get_command(self):
@@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = MockBot()
self.ctx = MockContext(bot=self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_handle_input_error_handler_errors(self):
"""Should handle each error probably."""
@@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase):
async def test_setup(self):
"""Should call `bot.add_cog` with `ErrorHandler`."""
bot = MockBot()
- await setup(bot)
+ await error_handler.setup(bot)
bot.add_cog.assert_awaited_once()
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index d896b7652..9f5143c01 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -2,6 +2,7 @@ import textwrap
import unittest
import unittest.mock
from datetime import datetime
+from textwrap import shorten
import discord
@@ -573,3 +574,76 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):
create_embed.assert_called_once_with(ctx, self.target, False)
ctx.send.assert_called_once()
+
+
+class RuleCommandTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the `!rule` command."""
+
+ def setUp(self) -> None:
+ """Set up steps executed before each test is run."""
+ self.bot = helpers.MockBot()
+ self.cog = information.Information(self.bot)
+ self.ctx = helpers.MockContext(author=helpers.MockMember(id=1, name="Bellaluma"))
+ self.full_rules = [
+ (
+ "First rule",
+ ["first", "number_one"]
+ ),
+ (
+ "Second rule",
+ ["second", "number_two"]
+ ),
+ (
+ "Third rule",
+ ["third", "number_three"]
+ )
+ ]
+ self.bot.api_client.get.return_value = self.full_rules
+
+ async def test_return_none_if_one_rule_number_is_invalid(self):
+
+ test_cases = [
+ (('1', '6', '7', '8'), (6, 7, 8)),
+ (('10', "first"), (10, )),
+ (("first", 10), (10, ))
+ ]
+
+ for raw_user_input, extracted_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ invalid = ", ".join(
+ str(rule_number) for rule_number in extracted_rule_numbers
+ if rule_number < 1 or rule_number > len(self.full_rules))
+
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+
+ self.assertEqual(
+ self.ctx.send.call_args,
+ unittest.mock.call(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ...")))
+ self.assertEqual(None, final_rule_numbers)
+
+ async def test_return_correct_rule_numbers(self):
+
+ test_cases = [
+ (("1", "2", "first"), {1, 2}),
+ (("1", "hello", "2", "second"), {1}),
+ (("second", "third", "unknown", "999"), {2, 3})
+ ]
+
+ for raw_user_input, expected_matched_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)
+
+ async def test_return_default_rules_when_no_input_or_no_match_are_found(self):
+ test_cases = [
+ ((), None),
+ (("hello", "2", "second"), None),
+ (("hello", "999"), None),
+ ]
+
+ for raw_user_input, expected_matched_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ embed = self.ctx.send.call_args.kwargs['embed']
+ self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description)
+ self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index 052048053..b78328137 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -35,17 +35,20 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.apply_infraction = AsyncMock()
self.bot.get_cog.return_value = AsyncMock()
self.cog.mod_log.ignore = Mock()
- self.ctx.guild.ban = Mock()
+ self.ctx.guild.ban = AsyncMock()
await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000)
- self.ctx.guild.ban.assert_called_once_with(
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar", "purge": ""}, self.target, ANY
+ )
+
+ action = self.cog.apply_infraction.call_args.args[-1]
+ await action()
+ self.ctx.guild.ban.assert_awaited_once_with(
self.target,
reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),
delete_message_days=0
)
- self.cog.apply_infraction.assert_awaited_once_with(
- self.ctx, {"foo": "bar", "purge": ""}, self.target, self.ctx.guild.ban.return_value
- )
@patch("bot.exts.moderation.infraction._utils.post_infraction")
async def test_apply_kick_reason_truncation(self, post_infraction_mock):
@@ -54,14 +57,17 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.apply_infraction = AsyncMock()
self.cog.mod_log.ignore = Mock()
- self.target.kick = Mock()
+ self.target.kick = AsyncMock()
await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000)
- self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
self.cog.apply_infraction.assert_awaited_once_with(
- self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
+ self.ctx, {"foo": "bar"}, self.target, ANY
)
+ action = self.cog.apply_infraction.call_args.args[-1]
+ await action()
+ self.target.kick.assert_awaited_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
+
@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)
class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):
@@ -79,19 +85,25 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):
"""Should call voice mute applying function without expiry."""
self.cog.apply_voice_mute = AsyncMock()
self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar"))
- self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None)
+ self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None)
async def test_temporary_voice_mute(self):
"""Should call voice mute applying function with expiry."""
self.cog.apply_voice_mute = AsyncMock()
self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar"))
- self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz")
+ self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz")
async def test_voice_unmute(self):
"""Should call infraction pardoning function."""
self.cog.pardon_infraction = AsyncMock()
+ self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user, pardon_reason="foobar"))
+ self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user, "foobar")
+
+ async def test_voice_unmute_reasonless(self):
+ """Should call infraction pardoning function without a pardon reason."""
+ self.cog.pardon_infraction = AsyncMock()
self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user))
- self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user)
+ self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user, None)
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
@@ -141,8 +153,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):
async def action_tester(self, action, reason: str) -> None:
"""Helper method to test voice mute action."""
- self.assertTrue(inspect.iscoroutine(action))
- await action
+ self.assertTrue(inspect.iscoroutinefunction(action))
+ await action()
self.user.move_to.assert_called_once_with(None, reason=ANY)
self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason=reason)
@@ -189,13 +201,14 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):
user = MockUser()
await self.cog.voicemute(self.cog, self.ctx, user, reason=None)
- post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None)
+ post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True,
+ duration_or_expiry=None)
apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY)
# Test action
action = self.cog.apply_infraction.call_args[0][-1]
- self.assertTrue(inspect.iscoroutine(action))
- await action
+ self.assertTrue(inspect.iscoroutinefunction(action))
+ await action()
async def test_voice_unmute_user_not_found(self):
"""Should include info to return dict when user was not found from guild."""
@@ -273,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):
self.user,
"FooBar",
purge_days=1,
- expires_at=None,
+ duration_or_expiry=None,
)
async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self):
@@ -285,7 +298,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):
self.ctx,
self.user,
"FooBar",
- expires_at=None,
+ duration_or_expiry=None,
)
@patch("bot.exts.moderation.infraction.infractions.Age")
diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py
index 5cf02033d..29dadf372 100644
--- a/tests/bot/exts/moderation/infraction/test_utils.py
+++ b/tests/bot/exts/moderation/infraction/test_utils.py
@@ -1,7 +1,7 @@
import unittest
from collections import namedtuple
from datetime import datetime
-from unittest.mock import AsyncMock, MagicMock, call, patch
+from unittest.mock import AsyncMock, MagicMock, patch
from botcore.site_api import ResponseCodeError
from discord import Embed, Forbidden, HTTPException, NotFound
@@ -309,8 +309,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):
async def test_normal_post_infraction(self):
"""Should return response from POST request if there are no errors."""
- now = datetime.now()
- payload = {
+ now = datetime.utcnow()
+ expected = {
"actor": self.ctx.author.id,
"hidden": True,
"reason": "Test reason",
@@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):
"user": self.member.id,
"active": False,
"expires_at": now.isoformat(),
- "dm_sent": False
+ "dm_sent": False,
}
self.ctx.bot.api_client.post.return_value = "foo"
actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False)
-
self.assertEqual(actual, "foo")
- self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload)
+ self.ctx.bot.api_client.post.assert_awaited_once()
+
+ # Since `last_applied` is based on current time, just check if expected is a subset of payload
+ payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"]
+ self.assertEqual(payload, payload | expected)
async def test_unknown_error_post_infraction(self):
"""Should send an error message to chat when a non-400 error occurs."""
@@ -349,19 +352,25 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):
@patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar")
async def test_first_fail_second_success_user_post_infraction(self, post_user_mock):
"""Should post the user if they don't exist, POST infraction again, and return the response if successful."""
- payload = {
+ expected = {
"actor": self.ctx.author.id,
"hidden": False,
"reason": "Test reason",
"type": "mute",
"user": self.user.id,
"active": True,
- "dm_sent": False
+ "dm_sent": False,
}
self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"]
-
actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason")
self.assertEqual(actual, "foo")
- self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2)
+ await_args = self.bot.api_client.post.await_args_list
+ self.assertEqual(len(await_args), 2, "Expected 2 awaits")
+
+ # Since `last_applied` is based on current time, just check if expected is a subset of payload
+ for args in await_args:
+ payload: dict = args.kwargs["json"]
+ self.assertEqual(payload, payload | expected)
+
post_user_mock.assert_awaited_once_with(self.ctx, self.user)
diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py
index cfe0c4b03..53d98360c 100644
--- a/tests/bot/exts/moderation/test_incidents.py
+++ b/tests/bot/exts/moderation/test_incidents.py
@@ -1,4 +1,5 @@
import asyncio
+import datetime
import enum
import logging
import typing as t
@@ -8,16 +9,19 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
import aiohttp
import discord
-from async_rediscache import RedisSession
from bot.constants import Colours
from bot.exts.moderation import incidents
from bot.utils.messages import format_user
+from bot.utils.time import TimestampFormats, discord_timestamp
+from tests.base import RedisTestCase
from tests.helpers import (
MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,
MockUser
)
+CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc)
+
class MockAsyncIterable:
"""
@@ -100,30 +104,45 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):
async def test_make_embed_actioned(self):
"""Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`."""
- embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember())
+ embed, file = await incidents.make_embed(
+ incident=MockMessage(created_at=CURRENT_TIME),
+ outcome=incidents.Signal.ACTIONED,
+ actioned_by=MockMember()
+ )
self.assertEqual(embed.colour.value, Colours.soft_green)
self.assertIn("Actioned", embed.footer.text)
async def test_make_embed_not_actioned(self):
"""Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`."""
- embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember())
+ embed, file = await incidents.make_embed(
+ incident=MockMessage(created_at=CURRENT_TIME),
+ outcome=incidents.Signal.NOT_ACTIONED,
+ actioned_by=MockMember()
+ )
self.assertEqual(embed.colour.value, Colours.soft_red)
self.assertIn("Rejected", embed.footer.text)
async def test_make_embed_content(self):
"""Incident content appears as embed description."""
- incident = MockMessage(content="this is an incident")
+ incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME)
+
+ reported_timestamp = discord_timestamp(CURRENT_TIME)
+ relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE)
+
embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember())
- self.assertEqual(incident.content, embed.description)
+ self.assertEqual(
+ f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*",
+ embed.description
+ )
async def test_make_embed_with_attachment_succeeds(self):
"""Incident's attachment is downloaded and displayed in the embed's image field."""
file = MagicMock(discord.File, filename="bigbadjoe.jpg")
attachment = MockAttachment(filename="bigbadjoe.jpg")
- incident = MockMessage(content="this is an incident", attachments=[attachment])
+ incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)
# Patch `download_file` to return our `file`
with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)):
@@ -135,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):
async def test_make_embed_with_attachment_fails(self):
"""Incident's attachment fails to download, proxy url is linked instead."""
attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg")
- incident = MockMessage(content="this is an incident", attachments=[attachment])
+ incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)
# Patch `download_file` to return None as if the download failed
with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)):
@@ -270,7 +289,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase):
self.incident.add_reaction.assert_not_called()
-class TestIncidents(unittest.IsolatedAsyncioTestCase):
+class TestIncidents(RedisTestCase):
"""
Tests for bound methods of the `Incidents` cog.
@@ -279,22 +298,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):
the instance as they wish.
"""
- session = None
-
- async def flush(self):
- """Flush everything from the database to prevent carry-overs between tests."""
- with await self.session.pool as connection:
- await connection.flushall()
-
- async def asyncSetUp(self): # noqa: N802
- self.session = RedisSession(use_fakeredis=True)
- await self.session.connect()
- await self.flush()
-
- async def asyncTearDown(self): # noqa: N802
- if self.session:
- await self.session.close()
-
def setUp(self):
"""
Prepare a fresh `Incidents` instance for each test.
@@ -365,7 +368,6 @@ class TestCrawlIncidents(TestIncidents):
class TestArchive(TestIncidents):
"""Tests for the `Incidents.archive` coroutine."""
-
async def test_archive_webhook_not_found(self):
"""
Method recovers and returns False when the webhook is not found.
@@ -375,7 +377,11 @@ class TestArchive(TestIncidents):
"""
self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404)
self.assertFalse(
- await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember())
+ await self.cog_instance.archive(
+ incident=MockMessage(created_at=CURRENT_TIME),
+ outcome=MagicMock(),
+ actioned_by=MockMember()
+ )
)
async def test_archive_relays_incident(self):
@@ -391,7 +397,7 @@ class TestArchive(TestIncidents):
# Define our own `incident` to be archived
incident = MockMessage(
content="this is an incident",
- author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")),
+ author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")),
id=123,
)
built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this
@@ -422,7 +428,7 @@ class TestArchive(TestIncidents):
webhook = MockAsyncWebhook()
self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook)
- message_from_clyde = MockMessage(author=MockUser(name="clyde the great"))
+ message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME)
await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember())
self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"])
@@ -521,12 +527,13 @@ class TestProcessEvent(TestIncidents):
async def test_process_event_confirmation_task_is_awaited(self):
"""Task given by `Incidents.make_confirmation_task` is awaited before method exits."""
mock_task = AsyncMock()
+ mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)])
with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):
await self.cog_instance.process_event(
reaction=incidents.Signal.ACTIONED.value,
- incident=MockMessage(id=123),
- member=MockMember(roles=[MockRole(id=1)])
+ incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME),
+ member=mock_member
)
mock_task.assert_awaited()
@@ -545,7 +552,7 @@ class TestProcessEvent(TestIncidents):
with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):
await self.cog_instance.process_event(
reaction=incidents.Signal.ACTIONED.value,
- incident=MockMessage(id=123),
+ incident=MockMessage(id=123, created_at=CURRENT_TIME),
member=MockMember(roles=[MockRole(id=1)])
)
except asyncio.TimeoutError:
@@ -656,7 +663,7 @@ class TestOnRawReactionAdd(TestIncidents):
emoji="reaction",
)
- async def asyncSetUp(self): # noqa: N802
+ async def asyncSetUp(self):
"""
Prepare an empty task and assign it as `crawl_task`.
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 65aecad28..2622f46a7 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -1,4 +1,3 @@
-import asyncio
import itertools
import unittest
from datetime import datetime, timezone
@@ -6,31 +5,15 @@ from typing import List, Tuple
from unittest import mock
from unittest.mock import AsyncMock, Mock
-from async_rediscache import RedisSession
from discord import PermissionOverwrite
from bot.constants import Channels, Guild, MODERATION_ROLES, Roles
from bot.exts.moderation import silence
+from tests.base import RedisTestCase
from tests.helpers import (
MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec
)
-redis_session = None
-redis_loop = asyncio.get_event_loop()
-
-
-def setUpModule(): # noqa: N802
- """Create and connect to the fakeredis session."""
- global redis_session
- redis_session = RedisSession(use_fakeredis=True)
- redis_loop.run_until_complete(redis_session.connect())
-
-
-def tearDownModule(): # noqa: N802
- """Close the fakeredis session."""
- if redis_session:
- redis_loop.run_until_complete(redis_session.close())
-
# Have to subclass it because builtins can't be patched.
class PatchedDatetime(datetime):
@@ -39,8 +22,24 @@ class PatchedDatetime(datetime):
now = mock.create_autospec(datetime, "now")
-class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
+class SilenceTest(RedisTestCase):
+ """A base class for Silence tests that correctly sets up the cog and redis."""
+
+ @autospec(silence, "Scheduler", pass_mocks=False)
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id))
+ self.cog = silence.Silence(self.bot)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def asyncSetUp(self) -> None:
+ await super().asyncSetUp()
+ await self.cog.cog_load() # Populate instance attributes.
+
+
+class SilenceNotifierTests(SilenceTest):
def setUp(self) -> None:
+ super().setUp()
self.alert_channel = MockTextChannel()
self.notifier = silence.SilenceNotifier(self.alert_channel)
self.notifier.stop = self.notifier_stop_mock = Mock()
@@ -105,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
-class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
+class SilenceCogTests(SilenceTest):
"""Tests for the general functionality of the Silence cog."""
- @autospec(silence, "Scheduler", pass_mocks=False)
- def setUp(self) -> None:
- self.bot = MockBot()
- self.cog = silence.Silence(self.bot)
-
@autospec(silence, "SilenceNotifier", pass_mocks=False)
async def test_cog_load_got_guild(self):
"""Bot got guild after it became available."""
- await self.cog.cog_load()
self.bot.wait_until_guild_available.assert_awaited_once()
self.bot.get_guild.assert_called_once_with(Guild.id)
@autospec(silence, "SilenceNotifier", pass_mocks=False)
async def test_cog_load_got_channels(self):
"""Got channels from bot."""
- self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
-
await self.cog.cog_load()
self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)
@autospec(silence, "SilenceNotifier")
async def test_cog_load_got_notifier(self, notifier):
"""Notifier was started with channel."""
- self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
-
await self.cog.cog_load()
notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))
self.assertEqual(self.cog.notifier, notifier.return_value)
@@ -245,15 +234,9 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2)
-class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):
+class SilenceArgumentParserTests(SilenceTest):
"""Tests for the silence argument parser utility function."""
- def setUp(self):
- self.bot = MockBot()
- self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
-
@autospec(silence.Silence, "send_message", pass_mocks=False)
@autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)
@autospec(silence.Silence, "parse_silence_args")
@@ -321,17 +304,19 @@ class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
-class RescheduleTests(unittest.IsolatedAsyncioTestCase):
+class RescheduleTests(RedisTestCase):
"""Tests for the rescheduling of cached unsilences."""
- @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
- def setUp(self):
+ @autospec(silence, "Scheduler", pass_mocks=False)
+ def setUp(self) -> None:
self.bot = MockBot()
self.cog = silence.Silence(self.bot)
self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper)
- with mock.patch.object(self.cog, "_reschedule", autospec=True):
- asyncio.run(self.cog.cog_load()) # Populate instance attributes.
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def asyncSetUp(self) -> None:
+ await super().asyncSetUp()
+ await self.cog.cog_load() # Populate instance attributes.
async def test_skipped_missing_channel(self):
"""Did nothing because the channel couldn't be retrieved."""
@@ -406,22 +391,14 @@ def voice_sync_helper(function):
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
-class SilenceTests(unittest.IsolatedAsyncioTestCase):
+class SilenceTests(SilenceTest):
"""Tests for the silence command and its related helper methods."""
- @autospec(silence.Silence, "_reschedule", pass_mocks=False)
- @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
def setUp(self) -> None:
- self.bot = MockBot(get_channel=lambda _: MockTextChannel())
- self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
+ super().setUp()
# Avoid unawaited coroutine warnings.
self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close()
-
- asyncio.run(self.cog.cog_load()) # Populate instance attributes.
-
self.text_channel = MockTextChannel()
self.text_overwrite = PermissionOverwrite(
send_messages=True,
@@ -679,24 +656,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)
-class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
+class UnsilenceTests(SilenceTest):
"""Tests for the unsilence command and its related helper methods."""
- @autospec(silence.Silence, "_reschedule", pass_mocks=False)
- @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
def setUp(self) -> None:
- self.bot = MockBot(get_channel=lambda _: MockTextChannel())
- self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
-
- overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
- self.cog.previous_overwrites = overwrites_cache
-
- asyncio.run(self.cog.cog_load()) # Populate instance attributes.
+ super().setUp()
self.cog.scheduler.__contains__.return_value = True
- overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
self.text_channel = MockTextChannel()
self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False)
self.text_channel.overwrites_for.return_value = self.text_overwrite
@@ -705,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)
self.voice_channel.overwrites_for.return_value = self.voice_overwrite
+ async def asyncSetUp(self) -> None:
+ await super().asyncSetUp()
+ overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
+ self.cog.previous_overwrites = overwrites_cache
+
+ overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
+
async def test_sent_correct_message(self):
"""Appropriate failure/success message was sent by the command."""
unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index b870a9945..b1f32c210 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -6,9 +6,10 @@ from discord import AllowedMentions
from discord.ext import commands
from bot import constants
+from bot.errors import LockedResourceError
from bot.exts.utils import snekbox
from bot.exts.utils.snekbox import Snekbox
-from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
+from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser
class SnekboxTests(unittest.IsolatedAsyncioTestCase):
@@ -26,7 +27,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
context_manager.__aenter__.return_value = resp
self.bot.http_session.post.return_value = context_manager
- self.assertEqual(await self.cog.post_job("import random"), "return")
+ self.assertEqual(await self.cog.post_job("import random", "3.10"), "return")
self.bot.http_session.post.assert_called_with(
constants.URLs.snekbox_eval_api,
json={"input": "import random"},
@@ -84,28 +85,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def test_get_results_message(self):
"""Return error and message according to the eval result."""
cases = (
- ('ERROR', None, ('Your eval job has failed', 'ERROR')),
- ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')),
- ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred'))
+ ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')),
+ ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')),
+ ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred'))
)
for stdout, returncode, expected in cases:
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval')
+ actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11')
self.assertEqual(actual, expected)
@patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)
def test_get_results_message_invalid_signal(self, mock_signals: Mock):
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),
- ('Your eval job has completed with return code 127', '')
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
+ ('Your 3.11 eval job has completed with return code 127', '')
)
@patch('bot.exts.utils.snekbox.Signals')
def test_get_results_message_valid_signal(self, mock_signals: Mock):
mock_signals.return_value.name = 'SIGTEST'
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),
- ('Your eval job has completed with return code 127 (SIGTEST)', '')
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
+ ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '')
)
def test_get_status_emoji(self):
@@ -179,9 +180,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.send_job = AsyncMock(return_value=response)
self.cog.continue_job = AsyncMock(return_value=(None, None))
- await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode'])
- self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval')
- self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command)
+ await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
+ self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval')
+ self.cog.continue_job.assert_called_once_with(ctx, response, 'eval')
async def test_eval_command_evaluate_twice(self):
"""Test the eval and re-eval command procedure."""
@@ -192,23 +193,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.continue_job = AsyncMock()
self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None))
- await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode'])
+ await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
self.cog.send_job.assert_called_with(
- ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval'
+ ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval'
)
- self.cog.continue_job.assert_called_with(ctx, response, ctx.command)
+ self.cog.continue_job.assert_called_with(ctx, response, 'eval')
async def test_eval_command_reject_two_eval_at_the_same_time(self):
"""Test if the eval command rejects an eval if the author already have a running eval."""
ctx = MockContext()
ctx.author.id = 42
- ctx.author.mention = '@LemonLemonishBeard#0042'
- ctx.send = AsyncMock()
- self.cog.jobs = (42,)
- await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
- ctx.send.assert_called_once_with(
- "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
- )
+
+ async def delay_with_side_effect(*args, **kwargs) -> dict:
+ """Delay the post_job call to ensure the job runs long enough to conflict."""
+ await asyncio.sleep(1)
+ return {'stdout': '', 'returncode': 0}
+
+ self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect)
+ with self.assertRaises(LockedResourceError):
+ await asyncio.gather(
+ self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
+ self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
+ )
async def test_send_job(self):
"""Test the send_job function."""
@@ -226,7 +232,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -237,9 +243,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval')
+ self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11')
self.cog.format_output.assert_called_once_with('')
async def test_send_job_with_paste_link(self):
@@ -258,7 +264,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -267,9 +273,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval')
+ self.cog.get_results_message.assert_called_once_with(
+ {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11'
+ )
self.cog.format_output.assert_called_once_with('Way too long beard')
async def test_send_job_with_non_zero_eval(self):
@@ -287,7 +295,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -295,17 +303,25 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval')
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11')
self.cog.format_output.assert_not_called()
@patch("bot.exts.utils.snekbox.partial")
async def test_continue_job_does_continue(self, partial_mock):
"""Test that the continue_job function does continue if required conditions are met."""
- ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
- response = MockMessage(delete=AsyncMock())
+ ctx = MockContext(
+ message=MockMessage(
+ id=4,
+ add_reaction=AsyncMock(),
+ clear_reactions=AsyncMock()
+ ),
+ author=MockMember(id=14)
+ )
+ response = MockMessage(id=42, delete=AsyncMock())
new_msg = MockMessage()
+ self.cog.jobs = {4: 42}
self.bot.wait_for.side_effect = ((None, new_msg), None)
expected = "NewCode"
self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
new file mode 100644
index 000000000..e1f904917
--- /dev/null
+++ b/tests/bot/rules/test_mentions.py
@@ -0,0 +1,131 @@
+from typing import Iterable, Optional
+
+import discord
+
+from bot.rules import mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMember, MockMessage, MockMessageReference
+
+
+def make_msg(
+ author: str,
+ total_user_mentions: int,
+ total_bot_mentions: int = 0,
+ *,
+ reference: Optional[MockMessageReference] = None
+) -> MockMessage:
+ """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""
+ user_mentions = [MockMember() for _ in range(total_user_mentions)]
+ bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
+
+ mentions = user_mentions + bot_mentions
+ if reference is not None:
+ # For the sake of these tests we assume that all references are mentions.
+ mentions.append(reference.resolved.author)
+ msg_type = discord.MessageType.reply
+ else:
+ msg_type = discord.MessageType.default
+
+ return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)
+
+
+class TestMentions(RuleTest):
+ """Tests applying the `mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = mentions.apply
+ self.config = {
+ "max": 2,
+ "interval": 10,
+ }
+
+ async def test_mentions_within_limit(self):
+ """Messages with an allowed amount of mentions."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 1), make_msg("alice", 2)],
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_mentions_exceeding_limit(self):
+ """Messages with a higher than allowed amount of mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
+ ("bob",),
+ 4,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 3, 1)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 3, reference=MockMessageReference())],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))],
+ ("bob",),
+ 3
+ )
+ )
+
+ await self.run_disallowed(cases)
+
+ async def test_ignore_bot_mentions(self):
+ """Messages with an allowed amount of mentions, also containing bot mentions."""
+ cases = (
+ [make_msg("bob", 0, 3)],
+ [make_msg("bob", 2, 1)],
+ [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
+ [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
+ )
+
+ await self.run_allowed(cases)
+
+ async def test_ignore_reply_mentions(self):
+ """Messages with an allowed amount of mentions in the content, also containing reply mentions."""
+ cases = (
+ [
+ make_msg("bob", 2, reference=MockMessageReference())
+ ],
+ [
+ make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True))
+ ],
+ [
+ make_msg("bob", 2, reference=MockMessageReference()),
+ make_msg("bob", 0, reference=MockMessageReference())
+ ],
+ [
+ make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)),
+ make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True))
+ ]
+ )
+
+ await self.run_allowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/helpers.py b/tests/helpers.py
index e74306d23..28a8e40a7 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -317,7 +317,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
guild_id=1,
intents=discord.Intents.all(),
)
- additional_spec_asyncs = ("wait_for", "redis_ready")
+ additional_spec_asyncs = ("wait_for",)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
spec_set = attachment_instance
+message_reference_instance = discord.MessageReference(
+ message_id=unittest.mock.MagicMock(id=1),
+ channel_id=unittest.mock.MagicMock(id=2),
+ guild_id=unittest.mock.MagicMock(id=3)
+)
+
+
+class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock MessageReference objects.
+
+ Instances of this class will follow the specification of `discord.MessageReference` instances.
+ For more information, see the `MockGuild` docstring.
+ """
+ spec_set = message_reference_instance
+
+ def __init__(self, *, reference_author_is_bot: bool = False, **kwargs):
+ super().__init__(**kwargs)
+ referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot)
+ self.resolved = MockMessage(author=referenced_msg_author)
+
+
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
"""
A MagicMock subclass to mock Message objects.
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index f3040b305..b2686b1d0 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -14,7 +14,7 @@ class DiscordMocksTests(unittest.TestCase):
"""Test if the default initialization of MockRole results in the correct object."""
role = helpers.MockRole()
- # The `spec` argument makes sure `isistance` checks with `discord.Role` pass
+ # The `spec` argument makes sure `isinstance` checks with `discord.Role` pass
self.assertIsInstance(role, discord.Role)
self.assertEqual(role.name, "role")