aboutsummaryrefslogtreecommitdiffstats
path: root/tests/bot
diff options
context:
space:
mode:
authorGravatar ks129 <[email protected]>2020-05-14 19:24:13 +0300
committerGravatar GitHub <[email protected]>2020-05-14 19:24:13 +0300
commit188623e556e43c9c4444bc2a01533a0d8e678dd2 (patch)
treefae2611611e58ea6200dba3ebd5481c8b4eaf5b7 /tests/bot
parent(Mod Utils): Removed unnecessary `textwrap` import (diff)
parentRemove @Admins ping from the #verification message (diff)
Merge branch 'master' into mod-utils-tests
Diffstat (limited to 'tests/bot')
-rw-r--r--tests/bot/cogs/moderation/test_silence.py251
-rw-r--r--tests/bot/cogs/sync/test_base.py5
-rw-r--r--tests/bot/cogs/test_cogs.py80
-rw-r--r--tests/bot/cogs/test_information.py11
-rw-r--r--tests/bot/cogs/test_snekbox.py51
-rw-r--r--tests/bot/test_converters.py30
-rw-r--r--tests/bot/test_decorators.py147
-rw-r--r--tests/bot/test_utils.py37
8 files changed, 557 insertions, 55 deletions
diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py
new file mode 100644
index 000000000..3fd149f04
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_silence.py
@@ -0,0 +1,251 @@
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock, Mock
+
+from discord import PermissionOverwrite
+
+from bot.cogs.moderation.silence import Silence, SilenceNotifier
+from bot.constants import Channels, Emojis, Guild, Roles
+from tests.helpers import MockBot, MockContext, MockTextChannel
+
+
+class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self) -> None:
+ self.alert_channel = MockTextChannel()
+ self.notifier = SilenceNotifier(self.alert_channel)
+ self.notifier.stop = self.notifier_stop_mock = Mock()
+ self.notifier.start = self.notifier_start_mock = Mock()
+
+ def test_add_channel_adds_channel(self):
+ """Channel in FirstHash with current loop is added to internal set."""
+ channel = Mock()
+ with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
+ self.notifier.add_channel(channel)
+ silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop)
+
+ def test_add_channel_starts_loop(self):
+ """Loop is started if `_silenced_channels` was empty."""
+ self.notifier.add_channel(Mock())
+ self.notifier_start_mock.assert_called_once()
+
+ def test_add_channel_skips_start_with_channels(self):
+ """Loop start is not called when `_silenced_channels` is not empty."""
+ with mock.patch.object(self.notifier, "_silenced_channels"):
+ self.notifier.add_channel(Mock())
+ self.notifier_start_mock.assert_not_called()
+
+ def test_remove_channel_removes_channel(self):
+ """Channel in FirstHash is removed from `_silenced_channels`."""
+ channel = Mock()
+ with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
+ self.notifier.remove_channel(channel)
+ silenced_channels.__delitem__.assert_called_with(channel)
+
+ def test_remove_channel_stops_loop(self):
+ """Notifier loop is stopped if `_silenced_channels` is empty after remove."""
+ with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False):
+ self.notifier.remove_channel(Mock())
+ self.notifier_stop_mock.assert_called_once()
+
+ def test_remove_channel_skips_stop_with_channels(self):
+ """Notifier loop is not stopped if `_silenced_channels` is not empty after remove."""
+ self.notifier.remove_channel(Mock())
+ self.notifier_stop_mock.assert_not_called()
+
+ async def test_notifier_private_sends_alert(self):
+ """Alert is sent on 15 min intervals."""
+ test_cases = (900, 1800, 2700)
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
+ await self.notifier._notifier()
+ self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ")
+ self.alert_channel.send.reset_mock()
+
+ async def test_notifier_skips_alert(self):
+ """Alert is skipped on first loop or not an increment of 900."""
+ test_cases = (0, 15, 5000)
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
+ await self.notifier._notifier()
+ self.alert_channel.send.assert_not_called()
+
+
+class SilenceTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = Silence(self.bot)
+ self.ctx = MockContext()
+ self.cog._verified_role = None
+ # Set event so command callbacks can continue.
+ self.cog._get_instance_vars_event.set()
+
+ async def test_instance_vars_got_guild(self):
+ """Bot got guild after it became available."""
+ await self.cog._get_instance_vars()
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(Guild.id)
+
+ async def test_instance_vars_got_role(self):
+ """Got `Roles.verified` role from guild."""
+ await self.cog._get_instance_vars()
+ guild = self.bot.get_guild()
+ guild.get_role.assert_called_once_with(Roles.verified)
+
+ async def test_instance_vars_got_channels(self):
+ """Got channels from bot."""
+ await self.cog._get_instance_vars()
+ self.bot.get_channel.called_once_with(Channels.mod_alerts)
+ self.bot.get_channel.called_once_with(Channels.mod_log)
+
+ @mock.patch("bot.cogs.moderation.silence.SilenceNotifier")
+ async def test_instance_vars_got_notifier(self, notifier):
+ """Notifier was started with channel."""
+ mod_log = MockTextChannel()
+ self.bot.get_channel.side_effect = (None, mod_log)
+ await self.cog._get_instance_vars()
+ notifier.assert_called_once_with(mod_log)
+ self.bot.get_channel.side_effect = None
+
+ async def test_silence_sent_correct_discord_message(self):
+ """Check if proper message was sent when called with duration in channel with previous state."""
+ test_cases = (
+ (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,),
+ (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,),
+ (5, f"{Emojis.cross_mark} current channel is already silenced.", False,),
+ )
+ for duration, result_message, _silence_patch_return in test_cases:
+ with self.subTest(
+ silence_duration=duration,
+ result_message=result_message,
+ starting_unsilenced_state=_silence_patch_return
+ ):
+ with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
+ await self.cog.silence.callback(self.cog, self.ctx, duration)
+ self.ctx.send.assert_called_once_with(result_message)
+ self.ctx.reset_mock()
+
+ async def test_unsilence_sent_correct_discord_message(self):
+ """Proper reply after a successful unsilence."""
+ with mock.patch.object(self.cog, "_unsilence", return_value=True):
+ await self.cog.unsilence.callback(self.cog, self.ctx)
+ self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.")
+
+ async def test_silence_private_for_false(self):
+ """Permissions are not set and `False` is returned in an already silenced channel."""
+ perm_overwrite = Mock(send_messages=False)
+ channel = Mock(overwrites_for=Mock(return_value=perm_overwrite))
+
+ self.assertFalse(await self.cog._silence(channel, True, None))
+ channel.set_permissions.assert_not_called()
+
+ async def test_silence_private_silenced_channel(self):
+ """Channel had `send_message` permissions revoked."""
+ channel = MockTextChannel()
+ self.assertTrue(await self.cog._silence(channel, False, None))
+ channel.set_permissions.assert_called_once()
+ self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages'])
+
+ async def test_silence_private_preserves_permissions(self):
+ """Previous permissions were preserved when channel was silenced."""
+ channel = MockTextChannel()
+ # Set up mock channel permission state.
+ mock_permissions = PermissionOverwrite()
+ mock_permissions_dict = dict(mock_permissions)
+ channel.overwrites_for.return_value = mock_permissions
+ await self.cog._silence(channel, False, None)
+ new_permissions = channel.set_permissions.call_args.kwargs
+ # Remove 'send_messages' key because it got changed in the method.
+ del new_permissions['send_messages']
+ del mock_permissions_dict['send_messages']
+ self.assertDictEqual(mock_permissions_dict, new_permissions)
+
+ async def test_silence_private_notifier(self):
+ """Channel should be added to notifier with `persistent` set to `True`, and the other way around."""
+ channel = MockTextChannel()
+ with mock.patch.object(self.cog, "notifier", create=True):
+ with self.subTest(persistent=True):
+ await self.cog._silence(channel, True, None)
+ self.cog.notifier.add_channel.assert_called_once()
+
+ with mock.patch.object(self.cog, "notifier", create=True):
+ with self.subTest(persistent=False):
+ await self.cog._silence(channel, False, None)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_silence_private_added_muted_channel(self):
+ """Channel was added to `muted_channels` on silence."""
+ channel = MockTextChannel()
+ with mock.patch.object(self.cog, "muted_channels") as muted_channels:
+ await self.cog._silence(channel, False, None)
+ muted_channels.add.assert_called_once_with(channel)
+
+ async def test_unsilence_private_for_false(self):
+ """Permissions are not set and `False` is returned in an unsilenced channel."""
+ channel = Mock()
+ self.assertFalse(await self.cog._unsilence(channel))
+ channel.set_permissions.assert_not_called()
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_unsilenced_channel(self, _):
+ """Channel had `send_message` permissions restored"""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ self.assertTrue(await self.cog._unsilence(channel))
+ channel.set_permissions.assert_called_once()
+ self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages'])
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_removed_notifier(self, notifier):
+ """Channel was removed from `notifier` on unsilence."""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ await self.cog._unsilence(channel)
+ notifier.remove_channel.assert_called_once_with(channel)
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_removed_muted_channel(self, _):
+ """Channel was removed from `muted_channels` on unsilence."""
+ perm_overwrite = MagicMock(send_messages=False)
+ channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
+ with mock.patch.object(self.cog, "muted_channels") as muted_channels:
+ await self.cog._unsilence(channel)
+ muted_channels.discard.assert_called_once_with(channel)
+
+ @mock.patch.object(Silence, "notifier", create=True)
+ async def test_unsilence_private_preserves_permissions(self, _):
+ """Previous permissions were preserved when channel was unsilenced."""
+ channel = MockTextChannel()
+ # Set up mock channel permission state.
+ mock_permissions = PermissionOverwrite(send_messages=False)
+ mock_permissions_dict = dict(mock_permissions)
+ channel.overwrites_for.return_value = mock_permissions
+ await self.cog._unsilence(channel)
+ new_permissions = channel.set_permissions.call_args.kwargs
+ # Remove 'send_messages' key because it got changed in the method.
+ del new_permissions['send_messages']
+ del mock_permissions_dict['send_messages']
+ self.assertDictEqual(mock_permissions_dict, new_permissions)
+
+ @mock.patch("bot.cogs.moderation.silence.asyncio")
+ @mock.patch.object(Silence, "_mod_alerts_channel", create=True)
+ def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):
+ """Task for sending an alert was created with present `muted_channels`."""
+ with mock.patch.object(self.cog, "muted_channels"):
+ self.cog.cog_unload()
+ alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")
+ asyncio_mock.create_task.assert_called_once_with(alert_channel.send())
+
+ @mock.patch("bot.cogs.moderation.silence.asyncio")
+ def test_cog_unload_skips_task_start(self, asyncio_mock):
+ """No task created with no channels."""
+ self.cog.cog_unload()
+ asyncio_mock.create_task.assert_not_called()
+
+ @mock.patch("bot.cogs.moderation.silence.with_role_check")
+ @mock.patch("bot.cogs.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
+ def test_cog_check(self, role_check):
+ """Role check is called with `MODERATION_ROLES`"""
+ self.cog.cog_check(self.ctx)
+ role_check.assert_called_once_with(self.ctx, *(1, 2, 3))
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
index fe0594efe..70aea2bab 100644
--- a/tests/bot/cogs/sync/test_base.py
+++ b/tests/bot/cogs/sync/test_base.py
@@ -1,3 +1,4 @@
+import asyncio
import unittest
from unittest import mock
@@ -84,7 +85,7 @@ class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):
method.assert_called_once_with(constants.Channels.dev_core)
- async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
+ async def test_send_prompt_returns_none_if_channel_fetch_fails(self):
"""None should be returned if there's an HTTPException when fetching the channel."""
self.bot.get_channel.return_value = None
self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
@@ -211,7 +212,7 @@ class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):
subtests = (
(constants.Emojis.check_mark, True, None),
("InVaLiD", False, None),
- (None, False, TimeoutError),
+ (None, False, asyncio.TimeoutError),
)
for emoji, ret_val, side_effect in subtests:
diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py
new file mode 100644
index 000000000..fdda59a8f
--- /dev/null
+++ b/tests/bot/cogs/test_cogs.py
@@ -0,0 +1,80 @@
+"""Test suite for general tests which apply to all cogs."""
+
+import importlib
+import pkgutil
+import typing as t
+import unittest
+from collections import defaultdict
+from types import ModuleType
+from unittest import mock
+
+from discord.ext import commands
+
+from bot import cogs
+
+
+class CommandNameTests(unittest.TestCase):
+ """Tests for shadowing command names and aliases."""
+
+ @staticmethod
+ def walk_commands(cog: commands.Cog) -> t.Iterator[commands.Command]:
+ """An iterator that recursively walks through `cog`'s commands and subcommands."""
+ # Can't use Bot.walk_commands() or Cog.get_commands() cause those are instance methods.
+ for command in cog.__cog_commands__:
+ if command.parent is None:
+ yield command
+ if isinstance(command, commands.GroupMixin):
+ # Annoyingly it returns duplicates for each alias so use a set to fix that
+ yield from set(command.walk_commands())
+
+ @staticmethod
+ def walk_modules() -> t.Iterator[ModuleType]:
+ """Yield imported modules from the bot.cogs subpackage."""
+ def on_error(name: str) -> t.NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ # The mock prevents asyncio.get_event_loop() from being called.
+ with mock.patch("discord.ext.tasks.loop"):
+ for module in pkgutil.walk_packages(cogs.__path__, "bot.cogs.", onerror=on_error):
+ if not module.ispkg:
+ yield importlib.import_module(module.name)
+
+ @staticmethod
+ def walk_cogs(module: ModuleType) -> t.Iterator[commands.Cog]:
+ """Yield all cogs defined in an extension."""
+ for obj in module.__dict__.values():
+ # Check if it's a class type cause otherwise issubclass() may raise a TypeError.
+ is_cog = isinstance(obj, type) and issubclass(obj, commands.Cog)
+ if is_cog and obj.__module__ == module.__name__:
+ yield obj
+
+ @staticmethod
+ def get_qualified_names(command: commands.Command) -> t.List[str]:
+ """Return a list of all qualified names, including aliases, for the `command`."""
+ names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases]
+ names.append(command.qualified_name)
+
+ return names
+
+ def get_all_commands(self) -> t.Iterator[commands.Command]:
+ """Yield all commands for all cogs in all extensions."""
+ for module in self.walk_modules():
+ for cog in self.walk_cogs(module):
+ for cmd in self.walk_commands(cog):
+ yield cmd
+
+ def test_names_dont_shadow(self):
+ """Names and aliases of commands should be unique."""
+ all_names = defaultdict(list)
+ for cmd in self.get_all_commands():
+ func_name = f"{cmd.module}.{cmd.callback.__qualname__}"
+
+ for name in self.get_qualified_names(cmd):
+ with self.subTest(cmd=func_name, name=name):
+ if name in all_names: # pragma: no cover
+ conflicts = ", ".join(all_names.get(name, ""))
+ self.fail(
+ f"Name '{name}' of the command {func_name} conflicts with {conflicts}."
+ )
+
+ all_names[name].append(func_name)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 5693d2946..b5f928dd6 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,7 +7,7 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InChannelCheckFailure
+from bot.decorators import InWhitelistCheckFailure
from tests import helpers
@@ -45,10 +45,9 @@ class InformationCogTests(unittest.TestCase):
_, kwargs = self.ctx.send.call_args
embed = kwargs.pop('embed')
- self.assertEqual(embed.title, "Role information")
+ self.assertEqual(embed.title, "Role information (Total 1 role)")
self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
- self.assertEqual(embed.footer.text, "Total roles: 1")
+ self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
def test_role_info_command(self):
"""Tests the `role info` command."""
@@ -486,7 +485,7 @@ class UserEmbedTests(unittest.TestCase):
user.avatar_url_as.return_value = "avatar url"
embed = asyncio.run(self.cog.create_user_embed(ctx, user))
- user.avatar_url_as.assert_called_once_with(format="png")
+ user.avatar_url_as.assert_called_once_with(static_format="png")
self.assertEqual(embed.thumbnail.url, "avatar url")
@@ -526,7 +525,7 @@ class UserCommandTests(unittest.TestCase):
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
msg = "Sorry, but you may only use this command within <#50>."
- with self.assertRaises(InChannelCheckFailure, msg=msg):
+ with self.assertRaises(InWhitelistCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
@unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
index 9cd7f0154..1dec0ccaf 100644
--- a/tests/bot/cogs/test_snekbox.py
+++ b/tests/bot/cogs/test_snekbox.py
@@ -1,11 +1,13 @@
import asyncio
import logging
import unittest
-from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
+from discord.ext import commands
+
+from bot import constants
from bot.cogs import snekbox
from bot.cogs.snekbox import Snekbox
-from bot.constants import URLs
from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
@@ -23,7 +25,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(await self.cog.post_eval("import random"), "return")
self.bot.http_session.post.assert_called_with(
- URLs.snekbox_eval_api,
+ constants.URLs.snekbox_eval_api,
json={"input": "import random"},
raise_for_status=True
)
@@ -43,10 +45,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
await self.cog.upload_output("My awesome output"),
- URLs.paste_service.format(key=key)
+ constants.URLs.paste_service.format(key=key)
)
self.bot.http_session.post.assert_called_with(
- URLs.paste_service.format(key="documents"),
+ constants.URLs.paste_service.format(key="documents"),
data="My awesome output",
raise_for_status=True
)
@@ -89,15 +91,15 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(actual, expected)
@patch('bot.cogs.snekbox.Signals', side_effect=ValueError)
- def test_get_results_message_invalid_signal(self, mock_Signals: Mock):
+ def test_get_results_message_invalid_signal(self, mock_signals: Mock):
self.assertEqual(
self.cog.get_results_message({'stdout': '', 'returncode': 127}),
('Your eval job has completed with return code 127', '')
)
@patch('bot.cogs.snekbox.Signals')
- def test_get_results_message_valid_signal(self, mock_Signals: Mock):
- mock_Signals.return_value.name = 'SIGTEST'
+ def test_get_results_message_valid_signal(self, mock_signals: Mock):
+ mock_signals.return_value.name = 'SIGTEST'
self.assertEqual(
self.cog.get_results_message({'stdout': '', 'returncode': 127}),
('Your eval job has completed with return code 127 (SIGTEST)', '')
@@ -279,11 +281,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Test that the continue_eval function does continue if required conditions are met."""
ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
response = MockMessage(delete=AsyncMock())
- new_msg = MockMessage(content='!e NewCode')
+ new_msg = MockMessage()
self.bot.wait_for.side_effect = ((None, new_msg), None)
+ expected = "NewCode"
+ self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
actual = await self.cog.continue_eval(ctx, response)
- self.assertEqual(actual, 'NewCode')
+ self.cog.get_code.assert_awaited_once_with(new_msg)
+ self.assertEqual(actual, expected)
self.bot.wait_for.assert_has_awaits(
(
call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10),
@@ -302,6 +307,32 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(actual, None)
ctx.message.clear_reactions.assert_called_once()
+ async def test_get_code(self):
+ """Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
+ prefix = constants.Bot.prefix
+ subtests = (
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"),
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None),
+ (MagicMock(spec=commands.Command), f"{prefix}tags get foo"),
+ (None, "print(123)")
+ )
+
+ for command, content, *expected_code in subtests:
+ if not expected_code:
+ expected_code = content
+ else:
+ [expected_code] = expected_code
+
+ with self.subTest(content=content, expected_code=expected_code):
+ self.bot.get_context.reset_mock()
+ self.bot.get_context.return_value = MockContext(command=command)
+ message = MockMessage(content=content)
+
+ actual_code = await self.cog.get_code(message)
+
+ self.bot.get_context.assert_awaited_once_with(message)
+ self.assertEqual(actual_code, expected_code)
+
def test_predicate_eval_message_edit(self):
"""Test the predicate_eval_message_edit function."""
msg0 = MockMessage(id=1, content='abc')
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index 1e5ca62ae..ca8cb6825 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -8,6 +8,7 @@ from discord.ext.commands import BadArgument
from bot.converters import (
Duration,
+ HushDurationConverter,
ISODateTime,
TagContentConverter,
TagNameConverter,
@@ -271,3 +272,32 @@ class ConverterTests(unittest.TestCase):
exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string"
with self.assertRaises(BadArgument, msg=exception_message):
asyncio.run(converter.convert(self.context, datetime_string))
+
+ def test_hush_duration_converter_for_valid(self):
+ """HushDurationConverter returns correct value for minutes duration or `"forever"` strings."""
+ test_values = (
+ ("0", 0),
+ ("15", 15),
+ ("10", 10),
+ ("5m", 5),
+ ("5M", 5),
+ ("forever", None),
+ )
+ converter = HushDurationConverter()
+ for minutes_string, expected_minutes in test_values:
+ with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes):
+ converted = asyncio.run(converter.convert(self.context, minutes_string))
+ self.assertEqual(expected_minutes, converted)
+
+ def test_hush_duration_converter_for_invalid(self):
+ """HushDurationConverter raises correct exception for invalid minutes duration strings."""
+ test_values = (
+ ("16", "Duration must be at most 15 minutes."),
+ ("10d", "10d is not a valid minutes duration."),
+ ("-1", "-1 is not a valid minutes duration."),
+ )
+ converter = HushDurationConverter()
+ for invalid_minutes_string, exception_message in test_values:
+ with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message):
+ with self.assertRaisesRegex(BadArgument, exception_message):
+ asyncio.run(converter.convert(self.context, invalid_minutes_string))
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
new file mode 100644
index 000000000..a17dd3e16
--- /dev/null
+++ b/tests/bot/test_decorators.py
@@ -0,0 +1,147 @@
+import collections
+import unittest
+import unittest.mock
+
+from bot import constants
+from bot.decorators import InWhitelistCheckFailure, in_whitelist
+from tests import helpers
+
+
+InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
+
+
+class InWhitelistTests(unittest.TestCase):
+ """Tests for the `in_whitelist` check."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up helpers that only need to be defined once."""
+ cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456)
+ cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654)
+ cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666)
+ cls.dm_channel = helpers.MockDMChannel()
+
+ cls.non_staff_member = helpers.MockMember()
+ cls.staff_role = helpers.MockRole(id=121212)
+ cls.staff_member = helpers.MockMember(roles=(cls.staff_role,))
+
+ cls.channels = (cls.bot_commands.id,)
+ cls.categories = (cls.help_channel.category_id,)
+ cls.roles = (cls.staff_role.id,)
+
+ def test_predicate_returns_true_for_whitelisted_context(self):
+ """The predicate should return `True` if a whitelisted context was passed to it."""
+ test_cases = (
+ InWhitelistTestCase(
+ kwargs={"channels": self.channels},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="In whitelisted channels by members without whitelisted roles",
+ ),
+ InWhitelistTestCase(
+ kwargs={"redirect": self.bot_commands.id},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="`redirect` should be implicitly added to `channels`",
+ ),
+ InWhitelistTestCase(
+ kwargs={"categories": self.categories},
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member),
+ description="Whitelisted category without whitelisted role",
+ ),
+ InWhitelistTestCase(
+ kwargs={"roles": self.roles},
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member),
+ description="Whitelisted role outside of whitelisted channel/category"
+ ),
+ InWhitelistTestCase(
+ kwargs={
+ "channels": self.channels,
+ "categories": self.categories,
+ "roles": self.roles,
+ "redirect": self.bot_commands,
+ },
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member),
+ description="Case with all whitelist kwargs used",
+ ),
+ )
+
+ for test_case in test_cases:
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ self.assertTrue(predicate(test_case.ctx))
+
+ def test_predicate_raises_exception_for_non_whitelisted_context(self):
+ """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context."""
+ test_cases = (
+ # Failing check with explicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": self.bot_commands.id,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an explicit redirect channel",
+ ),
+
+ # Failing check with implicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an implicit redirect channel",
+ ),
+
+ # Failing check without `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check without a redirect channel",
+ ),
+
+ # Command issued in DM channel
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me),
+ description="Commands issued in DM channel should be rejected",
+ ),
+ )
+
+ for test_case in test_cases:
+ if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None:
+ # There are two cases in which we have a redirect channel:
+ # 1. No redirect channel was passed; the default value of `bot_commands` is used
+ # 2. An explicit `redirect` is set that is "not None"
+ redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands)
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ # If an explicit `None` was passed for `redirect`, there is no redirect channel
+ redirect_message = ""
+
+ exception_message = f"You are not allowed to use that command{redirect_message}."
+
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message):
+ predicate(test_case.ctx)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
deleted file mode 100644
index d7bcc3ba6..000000000
--- a/tests/bot/test_utils.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import unittest
-
-from bot import utils
-
-
-class CaseInsensitiveDictTests(unittest.TestCase):
- """Tests for the `CaseInsensitiveDict` container."""
-
- def test_case_insensitive_key_access(self):
- """Tests case insensitive key access and storage."""
- instance = utils.CaseInsensitiveDict()
-
- key = 'LEMON'
- value = 'trees'
-
- instance[key] = value
- self.assertIn(key, instance)
- self.assertEqual(instance.get(key), value)
- self.assertEqual(instance.get(key.casefold()), value)
- self.assertEqual(instance.pop(key.casefold()), value)
- self.assertNotIn(key, instance)
- self.assertNotIn(key.casefold(), instance)
-
- instance.setdefault(key, value)
- del instance[key]
- self.assertNotIn(key, instance)
-
- def test_initialization_from_kwargs(self):
- """Tests creating the dictionary from keyword arguments."""
- instance = utils.CaseInsensitiveDict({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')
-
- def test_update_from_other_mapping(self):
- """Tests updating the dictionary from another mapping."""
- instance = utils.CaseInsensitiveDict()
- instance.update({'FOO': 'bar'})
- self.assertEqual(instance['foo'], 'bar')