aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2021-04-30 12:32:47 +0100
committerGravatar GitHub <[email protected]>2021-04-30 12:32:47 +0100
commitd3d9e2655ff5dcb79fca8ff24b860ee84c473e50 (patch)
treea515094f5df18615af14a016cfb4c174769cf205
parentWait for cache to fill before redirecting (diff)
parentMerge pull request #954 from ks129/error-handler-test (diff)
Merge branch 'main' into master
Diffstat (limited to '')
-rw-r--r--bot/exts/backend/error_handler.py22
-rw-r--r--bot/exts/moderation/infraction/infractions.py18
-rw-r--r--bot/exts/moderation/stream.py30
-rw-r--r--tests/bot/exts/backend/test_error_handler.py550
-rw-r--r--tests/helpers.py2
5 files changed, 603 insertions, 19 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index da0e94a7e..d8de177f5 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -1,4 +1,3 @@
-import contextlib
import difflib
import logging
import typing as t
@@ -60,7 +59,7 @@ class ErrorHandler(Cog):
log.trace(f"Command {command} had its error already handled locally; ignoring.")
return
- if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
+ if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):
if await self.try_silence(ctx):
return
# Try to look for a tag with the command's name
@@ -162,9 +161,8 @@ class ErrorHandler(Cog):
f"and the fallback tag failed validation in TagNameConverter."
)
else:
- with contextlib.suppress(ResponseCodeError):
- if await ctx.invoke(tags_get_command, tag_name=tag_name):
- return
+ if await ctx.invoke(tags_get_command, tag_name=tag_name):
+ return
if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
await self.send_command_suggestion(ctx, ctx.invoked_with)
@@ -214,32 +212,30 @@ class ErrorHandler(Cog):
* ArgumentParsingError: send an error message
* Other: send an error message and the help command
"""
- prepared_help_command = self.get_help_command(ctx)
-
if isinstance(e, errors.MissingRequiredArgument):
embed = self._get_error_embed("Missing required argument", e.param.name)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.missing_required_argument")
elif isinstance(e, errors.TooManyArguments):
embed = self._get_error_embed("Too many arguments", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.too_many_arguments")
elif isinstance(e, errors.BadArgument):
embed = self._get_error_embed("Bad argument", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_argument")
elif isinstance(e, errors.BadUnionArgument):
embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}")
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_union_argument")
elif isinstance(e, errors.ArgumentParsingError):
embed = self._get_error_embed("Argument parsing error", str(e))
await ctx.send(embed=embed)
- prepared_help_command.close()
+ self.get_help_command(ctx).close()
self.bot.stats.incr("errors.argument_parsing_error")
else:
embed = self._get_error_embed(
@@ -247,7 +243,7 @@ class ErrorHandler(Cog):
"Something about your input seems off. Check the arguments and try again."
)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.other_user_input_error")
@staticmethod
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index d89e80acc..38d1ffc0e 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Permanent infractions
@command()
- async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Warn a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False)
if infraction is None:
return
@@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user)
@command()
- async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Kick a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
await self.apply_kick(ctx, user, reason)
@command()
@@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog):
@command(aliases=["mute"])
async def tempmute(
self, ctx: Context,
- user: Member,
+ user: FetchedMember,
duration: t.Optional[Expiry] = None,
*,
reason: t.Optional[str] = None
@@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog):
If no duration is given, a one hour duration is used by default.
"""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
if duration is None:
duration = await Duration().convert(ctx, "1h")
await self.apply_mute(ctx, user, reason, expires_at=duration)
diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py
index 1dbb2a46b..fd856a7f4 100644
--- a/bot/exts/moderation/stream.py
+++ b/bot/exts/moderation/stream.py
@@ -70,6 +70,28 @@ class Stream(commands.Cog):
self._revoke_streaming_permission(member)
)
+ async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None:
+ """Suspend a member's stream."""
+ await self.bot.wait_until_guild_available()
+ voice_state = member.voice
+
+ if not voice_state:
+ return
+
+ # If the user is streaming.
+ if voice_state.self_stream:
+ # End user's stream by moving them to AFK voice channel and back.
+ original_vc = voice_state.channel
+ await member.move_to(ctx.guild.afk_channel)
+ await member.move_to(original_vc)
+
+ # Notify.
+ await ctx.send(f"{member.mention}'s stream has been suspended!")
+ log.debug(f"Successfully suspended stream from {member} ({member.id}).")
+ return
+
+ log.debug(f"No stream found to suspend from {member} ({member.id}).")
+
@commands.command(aliases=("streaming",))
@commands.has_any_role(*MODERATION_ROLES)
async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None:
@@ -170,10 +192,12 @@ class Stream(commands.Cog):
await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.")
log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).")
- return
- await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
- log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+ else:
+ await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
+ log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+
+ await self._suspend_stream(ctx, member)
@commands.command(aliases=('lstream',))
@commands.has_any_role(*MODERATION_ROLES)
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
new file mode 100644
index 000000000..bd4fb5942
--- /dev/null
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -0,0 +1,550 @@
+import unittest
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+from discord.ext.commands import errors
+
+from bot.api import ResponseCodeError
+from bot.errors import InvalidInfractedUser, LockedResourceError
+from bot.exts.backend.error_handler import ErrorHandler, setup
+from bot.exts.info.tags import Tags
+from bot.exts.moderation.silence import Silence
+from bot.utils.checks import InWhitelistCheckFailure
+from tests.helpers import MockBot, MockContext, MockGuild, MockRole
+
+
+class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for error handler functionality."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+
+ async def test_error_handler_already_handled(self):
+ """Should not do anything when error is already handled by local error handler."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandError()
+ error.handled = "foo"
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_not_invoked_by_handler(self):
+ """Should try first (un)silence channel, when fail, try to get tag."""
+ error = errors.CommandNotFound()
+ test_cases = (
+ {
+ "try_silence_return": True,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": True
+ }
+ )
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ for case in test_cases:
+ with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):
+ self.ctx.reset_mock()
+ cog.try_silence.reset_mock(return_value=True)
+ cog.try_get_tag.reset_mock()
+
+ cog.try_silence.return_value = case["try_silence_return"]
+ self.ctx.channel.id = 1234
+
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+
+ if case["try_silence_return"]:
+ cog.try_get_tag.assert_not_awaited()
+ cog.try_silence.assert_awaited_once()
+ else:
+ cog.try_silence.assert_awaited_once()
+ cog.try_get_tag.assert_awaited_once()
+
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_invoked_by_handler(self):
+ """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""
+ ctx = MockContext(bot=self.bot, invoked_from_error_handler=True)
+
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ error = errors.CommandNotFound()
+
+ self.assertIsNone(await cog.on_command_error(ctx, error))
+
+ cog.try_silence.assert_not_awaited()
+ cog.try_get_tag.assert_not_awaited()
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_user_input_error(self):
+ """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_user_input_error = AsyncMock()
+ error = errors.UserInputError()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_check_failure(self):
+ """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_check_failure = AsyncMock()
+ error = errors.CheckFailure()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_command_on_cooldown(self):
+ """Should send error with `ctx.send` when error is `CommandOnCooldown`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandOnCooldown(10, 9)
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_awaited_once_with(error)
+
+ async def test_error_handler_command_invoke_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ test_cases = (
+ {
+ "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))),
+ "expect_mock_call": cog.handle_api_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(TypeError)),
+ "expect_mock_call": cog.handle_unexpected_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))),
+ "expect_mock_call": "send"
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))),
+ "expect_mock_call": "send"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):
+ self.ctx.send.reset_mock()
+ self.assertIsNone(await cog.on_command_error(*case["args"]))
+ if case["expect_mock_call"] == "send":
+ self.ctx.send.assert_awaited_once()
+ else:
+ case["expect_mock_call"].assert_awaited_once_with(
+ self.ctx, case["args"][1].original
+ )
+
+ async def test_error_handler_conversion_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ cases = (
+ {
+ "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())),
+ "mock_function_to_call": cog.handle_api_error
+ },
+ {
+ "error": errors.ConversionError(AsyncMock(), TypeError),
+ "mock_function_to_call": cog.handle_unexpected_error
+ }
+ )
+
+ for case in cases:
+ with self.subTest(**case):
+ self.assertIsNone(await cog.on_command_error(self.ctx, case["error"]))
+ case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)
+
+ async def test_error_handler_two_other_errors(self):
+ """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_unexpected_error = AsyncMock()
+ errs = (
+ errors.MaxConcurrencyReached(1, MagicMock()),
+ errors.ExtensionError(name="foo")
+ )
+
+ for err in errs:
+ with self.subTest(error=err):
+ cog.handle_unexpected_error.reset_mock()
+ self.assertIsNone(await cog.on_command_error(self.ctx, err))
+ cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_error_handler_other_errors(self, log_mock):
+ """Should `log.debug` other errors."""
+ cog = ErrorHandler(self.bot)
+ error = errors.DisabledCommand() # Use this just as a other error
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ log_mock.debug.assert_called_once()
+
+
+class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Test for helper functions that handle `CommandNotFound` error."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.silence = Silence(self.bot)
+ self.bot.get_command.return_value = self.silence.silence
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_try_silence_context_invoked_from_error_handler(self):
+ """Should set `Context.invoked_from_error_handler` to `True`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler"))
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_silence_get_command(self):
+ """Should call `get_command` with `silence`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.bot.get_command.assert_called_once_with("silence")
+
+ async def test_try_silence_no_permissions_to_run(self):
+ """Should return `False` because missing permissions."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=False)
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_no_permissions_to_run_command_error(self):
+ """Should return `False` because `CommandError` raised (no permissions)."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError())
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_silencing(self):
+ """Should run silence command with correct arguments."""
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
+ test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(
+ self.bot.get_command.return_value,
+ duration=min(case.count("h")*2, 15)
+ )
+
+ async def test_try_silence_unsilence(self):
+ """Should call unsilence command."""
+ self.silence.silence.can_run = AsyncMock(return_value=True)
+ test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence)
+
+ async def test_try_silence_no_match(self):
+ """Should return `False` when message don't match."""
+ self.ctx.invoked_with = "foo"
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+
+class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for `try_get_tag` function."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+ self.tag = Tags(self.bot)
+ self.cog = ErrorHandler(self.bot)
+ self.bot.get_command.return_value = self.tag.get_command
+
+ async def test_try_get_tag_get_command(self):
+ """Should call `Bot.get_command` with `tags get` argument."""
+ self.bot.get_command.reset_mock()
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.bot.get_command.assert_called_once_with("tags get")
+
+ async def test_try_get_tag_invoked_from_error_handler(self):
+ """`self.ctx` should have `invoked_from_error_handler` `True`."""
+ self.ctx.invoked_from_error_handler = False
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_get_tag_no_permissions(self):
+ """Test how to handle checks failing."""
+ self.tag.get_command.can_run = AsyncMock(return_value=False)
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+
+ async def test_try_get_tag_command_error(self):
+ """Should call `on_command_error` when `CommandError` raised."""
+ err = errors.CommandError()
+ self.tag.get_command.can_run = AsyncMock(side_effect=err)
+ self.cog.on_command_error = AsyncMock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_success(self, tag_converter):
+ """Converting tag should successful."""
+ self.ctx.invoked_with = "foo"
+ tag_converter.convert = AsyncMock(return_value="foo")
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ tag_converter.convert.assert_awaited_once_with(self.ctx, "foo")
+ self.ctx.invoke.assert_awaited_once()
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_fail(self, tag_converter):
+ """Converting tag should raise `BadArgument`."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "bar"
+ tag_converter.convert = AsyncMock(side_effect=errors.BadArgument())
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_not_awaited()
+
+ async def test_try_get_tag_ctx_invoke(self):
+ """Should call `ctx.invoke` with proper args/kwargs."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo")
+
+ async def test_dont_call_suggestion_tag_sent(self):
+ """Should never call command suggestion if tag is already sent."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=True)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234])
+ async def test_dont_call_suggestion_if_user_mod(self):
+ """Should not call command suggestion if user is a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.ctx.author.roles = [MockRole(id=1234)]
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ async def test_call_suggestion(self):
+ """Should call command suggestion if user is not a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo")
+
+
+class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Individual error categories handler tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_handle_input_error_handler_errors(self):
+ """Should handle each error probably."""
+ test_cases = (
+ {
+ "error": errors.MissingRequiredArgument(MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.TooManyArguments(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadArgument(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.ArgumentParsingError(),
+ "call_prepared": False
+ },
+ {
+ "error": errors.UserInputError(),
+ "call_prepared": True
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_prepared=case["call_prepared"]):
+ self.ctx.reset_mock()
+ self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"]))
+ self.ctx.send.assert_awaited_once()
+ if case["call_prepared"]:
+ self.ctx.send_help.assert_awaited_once()
+ else:
+ self.ctx.send_help.assert_not_awaited()
+
+ async def test_handle_check_failure_errors(self):
+ """Should await `ctx.send` when error is check failure."""
+ test_cases = (
+ {
+ "error": errors.BotMissingPermissions(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingAnyRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.NoPrivateMessage(),
+ "call_ctx_send": True
+ },
+ {
+ "error": InWhitelistCheckFailure(1234),
+ "call_ctx_send": True
+ },
+ {
+ "error": ResponseCodeError(MagicMock()),
+ "call_ctx_send": False
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]):
+ self.ctx.reset_mock()
+ await self.cog.handle_check_failure(self.ctx, case["error"])
+ if case["call_ctx_send"]:
+ self.ctx.send.assert_awaited_once()
+ else:
+ self.ctx.send.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_api_error(self, log_mock):
+ """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code."""
+ test_cases = (
+ {
+ "error": ResponseCodeError(AsyncMock(status=400)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=404)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=550)),
+ "log_level": "warning"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=1000)),
+ "log_level": "warning"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], log_level=case["log_level"]):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ await self.cog.handle_api_error(self.ctx, case["error"])
+ self.ctx.send.assert_awaited_once()
+ if case["log_level"] == "warning":
+ log_mock.warning.assert_called_once()
+ else:
+ log_mock.debug.assert_called_once()
+
+ @patch("bot.exts.backend.error_handler.push_scope")
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_unexpected_error(self, log_mock, push_scope_mock):
+ """Should `ctx.send` this error, error log this and sent to Sentry."""
+ for case in (None, MockGuild()):
+ with self.subTest(guild=case):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ push_scope_mock.reset_mock()
+
+ self.ctx.guild = case
+ await self.cog.handle_unexpected_error(self.ctx, errors.CommandError())
+
+ self.ctx.send.assert_awaited_once()
+ log_mock.error.assert_called_once()
+ push_scope_mock.assert_called_once()
+
+ set_tag_calls = [
+ call("command", self.ctx.command.qualified_name),
+ call("message_id", self.ctx.message.id),
+ call("channel_id", self.ctx.channel.id),
+ ]
+ set_extra_calls = [
+ call("full_message", self.ctx.message.content)
+ ]
+ if case:
+ url = (
+ f"https://discordapp.com/channels/"
+ f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}"
+ )
+ set_extra_calls.append(call("jump_to", url))
+
+ push_scope_mock.set_tag.has_calls(set_tag_calls)
+ push_scope_mock.set_extra.has_calls(set_extra_calls)
+
+
+class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Other `ErrorHandler` tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+
+ async def test_get_help_command_command_specified(self):
+ """Should return coroutine of help command of specified command."""
+ self.ctx.command = "foo"
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help("foo")
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+ async def test_get_help_command_no_command_specified(self):
+ """Should return coroutine of help command."""
+ self.ctx.command = None
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help()
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+
+class ErrorHandlerSetupTests(unittest.TestCase):
+ """Tests for `ErrorHandler` `setup` function."""
+
+ def test_setup(self):
+ """Should call `bot.add_cog` with `ErrorHandler`."""
+ bot = MockBot()
+ setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/helpers.py b/tests/helpers.py
index 496363ae3..e3dc5fe5b 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -385,6 +385,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da
# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+context_instance.invoked_from_error_handler = None
class MockContext(CustomMockMixin, unittest.mock.MagicMock):
@@ -402,6 +403,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
+ self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)
attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())