aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/error_handler.py27
-rw-r--r--tests/bot/exts/backend/test_error_handler.py84
2 files changed, 98 insertions, 13 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index d8de177f5..a3c04437f 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -3,7 +3,7 @@ import logging
import typing as t
from discord import Embed
-from discord.ext.commands import Cog, Context, errors
+from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors
from sentry_sdk import push_scope
from bot.api import ResponseCodeError
@@ -115,8 +115,10 @@ class ErrorHandler(Cog):
Return bool depending on success of command.
"""
command = ctx.invoked_with.lower()
+ args = ctx.message.content.lower().split(" ")
silence_command = self.bot.get_command("silence")
ctx.invoked_from_error_handler = True
+
try:
if not await silence_command.can_run(ctx):
log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.")
@@ -124,11 +126,30 @@ class ErrorHandler(Cog):
except errors.CommandError:
log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.")
return False
+
+ # Parse optional args
+ channel = None
+ duration = min(command.count("h") * 2, 15)
+ kick = False
+
+ if len(args) > 1:
+ # Parse channel
+ for converter in (TextChannelConverter(), VoiceChannelConverter()):
+ try:
+ channel = await converter.convert(ctx, args[1])
+ break
+ except ChannelNotFound:
+ continue
+
+ if len(args) > 2 and channel is not None:
+ # Parse kick
+ kick = args[2].lower() == "true"
+
if command.startswith("shh"):
- await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15))
+ await ctx.invoke(silence_command, duration_or_channel=channel, duration=duration, kick=kick)
return True
elif command.startswith("unshh"):
- await ctx.invoke(self.bot.get_command("unsilence"))
+ await ctx.invoke(self.bot.get_command("unsilence"), channel=channel)
return True
return False
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index bd4fb5942..37e8108fc 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -9,7 +9,7 @@ from bot.exts.backend.error_handler import ErrorHandler, setup
from bot.exts.info.tags import Tags
from bot.exts.moderation.silence import Silence
from bot.utils.checks import InWhitelistCheckFailure
-from tests.helpers import MockBot, MockContext, MockGuild, MockRole
+from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel
class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
@@ -226,8 +226,8 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError())
self.assertFalse(await self.cog.try_silence(self.ctx))
- async def test_try_silence_silencing(self):
- """Should run silence command with correct arguments."""
+ async def test_try_silence_silence_duration(self):
+ """Should run silence command with correct duration argument."""
self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh")
@@ -238,21 +238,85 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(await self.cog.try_silence(self.ctx))
self.ctx.invoke.assert_awaited_once_with(
self.bot.get_command.return_value,
- duration=min(case.count("h")*2, 15)
+ duration_or_channel=None,
+ duration=min(case.count("h")*2, 15),
+ kick=False
)
+ async def test_try_silence_silence_arguments(self):
+ """Should run silence with the correct channel, duration, and kick arguments."""
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
+
+ test_cases = (
+ (MockTextChannel(), None), # None represents the case when no argument is passed
+ (MockTextChannel(), False),
+ (MockTextChannel(), True)
+ )
+
+ for channel, kick in test_cases:
+ with self.subTest(kick=kick, channel=channel):
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "shh"
+
+ self.ctx.message.content = f"!shh {channel.name} {kick if kick is not None else ''}"
+ self.ctx.guild.text_channels = [channel]
+
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(
+ self.bot.get_command.return_value,
+ duration_or_channel=channel,
+ duration=4,
+ kick=(kick if kick is not None else False)
+ )
+
+ async def test_try_silence_silence_message(self):
+ """If the words after the command could not be converted to a channel, None should be passed as channel."""
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
+ self.ctx.invoked_with = "shh"
+ self.ctx.message.content = "!shh not_a_channel true"
+
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(
+ self.bot.get_command.return_value,
+ duration_or_channel=None,
+ duration=4,
+ kick=False
+ )
+
async def test_try_silence_unsilence(self):
- """Should call unsilence command."""
+ """Should call unsilence command with correct duration and channel arguments."""
self.silence.silence.can_run = AsyncMock(return_value=True)
- test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh")
+ test_cases = (
+ ("unshh", None),
+ ("unshhhhh", None),
+ ("unshhhhhhhhh", None),
+ ("unshh", MockTextChannel())
+ )
- for case in test_cases:
- with self.subTest(message=case):
+ for invoke, channel in test_cases:
+ with self.subTest(message=invoke, channel=channel):
self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)
self.ctx.reset_mock()
- self.ctx.invoked_with = case
+
+ self.ctx.invoked_with = invoke
+ self.ctx.message.content = f"!{invoke}"
+ if channel is not None:
+ self.ctx.message.content += f" {channel.name}"
+ self.ctx.guild.text_channels = [channel]
+
self.assertTrue(await self.cog.try_silence(self.ctx))
- self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence)
+ self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=channel)
+
+ async def test_try_silence_unsilence_message(self):
+ """If the words after the command could not be converted to a channel, None should be passed as channel."""
+ self.silence.silence.can_run = AsyncMock(return_value=True)
+ self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)
+
+ self.ctx.invoked_with = "unshh"
+ self.ctx.message.content = "!unshh not_a_channel"
+
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=None)
async def test_try_silence_no_match(self):
"""Should return `False` when message don't match."""