diff options
author | 2021-05-13 06:13:31 +0300 | |
---|---|---|
committer | 2021-05-13 06:14:10 +0300 | |
commit | 50294816787562abd5f3f984f513d039612d2f3b (patch) | |
tree | fa1d7ffce9613360041707eac2df6334c0eac42d | |
parent | Updates Silence To Accept Duration Or Channel (diff) |
Updates Shh Command To Mirror Silence
Updates the shh and unshh commands from the error handler to accept
channel and kick arguments, to give them the same interface as the
silence and unsilence command.
Signed-off-by: Hassan Abouelela <[email protected]>
-rw-r--r-- | bot/exts/backend/error_handler.py | 27 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 84 |
2 files changed, 98 insertions, 13 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index d8de177f5..a3c04437f 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -3,7 +3,7 @@ import logging import typing as t from discord import Embed -from discord.ext.commands import Cog, Context, errors +from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from sentry_sdk import push_scope from bot.api import ResponseCodeError @@ -115,8 +115,10 @@ class ErrorHandler(Cog): Return bool depending on success of command. """ command = ctx.invoked_with.lower() + args = ctx.message.content.lower().split(" ") silence_command = self.bot.get_command("silence") ctx.invoked_from_error_handler = True + try: if not await silence_command.can_run(ctx): log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") @@ -124,11 +126,30 @@ class ErrorHandler(Cog): except errors.CommandError: log.debug("Cancelling attempt to invoke silence/unsilence due to failed checks.") return False + + # Parse optional args + channel = None + duration = min(command.count("h") * 2, 15) + kick = False + + if len(args) > 1: + # Parse channel + for converter in (TextChannelConverter(), VoiceChannelConverter()): + try: + channel = await converter.convert(ctx, args[1]) + break + except ChannelNotFound: + continue + + if len(args) > 2 and channel is not None: + # Parse kick + kick = args[2].lower() == "true" + if command.startswith("shh"): - await ctx.invoke(silence_command, duration=min(command.count("h")*2, 15)) + await ctx.invoke(silence_command, duration_or_channel=channel, duration=duration, kick=kick) return True elif command.startswith("unshh"): - await ctx.invoke(self.bot.get_command("unsilence")) + await ctx.invoke(self.bot.get_command("unsilence"), channel=channel) return True return False diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index bd4fb5942..37e8108fc 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.backend.error_handler import ErrorHandler, setup from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole +from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -226,8 +226,8 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError()) self.assertFalse(await self.cog.try_silence(self.ctx)) - async def test_try_silence_silencing(self): - """Should run silence command with correct arguments.""" + async def test_try_silence_silence_duration(self): + """Should run silence command with correct duration argument.""" self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") @@ -238,21 +238,85 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(await self.cog.try_silence(self.ctx)) self.ctx.invoke.assert_awaited_once_with( self.bot.get_command.return_value, - duration=min(case.count("h")*2, 15) + duration_or_channel=None, + duration=min(case.count("h")*2, 15), + kick=False ) + async def test_try_silence_silence_arguments(self): + """Should run silence with the correct channel, duration, and kick arguments.""" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + + test_cases = ( + (MockTextChannel(), None), # None represents the case when no argument is passed + (MockTextChannel(), False), + (MockTextChannel(), True) + ) + + for channel, kick in test_cases: + with self.subTest(kick=kick, channel=channel): + self.ctx.reset_mock() + self.ctx.invoked_with = "shh" + + self.ctx.message.content = f"!shh {channel.name} {kick if kick is not None else ''}" + self.ctx.guild.text_channels = [channel] + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with( + self.bot.get_command.return_value, + duration_or_channel=channel, + duration=4, + kick=(kick if kick is not None else False) + ) + + async def test_try_silence_silence_message(self): + """If the words after the command could not be converted to a channel, None should be passed as channel.""" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + self.ctx.invoked_with = "shh" + self.ctx.message.content = "!shh not_a_channel true" + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with( + self.bot.get_command.return_value, + duration_or_channel=None, + duration=4, + kick=False + ) + async def test_try_silence_unsilence(self): - """Should call unsilence command.""" + """Should call unsilence command with correct duration and channel arguments.""" self.silence.silence.can_run = AsyncMock(return_value=True) - test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") + test_cases = ( + ("unshh", None), + ("unshhhhh", None), + ("unshhhhhhhhh", None), + ("unshh", MockTextChannel()) + ) - for case in test_cases: - with self.subTest(message=case): + for invoke, channel in test_cases: + with self.subTest(message=invoke, channel=channel): self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) self.ctx.reset_mock() - self.ctx.invoked_with = case + + self.ctx.invoked_with = invoke + self.ctx.message.content = f"!{invoke}" + if channel is not None: + self.ctx.message.content += f" {channel.name}" + self.ctx.guild.text_channels = [channel] + self.assertTrue(await self.cog.try_silence(self.ctx)) - self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) + self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=channel) + + async def test_try_silence_unsilence_message(self): + """If the words after the command could not be converted to a channel, None should be passed as channel.""" + self.silence.silence.can_run = AsyncMock(return_value=True) + self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) + + self.ctx.invoked_with = "unshh" + self.ctx.message.content = "!unshh not_a_channel" + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=None) async def test_try_silence_no_match(self): """Should return `False` when message don't match.""" |