diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/utils/snekbox.py | 52 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 24 | 
2 files changed, 29 insertions, 47 deletions
| diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 49f1be17b..1d9646113 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -36,6 +36,8 @@ RAW_CODE_REGEX = re.compile(      re.DOTALL                               # "." also matches newlines  ) +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided.  TIMEIT_SETUP_WRAPPER = """  import atexit  import sys @@ -163,19 +165,19 @@ class Snekbox(Cog):          return code, args      @staticmethod -    def get_results_message(results: dict) -> Tuple[str, str]: +    def get_results_message(results: dict, job_name: str) -> Tuple[str, str]:          """Return a user-friendly message and error corresponding to the process's return code."""          stdout, returncode = results["stdout"], results["returncode"] -        msg = f"Your eval job has completed with return code {returncode}" +        msg = f"Your {job_name} job has completed with return code {returncode}"          error = ""          if returncode is None: -            msg = "Your eval job has failed" +            msg = f"Your {job_name} job has failed"              error = stdout.strip()          elif returncode == 128 + SIGKILL: -            msg = "Your eval job timed out or ran out of memory" +            msg = f"Your {job_name} job timed out or ran out of memory"          elif returncode == 255: -            msg = "Your eval job has failed" +            msg = f"Your {job_name} job has failed"              error = "A fatal NsJail error occurred"          else:              # Try to append signal's name if one exists @@ -249,7 +251,7 @@ class Snekbox(Cog):          code: str,          *,          args: Optional[list[str]] = None, -        format_func: FormatFunc +        job_name: str      ) -> Message:          """          Evaluate code, format it, and send the output to the corresponding channel. @@ -258,13 +260,13 @@ class Snekbox(Cog):          """          async with ctx.typing():              results = await self.post_eval(code, args=args) -            msg, error = self.get_results_message(results) +            msg, error = self.get_results_message(results, job_name)              if error:                  output, paste_link = error, None              else:                  log.trace("Formatting output...") -                output, paste_link = await format_func(results["stdout"]) +                output, paste_link = await self.format_output(results["stdout"])              icon = self.get_status_emoji(results)              msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" @@ -288,12 +290,12 @@ class Snekbox(Cog):                  response = await ctx.send(msg, allowed_mentions=allowed_mentions)              scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop) -            log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") +            log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")          return response      async def continue_eval(          self, ctx: Context, response: Message, command: Command -    ) -> Optional[tuple[str, Optional[list[str]]]]: +    ) -> tuple[Optional[str], Optional[list[str]]]:          """          Check if the eval session should continue. @@ -355,19 +357,15 @@ class Snekbox(Cog):          return code -    async def run_eval( +    async def run_job(          self, +        job_name: str,          ctx: Context,          code: str, -        format_func: FormatFunc,          *,          args: Optional[list[str]] = None,      ) -> None: -        """ -        Handles checks, stats and re-evaluation of an eval. - -        `format_func` is an async callable that takes a string (the output) and formats it to show to the user. -        """ +        """Handles checks, stats and re-evaluation of a snekbox job."""          if ctx.author.id in self.jobs:              await ctx.send(                  f"{ctx.author.mention} You've already got a job running - " @@ -392,7 +390,7 @@ class Snekbox(Cog):          while True:              self.jobs[ctx.author.id] = datetime.datetime.now()              try: -                response = await self.send_eval(ctx, code, args=args, format_func=format_func) +                response = await self.send_eval(ctx, code, args=args, job_name=job_name)              finally:                  del self.jobs[ctx.author.id] @@ -401,18 +399,6 @@ class Snekbox(Cog):                  break              log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") -    async def format_timeit_output(self, output: str) -> tuple[str, str]: -        """ -        Parses the time from the end of the output given by timeit. - -        If an error happened, then it won't contain the time and instead proceed with regular formatting. -        """ -        split_output = output.rstrip("\n").rsplit("\n", 1) -        if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]): -            return split_output[1], None - -        return await self.format_output(output) -      @command(name="eval", aliases=("e",))      @guild_only()      @redirect_output( @@ -434,7 +420,7 @@ class Snekbox(Cog):          issue with it!          """          code = "\n".join(self.prepare_input(code)) -        await self.run_eval(ctx, code, format_func=self.format_output) +        await self.run_job("eval", ctx, code)      @command(name="timeit", aliases=("ti",))      @guild_only() @@ -462,9 +448,7 @@ class Snekbox(Cog):          codeblocks = self.prepare_input(code)          code, args = self.prepare_timeit_input(codeblocks) -        await self.run_eval( -            ctx, code=code, format_func=self.format_timeit_output, args=args -        ) +        await self.run_job("timeit", ctx, code=code, args=args)  def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 4245de8a3..339cdaaa4 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -72,13 +72,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for stdout, returncode, expected in cases:              with self.subTest(stdout=stdout, returncode=returncode, expected=expected): -                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) +                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval')                  self.assertEqual(actual, expected)      @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)      def test_get_results_message_invalid_signal(self, mock_signals: Mock):          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),              ('Your eval job has completed with return code 127', '')          ) @@ -86,7 +86,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      def test_get_results_message_valid_signal(self, mock_signals: Mock):          mock_signals.return_value.name = 'SIGTEST'          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),              ('Your eval job has completed with return code 127 (SIGTEST)', '')          ) @@ -164,9 +164,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')          self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') -        self.cog.send_eval.assert_called_once_with( -            ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output -        ) +        self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval')          self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command)      async def test_eval_command_evaluate_twice(self): @@ -182,7 +180,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')          self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))          self.cog.send_eval.assert_called_with( -            ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output +            ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval'          )          self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) @@ -214,7 +212,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          mocked_filter_cog.filter_eval = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) +        await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -227,7 +225,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval')          self.cog.format_output.assert_called_once_with('')      async def test_send_eval_with_paste_link(self): @@ -246,7 +244,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          mocked_filter_cog.filter_eval = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) +        await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -257,7 +255,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval')          self.cog.format_output.assert_called_once_with('Way too long beard')      async def test_send_eval_with_non_zero_eval(self): @@ -275,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          mocked_filter_cog.filter_eval = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) +        await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -285,7 +283,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) -        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval')          self.cog.format_output.assert_not_called()      @patch("bot.exts.utils.snekbox.partial") | 
