diff options
| author | 2020-06-11 18:49:00 +0300 | |
|---|---|---|
| committer | 2020-06-11 18:49:00 +0300 | |
| commit | 49ba7b98b424f9f4199f5f57b4237481954fa06c (patch) | |
| tree | ef17e82e4f018381432828285780aa453eac54e3 /tests | |
| parent | Mod Utils Tests: Replace `has_active_infraction` with `get_active_infraction` (diff) | |
| parent | Fix trailing whitespace in Action file (diff) | |
Merge branch 'master' into mod-utils-tests
Diffstat (limited to '')
| -rw-r--r-- | tests/bot/cogs/moderation/test_infractions.py | 55 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_modlog.py | 29 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 3 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_users.py | 2 | ||||
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 159 | ||||
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 2 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 15 | ||||
| -rw-r--r-- | tests/bot/cogs/test_snekbox.py | 26 | ||||
| -rw-r--r-- | tests/bot/test_constants.py | 43 | ||||
| -rw-r--r-- | tests/bot/test_converters.py | 113 | ||||
| -rw-r--r-- | tests/bot/test_decorators.py | 4 | ||||
| -rw-r--r-- | tests/bot/utils/test_checks.py | 52 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 273 | ||||
| -rw-r--r-- | tests/helpers.py | 34 | 
14 files changed, 715 insertions, 95 deletions
| diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..da4e92ccc --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): +    """Tests for ban and kick command reason truncation.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = Infractions(self.bot) +        self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) +        self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) +        self.guild = MockGuild(id=4567) +        self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + +    @patch("bot.cogs.moderation.utils.get_active_infraction") +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): +        """Should truncate reason for `ctx.guild.ban`.""" +        get_active_mock.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.bot.get_cog.return_value = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.ctx.guild.ban = Mock() + +        await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) +        self.ctx.guild.ban.assert_called_once_with( +            self.target, +            reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), +            delete_message_days=0 +        ) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value +        ) + +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_kick_reason_truncation(self, post_infraction_mock): +        """Should truncate reason for `Member.kick`.""" +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.target.kick = Mock() + +        await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) +        self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value +        ) diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..f2809f40a --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for moderation logs.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = ModLog(self.bot) +        self.channel = MockTextChannel() + +    async def test_log_entry_description_truncation(self): +        """Test that embed description for ModLog entry is truncated.""" +        self.bot.get_channel.return_value = self.channel +        await self.cog.send_log_message( +            icon_url="foo", +            colour=discord.Colour.blue(), +            title="bar", +            text="foo bar" * 3000 +        ) +        embed = self.channel.send.call_args[1]["embed"] +        self.assertEqual( +            embed.description, ("foo bar" * 3000)[:2045] + "..." +        ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 81398c61f..14fd909c4 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -247,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase):          before_data = {              "name": "old name",              "discriminator": "1234", -            "avatar": "old avatar",              "bot": False,          }          subtests = (              (True, "name", "name", "new name", "new name"),              (True, "discriminator", "discriminator", "8765", 8765), -            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),              (False, "bot", "bot", True, True),          ) @@ -295,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase):          )          data = { -            "avatar_hash": member.avatar,              "discriminator": int(member.discriminator),              "id": member.id,              "in_guild": True, diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 818883012..002a947ad 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -10,7 +10,6 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("avatar_hash", None)      kwargs.setdefault("roles", (666,))      kwargs.setdefault("in_guild", True) @@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          for member in members:              member = member.copy() -            member["avatar"] = member.pop("avatar_hash")              del member["in_guild"]              mock_member = helpers.MockMember(**member) diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): +    """Test the AntiMalware cog.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() +        self.cog = antimalware.AntiMalware(self.bot) +        self.message = MockMessage() + +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should not be deleted""" +        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_message_without_attachment(self): +        """Messages without attachments should result in no action.""" +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_direct_message_with_attachment(self): +        """Direct messages should have no action taken.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.guild = None + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_message_with_illegal_extension_gets_deleted(self): +        """A message containing an illegal extension should send an embed.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_called_once() + +    async def test_message_send_by_staff(self): +        """A message send by a member of staff should be ignored.""" +        staff_role = MockRole(id=STAFF_ROLES[0]) +        self.message.author.roles.append(staff_role) +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site""" +        attachment = MockAttachment(filename="python.py") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") + +        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt file should result in the correct embed.""" +        attachment = MockAttachment(filename="python.txt") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.TXT_EMBED_DESCRIPTION = Mock() +        antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        cmd_channel = self.bot.get_channel(Channels.bot_commands) + +        self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) +        antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + +    async def test_other_disallowed_extention_embed_description(self): +        """Test the description for a non .py/.txt disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        meta_channel = self.bot.get_channel(Channels.meta) + +        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( +            blocked_extensions_str=".disallowed", +            meta_channel_mention=meta_channel.mention +        ) + +    async def test_removing_deleted_message_logs(self): +        """Removing an already deleted message logs the correct message""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) +        self.message.delete.assert_called_once() + +    async def test_message_with_illegal_attachment_logs(self): +        """Deleting a message with an illegal attachment should result in a log.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) + +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (AntiMalwareConfig.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], [".disallowed"]), +            ([".disallowed"], [".disallowed"]), +            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] +                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): +    """Tests setup of the `AntiMalware` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        antimalware.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 7e6bfc748..a8c0107c6 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          self.assertEqual(cog.bot, bot)          self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) -        bot.loop.create_loop.called_once_with(cog.fetch_webhook()) +        bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())      def test_fetch_webhook_succeeds_without_connectivity_issues(self):          """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index b5f928dd6..79c0e0ad3 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,10 +7,9 @@ import discord  from bot import constants  from bot.cogs import information -from bot.decorators import InWhitelistCheckFailure +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  COG_PATH = "bot.cogs.information.Information" @@ -149,14 +148,18 @@ class InformationCogTests(unittest.TestCase):                  Voice region: {self.ctx.guild.region}                  Features: {', '.join(self.ctx.guild.features)} -                **Counts** -                Members: {self.ctx.guild.member_count:,} -                Roles: {len(self.ctx.guild.roles)} +                **Channel counts**                  Category channels: 1                  Text channels: 1                  Voice channels: 1 +                Staff channels: 0 + +                **Member counts** +                Members: {self.ctx.guild.member_count:,} +                Staff members: 0 +                Roles: {len(self.ctx.guild.roles)} -                **Members** +                **Member statuses**                  {constants.Emojis.status_online} 2                  {constants.Emojis.status_idle} 1                  {constants.Emojis.status_dnd} 4 diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..cf9adbee0 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Post the eval code to the URLs.snekbox_eval_api endpoint."""          resp = MagicMock()          resp.json = AsyncMock(return_value="return") -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(await self.cog.post_eval("import random"), "return")          self.bot.http_session.post.assert_called_with( @@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          key = "MarkDiamond"          resp = MagicMock()          resp.json = AsyncMock(return_value={"key": key}) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(              await self.cog.upload_output("My awesome output"), @@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Output upload gracefully fallback if the upload fail."""          resp = MagicMock()          resp.json = AsyncMock(side_effect=Exception) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          log = logging.getLogger("bot.cogs.snekbox")          with self.assertLogs(logger=log, level='ERROR'): @@ -208,10 +217,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      async def test_eval_command_call_help(self):          """Test if the eval command call the help command if no code is provided.""" -        ctx = MockContext() -        ctx.invoke = AsyncMock() +        ctx = MockContext(command="sentinel")          await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') -        ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") +        ctx.send_help.assert_called_once_with("sentinel")      async def test_send_eval(self):          """Test the send_eval function.""" @@ -291,7 +299,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(actual, expected)          self.bot.wait_for.assert_has_awaits(              ( -                call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), +                call( +                    'message_edit', +                    check=partial_mock(snekbox.predicate_eval_message_edit, ctx), +                    timeout=snekbox.REEVAL_TIMEOUT, +                ),                  call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)              )          ) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index dae7c066c..f10d6fbe8 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,14 +1,40 @@  import inspect +import typing  import unittest  from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: +    """ +    Return True if `value` is an instance of the type represented by `annotation`. + +    This doesn't account for things like Unions or checking for homogenous types in collections. +    """ +    origin = typing.get_origin(annotation) + +    # This is done in case a bare e.g. `typing.List` is used. +    # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. +    # `get_origin()` does this normalisation for us. +    type_ = annotation if origin is None else origin + +    return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: +    """Return True if `value` is an instance of any type in `types`.""" +    for type_ in types: +        if is_annotation_instance(value, type_): +            return True + +    return False + +  class ConstantsTests(unittest.TestCase):      """Tests for our constants."""      def test_section_configuration_matches_type_specification(self): -        """The section annotations should match the actual types of the sections.""" +        """"The section annotations should match the actual types of the sections."""          sections = (              cls @@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):          )          for section in sections:              for name, annotation in section.__annotations__.items(): -                with self.subTest(section=section, name=name, annotation=annotation): +                with self.subTest(section=section.__name__, name=name, annotation=annotation):                      value = getattr(section, name) +                    origin = typing.get_origin(annotation) +                    annotation_args = typing.get_args(annotation) +                    failure_msg = f"{value} is not an instance of {annotation}" -                    if getattr(annotation, '_name', None) in ('Dict', 'List'): -                        self.skipTest("Cannot validate containers yet.") - -                    self.assertIsInstance(value, annotation) +                    if origin is typing.Union: +                        is_instance = is_any_instance(value, annotation_args) +                        self.assertTrue(is_instance, failure_msg) +                    else: +                        is_instance = is_annotation_instance(value, annotation) +                        self.assertTrue(is_instance, failure_msg) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index ca8cb6825..c42111f3f 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -1,5 +1,5 @@ -import asyncio  import datetime +import re  import unittest  from unittest.mock import MagicMock, patch @@ -16,7 +16,7 @@ from bot.converters import (  ) -class ConverterTests(unittest.TestCase): +class ConverterTests(unittest.IsolatedAsyncioTestCase):      """Tests our custom argument converters."""      @classmethod @@ -26,7 +26,7 @@ class ConverterTests(unittest.TestCase):          cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') -    def test_tag_content_converter_for_valid(self): +    async def test_tag_content_converter_for_valid(self):          """TagContentConverter should return correct values for valid input."""          test_values = (              ('hello', 'hello'), @@ -35,10 +35,10 @@ class ConverterTests(unittest.TestCase):          for content, expected_conversion in test_values:              with self.subTest(content=content, expected_conversion=expected_conversion): -                conversion = asyncio.run(TagContentConverter.convert(self.context, content)) +                conversion = await TagContentConverter.convert(self.context, content)                  self.assertEqual(conversion, expected_conversion) -    def test_tag_content_converter_for_invalid(self): +    async def test_tag_content_converter_for_invalid(self):          """TagContentConverter should raise the proper exception for invalid input."""          test_values = (              ('', "Tag contents should not be empty, or filled with whitespace."), @@ -47,10 +47,10 @@ class ConverterTests(unittest.TestCase):          for value, exception_message in test_values:              with self.subTest(tag_content=value, exception_message=exception_message): -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(TagContentConverter.convert(self.context, value)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await TagContentConverter.convert(self.context, value) -    def test_tag_name_converter_for_valid(self): +    async def test_tag_name_converter_for_valid(self):          """TagNameConverter should return the correct values for valid tag names."""          test_values = (              ('tracebacks', 'tracebacks'), @@ -60,10 +60,10 @@ class ConverterTests(unittest.TestCase):          for name, expected_conversion in test_values:              with self.subTest(name=name, expected_conversion=expected_conversion): -                conversion = asyncio.run(TagNameConverter.convert(self.context, name)) +                conversion = await TagNameConverter.convert(self.context, name)                  self.assertEqual(conversion, expected_conversion) -    def test_tag_name_converter_for_invalid(self): +    async def test_tag_name_converter_for_invalid(self):          """TagNameConverter should raise the correct exception for invalid tag names."""          test_values = (              ('👋', "Don't be ridiculous, you can't use that character!"), @@ -75,29 +75,29 @@ class ConverterTests(unittest.TestCase):          for invalid_name, exception_message in test_values:              with self.subTest(invalid_name=invalid_name, exception_message=exception_message): -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(TagNameConverter.convert(self.context, invalid_name)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await TagNameConverter.convert(self.context, invalid_name) -    def test_valid_python_identifier_for_valid(self): +    async def test_valid_python_identifier_for_valid(self):          """ValidPythonIdentifier returns valid identifiers unchanged."""          test_values = ('foo', 'lemon')          for name in test_values:              with self.subTest(identifier=name): -                conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                conversion = await ValidPythonIdentifier.convert(self.context, name)                  self.assertEqual(name, conversion) -    def test_valid_python_identifier_for_invalid(self): +    async def test_valid_python_identifier_for_invalid(self):          """ValidPythonIdentifier raises the proper exception for invalid identifiers."""          test_values = ('nested.stuff', '#####')          for name in test_values:              with self.subTest(identifier=name):                  exception_message = f'`{name}` is not a valid Python identifier' -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await ValidPythonIdentifier.convert(self.context, name) -    def test_duration_converter_for_valid(self): +    async def test_duration_converter_for_valid(self):          """Duration returns the correct `datetime` for valid duration strings."""          test_values = (              # Simple duration strings @@ -159,35 +159,35 @@ class ConverterTests(unittest.TestCase):                  mock_datetime.utcnow.return_value = self.fixed_utc_now                  with self.subTest(duration=duration, duration_dict=duration_dict): -                    converted_datetime = asyncio.run(converter.convert(self.context, duration)) +                    converted_datetime = await converter.convert(self.context, duration)                      self.assertEqual(converted_datetime, expected_datetime) -    def test_duration_converter_for_invalid(self): +    async def test_duration_converter_for_invalid(self):          """Duration raises the right exception for invalid duration strings."""          test_values = (              # Units in wrong order -            ('1d1w'), -            ('1s1y'), +            '1d1w', +            '1s1y',              # Duplicated units -            ('1 year 2 years'), -            ('1 M 10 minutes'), +            '1 year 2 years', +            '1 M 10 minutes',              # Unknown substrings -            ('1MVes'), -            ('1y3breads'), +            '1MVes', +            '1y3breads',              # Missing amount -            ('ym'), +            'ym',              # Incorrect whitespace -            (" 1y"), -            ("1S "), -            ("1y  1m"), +            " 1y", +            "1S ", +            "1y  1m",              # Garbage -            ('Guido van Rossum'), -            ('lemon lemon lemon lemon lemon lemon lemon'), +            'Guido van Rossum', +            'lemon lemon lemon lemon lemon lemon lemon',          )          converter = Duration() @@ -195,10 +195,21 @@ class ConverterTests(unittest.TestCase):          for invalid_duration in test_values:              with self.subTest(invalid_duration=invalid_duration):                  exception_message = f'`{invalid_duration}` is not a valid duration string.' -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(converter.convert(self.context, invalid_duration)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, invalid_duration) -    def test_isodatetime_converter_for_valid(self): +    @patch("bot.converters.datetime") +    async def test_duration_converter_out_of_range(self, mock_datetime): +        """Duration converter should raise BadArgument if datetime raises a ValueError.""" +        mock_datetime.__add__.side_effect = ValueError +        mock_datetime.utcnow.return_value = mock_datetime + +        duration = f"{datetime.MAXYEAR}y" +        exception_message = f"`{duration}` results in a datetime outside the supported range." +        with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +            await Duration().convert(self.context, duration) + +    async def test_isodatetime_converter_for_valid(self):          """ISODateTime converter returns correct datetime for valid datetime string."""          test_values = (              # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` @@ -243,37 +254,37 @@ class ConverterTests(unittest.TestCase):          for datetime_string, expected_dt in test_values:              with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt): -                converted_dt = asyncio.run(converter.convert(self.context, datetime_string)) +                converted_dt = await converter.convert(self.context, datetime_string)                  self.assertIsNone(converted_dt.tzinfo)                  self.assertEqual(converted_dt, expected_dt) -    def test_isodatetime_converter_for_invalid(self): +    async def test_isodatetime_converter_for_invalid(self):          """ISODateTime converter raises the correct exception for invalid datetime strings."""          test_values = (              # Make sure it doesn't interfere with the Duration converter -            ('1Y'), -            ('1d'), -            ('1H'), +            '1Y', +            '1d', +            '1H',              # Check if it fails when only providing the optional time part -            ('10:10:10'), -            ('10:00'), +            '10:10:10', +            '10:00',              # Invalid date format -            ('19-01-01'), +            '19-01-01',              # Other non-valid strings -            ('fisk the tag master'), +            'fisk the tag master',          )          converter = ISODateTime()          for datetime_string in test_values:              with self.subTest(datetime_string=datetime_string):                  exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" -                with self.assertRaises(BadArgument, msg=exception_message): -                    asyncio.run(converter.convert(self.context, datetime_string)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, datetime_string) -    def test_hush_duration_converter_for_valid(self): +    async def test_hush_duration_converter_for_valid(self):          """HushDurationConverter returns correct value for minutes duration or `"forever"` strings."""          test_values = (              ("0", 0), @@ -286,10 +297,10 @@ class ConverterTests(unittest.TestCase):          converter = HushDurationConverter()          for minutes_string, expected_minutes in test_values:              with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes): -                converted = asyncio.run(converter.convert(self.context, minutes_string)) +                converted = await converter.convert(self.context, minutes_string)                  self.assertEqual(expected_minutes, converted) -    def test_hush_duration_converter_for_invalid(self): +    async def test_hush_duration_converter_for_invalid(self):          """HushDurationConverter raises correct exception for invalid minutes duration strings."""          test_values = (              ("16", "Duration must be at most 15 minutes."), @@ -299,5 +310,5 @@ class ConverterTests(unittest.TestCase):          converter = HushDurationConverter()          for invalid_minutes_string, exception_message in test_values:              with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message): -                with self.assertRaisesRegex(BadArgument, exception_message): -                    asyncio.run(converter.convert(self.context, invalid_minutes_string)) +                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): +                    await converter.convert(self.context, invalid_minutes_string) diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index a17dd3e16..3d450caa0 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -3,10 +3,10 @@ import unittest  import unittest.mock  from bot import constants -from bot.decorators import InWhitelistCheckFailure, in_whitelist +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 9610771e5..de72e5748 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,6 +1,8 @@  import unittest +from unittest.mock import MagicMock  from bot.utils import checks +from bot.utils.checks import InWhitelistCheckFailure  from tests.helpers import MockContext, MockRole @@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):          self.ctx.author.roles.append(MockRole(id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) -    def test_in_channel_check_for_correct_channel(self): -        self.ctx.channel.id = 42 -        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_correct_channel(self): +        """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" +        channel_id = 3 +        self.ctx.channel.id = channel_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id])) -    def test_in_channel_check_for_incorrect_channel(self): -        self.ctx.channel.id = 42 + 10 -        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_incorrect_channel(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match.""" +        self.ctx.channel.id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, [4]) + +    def test_in_whitelist_check_correct_category(self): +        """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list.""" +        category_id = 3 +        self.ctx.channel.category_id = category_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id])) + +    def test_in_whitelist_check_incorrect_category(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match.""" +        self.ctx.channel.category_id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, categories=[4]) + +    def test_in_whitelist_check_correct_role(self): +        """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6])) + +    def test_in_whitelist_check_incorrect_role(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, roles=[4]) + +    def test_in_whitelist_check_fail_silently(self): +        """`in_whitelist_check` test no exception raised if `fail_silently` is `True`""" +        self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True)) + +    def test_in_whitelist_check_complex(self): +        """`in_whitelist_check` test with multiple parameters""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.ctx.channel.category_id = 3 +        self.ctx.channel.id = 5 +        self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2])) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py new file mode 100644 index 000000000..8c1a40640 --- /dev/null +++ b/tests/bot/utils/test_redis_cache.py @@ -0,0 +1,273 @@ +import asyncio +import unittest + +import fakeredis.aioredis + +from bot.utils import RedisCache +from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError +from tests import helpers + + +class RedisCacheTests(unittest.IsolatedAsyncioTestCase): +    """Tests the RedisCache class from utils.redis_dict.py.""" + +    async def asyncSetUp(self):  # noqa: N802 +        """Sets up the objects that only have to be initialized once.""" +        self.bot = helpers.MockBot() +        self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() + +        # Okay, so this is necessary so that we can create a clean new +        # class for every test method, and we want that because it will +        # ensure we get a fresh loop, which is necessary for test_increment_lock +        # to be able to pass. +        class DummyCog: +            """A dummy cog, for dummies.""" + +            redis = RedisCache() + +            def __init__(self, bot: helpers.MockBot): +                self.bot = bot + +        self.cog = DummyCog(self.bot) + +        await self.cog.redis.clear() + +    def test_class_attribute_namespace(self): +        """Test that RedisDict creates a namespace automatically for class attributes.""" +        self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") + +    async def test_class_attribute_required(self): +        """Test that errors are raised when not assigned as a class attribute.""" +        bad_cache = RedisCache() +        self.assertIs(bad_cache._namespace, None) + +        with self.assertRaises(RuntimeError): +            await bad_cache.set("test", "me_up_deadman") + +    def test_namespace_collision(self): +        """Test that we prevent colliding namespaces.""" +        bob_cache_1 = RedisCache() +        bob_cache_1._set_namespace("BobRoss") +        self.assertEqual(bob_cache_1._namespace, "BobRoss") + +        bob_cache_2 = RedisCache() +        bob_cache_2._set_namespace("BobRoss") +        self.assertEqual(bob_cache_2._namespace, "BobRoss_") + +    async def test_set_get_item(self): +        """Test that users can set and get items from the RedisDict.""" +        test_cases = ( +            ('favorite_fruit', 'melon'), +            ('favorite_number', 86), +            ('favorite_fraction', 86.54) +        ) + +        # Test that we can get and set different types. +        for test in test_cases: +            await self.cog.redis.set(*test) +            self.assertEqual(await self.cog.redis.get(test[0]), test[1]) + +        # Test that .get allows a default value +        self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + +    async def test_set_item_type(self): +        """Test that .set rejects keys and values that are not permitted.""" +        fruits = ["lemon", "melon", "apple"] + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(fruits, "nice") + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(4.23, "nice") + +    async def test_delete_item(self): +        """Test that .delete allows us to delete stuff from the RedisCache.""" +        # Add an item and verify that it gets added +        await self.cog.redis.set("internet", "firetruck") +        self.assertEqual(await self.cog.redis.get("internet"), "firetruck") + +        # Delete that item and verify that it gets deleted +        await self.cog.redis.delete("internet") +        self.assertIs(await self.cog.redis.get("internet"), None) + +    async def test_contains(self): +        """Test that we can check membership with .contains.""" +        await self.cog.redis.set('favorite_country', "Burkina Faso") + +        self.assertIs(await self.cog.redis.contains('favorite_country'), True) +        self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) + +    async def test_items(self): +        """Test that the RedisDict can be iterated.""" +        # Set up our test cases in the Redis cache +        test_cases = [ +            ('favorite_turtle', 'Donatello'), +            ('second_favorite_turtle', 'Leonardo'), +            ('third_favorite_turtle', 'Raphael'), +        ] +        for key, value in test_cases: +            await self.cog.redis.set(key, value) + +        # Consume the AsyncIterator into a regular list, easier to compare that way. +        redis_items = [item for item in await self.cog.redis.items()] + +        # These sequences are probably in the same order now, but probably +        # isn't good enough for tests. Let's not rely on .hgetall always +        # returning things in sequence, and just sort both lists to be safe. +        redis_items = sorted(redis_items) +        test_cases = sorted(test_cases) + +        # If these are equal now, everything works fine. +        self.assertSequenceEqual(test_cases, redis_items) + +    async def test_length(self): +        """Test that we can get the correct .length from the RedisDict.""" +        await self.cog.redis.set('one', 1) +        await self.cog.redis.set('two', 2) +        await self.cog.redis.set('three', 3) +        self.assertEqual(await self.cog.redis.length(), 3) + +        await self.cog.redis.set('four', 4) +        self.assertEqual(await self.cog.redis.length(), 4) + +    async def test_to_dict(self): +        """Test that the .to_dict method returns a workable dictionary copy.""" +        copy = await self.cog.redis.to_dict() +        local_copy = {key: value for key, value in await self.cog.redis.items()} +        self.assertIs(type(copy), dict) +        self.assertDictEqual(copy, local_copy) + +    async def test_clear(self): +        """Test that the .clear method removes the entire hash.""" +        await self.cog.redis.set('teddy', 'with me') +        await self.cog.redis.set('in my dreams', 'you have a weird hat') +        self.assertEqual(await self.cog.redis.length(), 2) + +        await self.cog.redis.clear() +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_pop(self): +        """Test that we can .pop an item from the RedisDict.""" +        await self.cog.redis.set('john', 'was afraid') + +        self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') +        self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_update(self): +        """Test that we can .update the RedisDict with multiple items.""" +        await self.cog.redis.set("reckfried", "lona") +        await self.cog.redis.set("bel air", "prince") +        await self.cog.redis.update({ +            "reckfried": "jona", +            "mega": "hungry, though", +        }) + +        result = { +            "reckfried": "jona", +            "bel air": "prince", +            "mega": "hungry, though", +        } +        self.assertDictEqual(await self.cog.redis.to_dict(), result) + +    def test_typestring_conversion(self): +        """Test the typestring-related helper functions.""" +        conversion_tests = ( +            (12, "i|12"), +            (12.4, "f|12.4"), +            ("cowabunga", "s|cowabunga"), +        ) + +        # Test conversion to typestring +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) + +        # Test conversion from typestrings +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) + +        # Test that exceptions are raised on invalid input +        with self.assertRaises(TypeError): +            self.cog.redis._value_to_typestring(["internet"]) +            self.cog.redis._value_from_typestring("o|firedog") + +    async def test_increment_decrement(self): +        """Test .increment and .decrement methods.""" +        await self.cog.redis.set("entropic", 5) +        await self.cog.redis.set("disentropic", 12.5) + +        # Test default increment +        await self.cog.redis.increment("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 6) + +        # Test default decrement +        await self.cog.redis.decrement("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # Test float increment with float +        await self.cog.redis.increment("disentropic", 2.0) +        self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) + +        # Test float increment with int +        await self.cog.redis.increment("disentropic", 2) +        self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) + +        # Test negative increments, because why not. +        await self.cog.redis.increment("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 0) + +        # Negative decrements? Sure. +        await self.cog.redis.decrement("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # What about if we use a negative float to decrement an int? +        # This should convert the type into a float. +        await self.cog.redis.decrement("entropic", -2.5) +        self.assertEqual(await self.cog.redis.get("entropic"), 7.5) + +        # Let's test that they raise the right errors +        with self.assertRaises(KeyError): +            await self.cog.redis.increment("doesn't_exist!") + +        await self.cog.redis.set("stringthing", "stringthing") +        with self.assertRaises(TypeError): +            await self.cog.redis.increment("stringthing") + +    async def test_increment_lock(self): +        """Test that we can't produce a race condition in .increment.""" +        await self.cog.redis.set("test_key", 0) +        tasks = [] + +        # Increment this a lot in different tasks +        for _ in range(100): +            task = asyncio.create_task( +                self.cog.redis.increment("test_key", 1) +            ) +            tasks.append(task) +        await asyncio.gather(*tasks) + +        # Confirm that the value has been incremented the exact right number of times. +        value = await self.cog.redis.get("test_key") +        self.assertEqual(value, 100) + +    async def test_exceptions_raised(self): +        """Testing that the various RuntimeErrors are reachable.""" +        class MyCog: +            cache = RedisCache() + +            def __init__(self): +                self.other_cache = RedisCache() + +        cog = MyCog() + +        # Raises "No Bot instance" +        with self.assertRaises(NoBotInstanceError): +            await cog.cache.get("john") + +        # Raises "RedisCache has no namespace" +        with self.assertRaises(NoNamespaceError): +            await cog.other_cache.get("was") + +        # Raises "You must access the RedisCache instance through the cog instance" +        with self.assertRaises(NoParentInstanceError): +            await MyCog.cache.get("afraid") diff --git a/tests/helpers.py b/tests/helpers.py index 2b79a6c2a..faa839370 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,12 +4,15 @@ import collections  import itertools  import logging  import unittest.mock +from asyncio import AbstractEventLoop  from typing import Iterable, Optional  import discord +from aiohttp import ClientSession  from discord.ext.commands import Context  from bot.api import APIClient +from bot.async_stats import AsyncStatsClient  from bot.bot import Bot @@ -205,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):          """Simplified position-based comparisons similar to those of `discord.Role`."""          return self.position < other.position +    def __ge__(self, other): +        """Simplified position-based comparisons similar to those of `discord.Role`.""" +        return self.position >= other.position +  # Create a Member instance to get a realistic Mock of `discord.Member`  member_data = {'user': 'lemon', 'roles': [1]} @@ -264,10 +271,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):      spec_set = APIClient -# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -bot_instance.http_session = None -bot_instance.api_client = None +def _get_mock_loop() -> unittest.mock.Mock: +    """Return a mocked asyncio.AbstractEventLoop.""" +    loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) + +    # Since calling `create_task` on our MockBot does not actually schedule the coroutine object +    # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object +    # to prevent "has not been awaited"-warnings. +    loop.create_task.side_effect = lambda coroutine: coroutine.close() + +    return loop  class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -277,17 +290,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ -    spec_set = bot_instance -    additional_spec_asyncs = ("wait_for",) +    spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) +    additional_spec_asyncs = ("wait_for", "redis_ready")      def __init__(self, **kwargs) -> None:          super().__init__(**kwargs) -        self.api_client = MockAPIClient() -        # Since calling `create_task` on our MockBot does not actually schedule the coroutine object -        # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object -        # to prevent "has not been awaited"-warnings. -        self.loop.create_task.side_effect = lambda coroutine: coroutine.close() +        self.loop = _get_mock_loop() +        self.api_client = MockAPIClient(loop=self.loop) +        self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) +        self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` | 
