aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bot/exts/moderation/silence.py35
-rw-r--r--tests/bot/exts/moderation/test_silence.py91
2 files changed, 109 insertions, 17 deletions
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index e5c96e76f..8e4ce7ae2 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -1,5 +1,6 @@
import json
import logging
+import typing
from contextlib import suppress
from datetime import datetime, timedelta, timezone
from typing import Optional, OrderedDict, Union
@@ -84,12 +85,8 @@ class SilenceNotifier(tasks.Loop):
async def _select_lock_channel(args: OrderedDict[str, any]) -> TextOrVoiceChannel:
"""Passes the channel to be silenced to the resource lock."""
- channel = args["channel"]
- if channel is not None:
- return channel
-
- else:
- return args["ctx"].channel
+ channel, _ = Silence.parse_silence_args(args["ctx"], args["duration_or_channel"], args["duration"])
+ return channel
class Silence(commands.Cog):
@@ -155,8 +152,8 @@ class Silence(commands.Cog):
async def silence(
self,
ctx: Context,
+ duration_or_channel: typing.Union[TextOrVoiceChannel, HushDurationConverter] = None,
duration: HushDurationConverter = 10,
- channel: TextOrVoiceChannel = None,
*,
kick: bool = False
) -> None:
@@ -170,8 +167,8 @@ class Silence(commands.Cog):
If `kick` is True, members will not be added back to the voice channel, and members will be unable to rejoin.
"""
await self._init_task
- if channel is None:
- channel = ctx.channel
+ channel, duration = self.parse_silence_args(ctx, duration_or_channel, duration)
+
channel_info = f"#{channel} ({channel.id})"
log.debug(f"{ctx.author} is silencing channel {channel_info}.")
@@ -198,6 +195,26 @@ class Silence(commands.Cog):
formatted_message = MSG_SILENCE_SUCCESS.format(duration=duration)
await self.send_message(formatted_message, ctx.channel, channel, alert_target=True)
+ @staticmethod
+ def parse_silence_args(
+ ctx: Context,
+ duration_or_channel: typing.Union[TextOrVoiceChannel, int],
+ duration: HushDurationConverter
+ ) -> typing.Tuple[TextOrVoiceChannel, int]:
+ """Helper method to parse the arguments of the silence command."""
+ duration: int
+
+ if duration_or_channel:
+ if isinstance(duration_or_channel, (TextChannel, VoiceChannel)):
+ channel = duration_or_channel
+ else:
+ channel = ctx.channel
+ duration = duration_or_channel
+ else:
+ channel = ctx.channel
+
+ return channel, duration
+
async def _set_silence_overwrites(self, channel: TextOrVoiceChannel, *, kick: bool = False) -> bool:
"""Set silence permission overwrites for `channel` and return True if successful."""
# Get the original channel overwrites
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index af6dd5a37..a7ea733c5 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -260,6 +260,81 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2)
+class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the silence argument parser utility function."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ @autospec(silence.Silence, "send_message", pass_mocks=False)
+ @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)
+ @autospec(silence.Silence, "parse_silence_args")
+ async def test_command(self, parser_mock):
+ """Test that the command passes in the correct arguments for different calls."""
+ test_cases = (
+ (),
+ (15, ),
+ (MockTextChannel(),),
+ (MockTextChannel(), 15),
+ )
+
+ ctx = MockContext()
+ parser_mock.return_value = (ctx.channel, 10)
+
+ for case in test_cases:
+ with self.subTest("Test command converters", args=case):
+ await self.cog.silence.callback(self.cog, ctx, *case)
+
+ try:
+ first_arg = case[0]
+ except IndexError:
+ # Default value when the first argument is not passed
+ first_arg = None
+
+ try:
+ second_arg = case[1]
+ except IndexError:
+ # Default value when the second argument is not passed
+ second_arg = 10
+
+ parser_mock.assert_called_with(ctx, first_arg, second_arg)
+
+ async def test_no_arguments(self):
+ """Test the parser when no arguments are passed to the command."""
+ ctx = MockContext()
+ channel, duration = self.cog.parse_silence_args(ctx, None, 10)
+
+ self.assertEqual(ctx.channel, channel)
+ self.assertEqual(10, duration)
+
+ async def test_channel_only(self):
+ """Test the parser when just the channel argument is passed."""
+ expected_channel = MockTextChannel()
+ actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 10)
+
+ self.assertEqual(expected_channel, actual_channel)
+ self.assertEqual(10, duration)
+
+ async def test_duration_only(self):
+ """Test the parser when just the duration argument is passed."""
+ ctx = MockContext()
+ channel, duration = self.cog.parse_silence_args(ctx, 15, 10)
+
+ self.assertEqual(ctx.channel, channel)
+ self.assertEqual(15, duration)
+
+ async def test_all_args(self):
+ """Test the parser when both channel and duration are passed."""
+ expected_channel = MockTextChannel()
+ actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 15)
+
+ self.assertEqual(expected_channel, actual_channel)
+ self.assertEqual(15, duration)
+
+
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class RescheduleTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the rescheduling of cached unsilences."""
@@ -387,7 +462,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
with self.subTest(was_silenced=was_silenced, target=target, message=message):
with mock.patch.object(self.cog, "send_message") as send_message:
ctx = MockContext()
- await self.cog.silence.callback(self.cog, ctx, duration, target)
+ await self.cog.silence.callback(self.cog, ctx, target, duration)
send_message.assert_called_once_with(
message,
ctx.channel,
@@ -399,7 +474,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_sync_called(self, ctx, sync, kick):
"""Tests if silence command calls sync on a voice channel."""
channel = MockVoiceChannel()
- await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False)
+ await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False)
sync.assert_awaited_once_with(self.cog, channel)
kick.assert_not_called()
@@ -408,7 +483,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_kick_called(self, ctx, sync, kick):
"""Tests if silence command calls kick on a voice channel."""
channel = MockVoiceChannel()
- await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True)
+ await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True)
kick.assert_awaited_once_with(channel)
sync.assert_not_called()
@@ -417,7 +492,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_sync_not_called(self, ctx, sync, kick):
"""Tests that silence command does not call sync on a text channel."""
channel = MockTextChannel()
- await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False)
+ await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False)
sync.assert_not_called()
kick.assert_not_called()
@@ -426,7 +501,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_kick_not_called(self, ctx, sync, kick):
"""Tests that silence command does not call kick on a text channel."""
channel = MockTextChannel()
- await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True)
+ await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True)
sync.assert_not_called()
kick.assert_not_called()
@@ -515,7 +590,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_indefinite_added_to_notifier(self):
"""Channel was added to notifier if a duration was not set for the silence."""
with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
- await self.cog.silence.callback(self.cog, MockContext(), None)
+ await self.cog.silence.callback(self.cog, MockContext(), None, None)
self.cog.notifier.add_channel.assert_called_once()
async def test_silenced_not_added_to_notifier(self):
@@ -547,7 +622,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_cached_indefinite_time(self):
"""A value of -1 was cached for a permanent silence."""
ctx = MockContext(channel=self.text_channel)
- await self.cog.silence.callback(self.cog, ctx, None)
+ await self.cog.silence.callback(self.cog, ctx, None, None)
self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
async def test_scheduled_task(self):
@@ -563,7 +638,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_permanent_not_scheduled(self):
"""A task was not scheduled for a permanent silence."""
ctx = MockContext(channel=self.text_channel)
- await self.cog.silence.callback(self.cog, ctx, None)
+ await self.cog.silence.callback(self.cog, ctx, None, None)
self.cog.scheduler.schedule_later.assert_not_called()