diff options
| author | 2020-02-28 09:45:21 -0800 | |
|---|---|---|
| committer | 2020-02-28 09:45:21 -0800 | |
| commit | a20ab7e27f06a66f1c35bdc30a54898222f7388a (patch) | |
| tree | 7b15013a7f435df9583197be8bf092a2109b9fe4 | |
| parent | Merge pull request #757 from python-discord/feat/backend/b131/error-handling (diff) | |
| parent | Snekbox: mention re-evaluation feature in the command's docstring (diff) | |
Merge pull request #710 from python-discord/eval-enhancements
Eval cog improvements 
| -rw-r--r-- | bot/cogs/snekbox.py | 131 | ||||
| -rw-r--r-- | tests/bot/cogs/test_snekbox.py | 368 | ||||
| -rw-r--r-- | tests/helpers.py | 12 | 
3 files changed, 479 insertions, 32 deletions
| diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index aef12546d..cff7c5786 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -1,10 +1,14 @@ +import asyncio +import contextlib  import datetime  import logging  import re  import textwrap +from functools import partial  from signal import Signals  from typing import Optional, Tuple +from discord import HTTPException, Message, NotFound, Reaction, User  from discord.ext.commands import Cog, Context, command, guild_only  from bot.bot import Bot @@ -36,6 +40,10 @@ RAW_CODE_REGEX = re.compile(  MAX_PASTE_LEN = 1000  EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501'  # :repeat: +  class Snekbox(Cog):      """Safe evaluation of Python code using Snekbox.""" @@ -101,7 +109,7 @@ class Snekbox(Cog):          if returncode is None:              msg = "Your eval job has failed"              error = stdout.strip() -        elif returncode == 128 + Signals.SIGKILL: +        elif returncode == 128 + SIGKILL:              msg = "Your eval job timed out or ran out of memory"          elif returncode == 255:              msg = "Your eval job has failed" @@ -135,7 +143,7 @@ class Snekbox(Cog):          """          log.trace("Formatting output...") -        output = output.strip(" \n") +        output = output.rstrip("\n")          original_output = output  # To be uploaded to a pasting service if needed          paste_link = None @@ -152,8 +160,8 @@ class Snekbox(Cog):          lines = output.count("\n")          if lines > 0: -            output = output.split("\n")[:10]  # Only first 10 cause the rest is truncated anyway -            output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1)) +            output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] +            output = output[:11]  # Limiting to only 11 lines              output = "\n".join(output)          if lines > 10: @@ -169,12 +177,72 @@ class Snekbox(Cog):          if truncated:              paste_link = await self.upload_output(original_output) -        output = output.strip() -        if not output: -            output = "[No output]" +        output = output or "[No output]"          return output, paste_link +    async def send_eval(self, ctx: Context, code: str) -> Message: +        """ +        Evaluate code, format it, and send the output to the corresponding channel. + +        Return the bot response. +        """ +        async with ctx.typing(): +            results = await self.post_eval(code) +            msg, error = self.get_results_message(results) + +            if error: +                output, paste_link = error, None +            else: +                output, paste_link = await self.format_output(results["stdout"]) + +            icon = self.get_status_emoji(results) +            msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```" +            if paste_link: +                msg = f"{msg}\nFull output: {paste_link}" + +            response = await ctx.send(msg) +            self.bot.loop.create_task( +                wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) +            ) + +            log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") +        return response + +    async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: +        """ +        Check if the eval session should continue. + +        Return the new code to evaluate or None if the eval session should be terminated. +        """ +        _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) +        _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + +        with contextlib.suppress(NotFound): +            try: +                _, new_message = await self.bot.wait_for( +                    'message_edit', +                    check=_predicate_eval_message_edit, +                    timeout=10 +                ) +                await ctx.message.add_reaction(REEVAL_EMOJI) +                await self.bot.wait_for( +                    'reaction_add', +                    check=_predicate_emoji_reaction, +                    timeout=10 +                ) + +                code = new_message.content.split(' ', maxsplit=1)[1] +                await ctx.message.clear_reactions() +                with contextlib.suppress(HTTPException): +                    await response.delete() + +            except asyncio.TimeoutError: +                await ctx.message.clear_reactions() +                return None + +            return code +      @command(name="eval", aliases=("e",))      @guild_only()      @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) @@ -183,7 +251,10 @@ class Snekbox(Cog):          Run Python code and get the results.          This command supports multiple lines of code, including code wrapped inside a formatted code -        block. We've done our best to make this safe, but do let us know if you manage to find an +        block. Code can be re-evaluated by editing the original message within 10 seconds and +        clicking the reaction that subsequently appears. + +        We've done our best to make this sandboxed, but do let us know if you manage to find an          issue with it!          """          if ctx.author.id in self.jobs: @@ -199,32 +270,28 @@ class Snekbox(Cog):          log.info(f"Received code from {ctx.author} for evaluation:\n{code}") -        self.jobs[ctx.author.id] = datetime.datetime.now() -        code = self.prepare_input(code) +        while True: +            self.jobs[ctx.author.id] = datetime.datetime.now() +            code = self.prepare_input(code) +            try: +                response = await self.send_eval(ctx, code) +            finally: +                del self.jobs[ctx.author.id] + +            code = await self.continue_eval(ctx, response) +            if not code: +                break +            log.info(f"Re-evaluating message {ctx.message.id}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: +    """Return True if the edited message is the context message and the content was indeed modified.""" +    return new_msg.id == ctx.message.id and old_msg.content != new_msg.content -        try: -            async with ctx.typing(): -                results = await self.post_eval(code) -                msg, error = self.get_results_message(results) - -                if error: -                    output, paste_link = error, None -                else: -                    output, paste_link = await self.format_output(results["stdout"]) - -                icon = self.get_status_emoji(results) -                msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```" -                if paste_link: -                    msg = f"{msg}\nFull output: {paste_link}" - -                response = await ctx.send(msg) -                self.bot.loop.create_task( -                    wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) -                ) -                log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") -        finally: -            del self.jobs[ctx.author.id] +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: +    """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" +    return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI  def setup(bot: Bot) -> None: diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py new file mode 100644 index 000000000..985bc66a1 --- /dev/null +++ b/tests/bot/cogs/test_snekbox.py @@ -0,0 +1,368 @@ +import asyncio +import logging +import unittest +from functools import partial +from unittest.mock import MagicMock, Mock, call, patch + +from bot.cogs import snekbox +from bot.cogs.snekbox import Snekbox +from bot.constants import URLs +from tests.helpers import ( +    AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test +) + + +class SnekboxTests(unittest.TestCase): +    def setUp(self): +        """Add mocked bot and cog to the instance.""" +        self.bot = MockBot() + +        self.mocked_post = MagicMock() +        self.mocked_post.json = AsyncMock() +        self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post)) + +        self.cog = Snekbox(bot=self.bot) + +    @async_test +    async def test_post_eval(self): +        """Post the eval code to the URLs.snekbox_eval_api endpoint.""" +        self.mocked_post.json.return_value = {'lemon': 'AI'} + +        self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'}) +        self.bot.http_session.post.assert_called_once_with( +            URLs.snekbox_eval_api, +            json={"input": "import random"}, +            raise_for_status=True +        ) + +    @async_test +    async def test_upload_output_reject_too_long(self): +        """Reject output longer than MAX_PASTE_LEN.""" +        result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) +        self.assertEqual(result, "too long to upload") + +    @async_test +    async def test_upload_output(self): +        """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" +        key = "RainbowDash" +        self.mocked_post.json.return_value = {"key": key} + +        self.assertEqual( +            await self.cog.upload_output("My awesome output"), +            URLs.paste_service.format(key=key) +        ) +        self.bot.http_session.post.assert_called_once_with( +            URLs.paste_service.format(key="documents"), +            data="My awesome output", +            raise_for_status=True +        ) + +    @async_test +    async def test_upload_output_gracefully_fallback_if_exception_during_request(self): +        """Output upload gracefully fallback if the upload fail.""" +        self.mocked_post.json.side_effect = Exception +        log = logging.getLogger("bot.cogs.snekbox") +        with self.assertLogs(logger=log, level='ERROR'): +            await self.cog.upload_output('My awesome output!') + +    @async_test +    async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): +        """Output upload gracefully fallback if there is no key entry in the response body.""" +        self.mocked_post.json.return_value = {} +        self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + +    def test_prepare_input(self): +        cases = ( +            ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), +            ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), +            ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), +            ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), +        ) +        for case, expected, testname in cases: +            with self.subTest(msg=f'Extract code from {testname}.'): +                self.assertEqual(self.cog.prepare_input(case), expected) + +    def test_get_results_message(self): +        """Return error and message according to the eval result.""" +        cases = ( +            ('ERROR', None, ('Your eval job has failed', 'ERROR')), +            ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), +            ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) +        ) +        for stdout, returncode, expected in cases: +            with self.subTest(stdout=stdout, returncode=returncode, expected=expected): +                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) +                self.assertEqual(actual, expected) + +    @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) +    def test_get_results_message_invalid_signal(self, mock_Signals: Mock): +        self.assertEqual( +            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            ('Your eval job has completed with return code 127', '') +        ) + +    @patch('bot.cogs.snekbox.Signals') +    def test_get_results_message_valid_signal(self, mock_Signals: Mock): +        mock_Signals.return_value.name = 'SIGTEST' +        self.assertEqual( +            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            ('Your eval job has completed with return code 127 (SIGTEST)', '') +        ) + +    def test_get_status_emoji(self): +        """Return emoji according to the eval result.""" +        cases = ( +            (' ', -1, ':warning:'), +            ('Hello world!', 0, ':white_check_mark:'), +            ('Invalid beard size', -1, ':x:') +        ) +        for stdout, returncode, expected in cases: +            with self.subTest(stdout=stdout, returncode=returncode, expected=expected): +                actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) +                self.assertEqual(actual, expected) + +    @async_test +    async def test_format_output(self): +        """Test output formatting.""" +        self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + +        too_many_lines = ( +            '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' +            '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' +        ) +        too_long_too_many_lines = ( +            "\n".join( +                f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) +            )[:1000] + "\n... (truncated - too long, too many lines)" +        ) + +        cases = ( +            ('', ('[No output]', None), 'No output'), +            ('My awesome output', ('My awesome output', None), 'One line output'), +            ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), +            ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'), +            ( +                '\u202E\u202E\u202E', +                ('Code block escape attempt detected; will not output result', None), +                'Detect RIGHT-TO-LEFT OVERRIDE' +            ), +            ( +                '\u200B\u200B\u200B', +                ('Code block escape attempt detected; will not output result', None), +                'Detect ZERO WIDTH SPACE' +            ), +            ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'), +            ( +                'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd', +                (too_many_lines, 'https://testificate.com/'), +                '12 lines output' +            ), +            ( +                'verylongbeard' * 100, +                ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'), +                '1300 characters output' +            ), +            ( +                ('verylongbeard' * 10 + '\n') * 15, +                (too_long_too_many_lines, 'https://testificate.com/'), +                '15 lines, 1965 characters output' +            ), +        ) +        for case, expected, testname in cases: +            with self.subTest(msg=testname, case=case, expected=expected): +                self.assertEqual(await self.cog.format_output(case), expected) + +    @async_test +    async def test_eval_command_evaluate_once(self): +        """Test the eval command procedure.""" +        ctx = MockContext() +        response = MockMessage() +        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') +        self.cog.send_eval = AsyncMock(return_value=response) +        self.cog.continue_eval = AsyncMock(return_value=None) + +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') +        self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') +        self.cog.continue_eval.assert_called_once_with(ctx, response) + +    @async_test +    async def test_eval_command_evaluate_twice(self): +        """Test the eval and re-eval command procedure.""" +        ctx = MockContext() +        response = MockMessage() +        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') +        self.cog.send_eval = AsyncMock(return_value=response) +        self.cog.continue_eval = AsyncMock() +        self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) +        self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') +        self.cog.continue_eval.assert_called_with(ctx, response) + +    @async_test +    async def test_eval_command_reject_two_eval_at_the_same_time(self): +        """Test if the eval command rejects an eval if the author already have a running eval.""" +        ctx = MockContext() +        ctx.author.id = 42 +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.send = AsyncMock() +        self.cog.jobs = (42,) +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        ctx.send.assert_called_once_with( +            "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" +        ) + +    @async_test +    async def test_eval_command_call_help(self): +        """Test if the eval command call the help command if no code is provided.""" +        ctx = MockContext() +        ctx.invoke = AsyncMock() +        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') +        ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") + +    @async_test +    async def test_send_eval(self): +        """Test the send_eval function.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) +        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.format_output.assert_called_once_with('') + +    @async_test +    async def test_send_eval_with_paste_link(self): +        """Test the send_eval function with a too long output that generate a paste link.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) +        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :yay!: Return code 0.' +            '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.format_output.assert_called_once_with('Way too long beard') + +    @async_test +    async def test_send_eval_with_non_zero_eval(self): +        """Test the send_eval function with a code returning a non-zero code.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = '@LemonLemonishBeard#0042' +        ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None)) +        self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) +        self.cog.get_status_emoji = MagicMock(return_value=':nope!:') +        self.cog.format_output = AsyncMock()  # This function isn't called + +        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        ctx.send.assert_called_once_with( +            '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' +        ) +        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.format_output.assert_not_called() + +    @async_test +    async def test_continue_eval_does_continue(self): +        """Test that the continue_eval function does continue if required conditions are met.""" +        ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) +        response = MockMessage(delete=AsyncMock()) +        new_msg = MockMessage(content='!e NewCode') +        self.bot.wait_for.side_effect = ((None, new_msg), None) + +        actual = await self.cog.continue_eval(ctx, response) +        self.assertEqual(actual, 'NewCode') +        self.bot.wait_for.has_calls( +            call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10), +            call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +        ) +        ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) +        ctx.message.clear_reactions.assert_called_once() +        response.delete.assert_called_once() + +    @async_test +    async def test_continue_eval_does_not_continue(self): +        ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) +        self.bot.wait_for.side_effect = asyncio.TimeoutError + +        actual = await self.cog.continue_eval(ctx, MockMessage()) +        self.assertEqual(actual, None) +        ctx.message.clear_reactions.assert_called_once() + +    def test_predicate_eval_message_edit(self): +        """Test the predicate_eval_message_edit function.""" +        msg0 = MockMessage(id=1, content='abc') +        msg1 = MockMessage(id=2, content='abcdef') +        msg2 = MockMessage(id=1, content='abcdef') + +        cases = ( +            (msg0, msg0, False, 'same ID, same content'), +            (msg0, msg1, False, 'different ID, different content'), +            (msg0, msg2, True, 'same ID, different content') +        ) +        for ctx_msg, new_msg, expected, testname in cases: +            with self.subTest(msg=f'Messages with {testname} return {expected}'): +                ctx = MockContext(message=ctx_msg) +                actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) +                self.assertEqual(actual, expected) + +    def test_predicate_eval_emoji_reaction(self): +        """Test the predicate_eval_emoji_reaction function.""" +        valid_reaction = MockReaction(message=MockMessage(id=1)) +        valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI +        valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) +        valid_user = MockUser(id=2) + +        invalid_reaction_id = MockReaction(message=MockMessage(id=42)) +        invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI +        invalid_user_id = MockUser(id=42) +        invalid_reaction_str = MockReaction(message=MockMessage(id=1)) +        invalid_reaction_str.__str__.return_value = ':longbeard:' + +        cases = ( +            (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), +            (valid_reaction, invalid_user_id, False, 'invalid user ID'), +            (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), +            (valid_reaction, valid_user, True, 'matching attributes') +        ) +        for reaction, user, expected, testname in cases: +            with self.subTest(msg=f'Test with {testname} and expected return {expected}'): +                actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) +                self.assertEqual(actual, expected) + + +class SnekboxSetupTests(unittest.TestCase): +    """Tests setup of the `Snekbox` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        snekbox.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/helpers.py b/tests/helpers.py index 9d9dd5da6..6f50f6ae3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -128,6 +128,18 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):          return super().__call__(*args, **kwargs) +class AsyncContextManagerMock(unittest.mock.MagicMock): +    def __init__(self, return_value: Any): +        super().__init__() +        self._return_value = return_value + +    async def __aenter__(self): +        return self._return_value + +    async def __aexit__(self, *args): +        pass + +  class AsyncIteratorMock:      """      A class to mock asynchronous iterators. | 
