diff options
-rw-r--r-- | bot/exts/filters/filtering.py | 4 | ||||
-rw-r--r-- | bot/exts/utils/snekbox.py | 266 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 69 |
3 files changed, 225 insertions, 114 deletions
diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 1f83acf9b..599302576 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,9 +267,9 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_eval(self, result: str, msg: Message) -> bool: + async def filter_snekbox_job(self, result: str, msg: Message) -> bool: """ - Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. + Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. Also requires the original message, to check whether to filter and for mod logs. Returns whether a filter was triggered or not. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index cc3a2e1d7..15599208f 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -2,14 +2,14 @@ import asyncio import contextlib import datetime import re -import textwrap from functools import partial from signal import Signals +from textwrap import dedent from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -22,17 +22,50 @@ log = get_logger(__name__) ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided. +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque + +if not hasattr(sys, "_setup_finished"): + class Writer(deque): + def __init__(self): + super().__init__(maxlen=1) + + def write(self, string): + if string.strip(): + self.append(string) + + def read(self): + return self.pop() + + def flush(self): + pass + + sys.stdout = Writer() + + def print_last_line(): + if sys.stdout: + print(sys.stdout.read(), file=sys.__stdout__) + + atexit.register(print_last_line) + sys._setup_finished = None +{setup} +""" + MAX_PASTE_LEN = 10000 -# `!eval` command whitelists and blacklists. -NO_EVAL_CHANNELS = (Channels.python_general,) -NO_EVAL_CATEGORIES = () -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +# The Snekbox commands' whitelists and blacklists. +NO_SNEKBOX_CHANNELS = (Channels.python_general,) +NO_SNEKBOX_CATEGORIES = () +SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 -REEVAL_EMOJI = '\U0001f501' # :repeat: -REEVAL_TIMEOUT = 30 +REDO_EMOJI = '\U0001f501' # :repeat: +REDO_TIMEOUT = 30 class Snekbox(Cog): @@ -42,15 +75,19 @@ class Snekbox(Cog): self.bot = bot self.jobs = {} - async def post_eval(self, code: str) -> dict: + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" url = URLs.snekbox_eval_api data = {"input": code} + + if args is not None: + data["args"] = args + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: return await resp.json() async def upload_output(self, output: str) -> Optional[str]: - """Upload the eval output to a paste service and return a URL to it if successful.""" + """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") if len(output) > MAX_PASTE_LEN: @@ -59,49 +96,70 @@ class Snekbox(Cog): return await send_to_paste_service(output, extension="txt") @staticmethod - def prepare_input(code: str) -> str: + def prepare_input(code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. If there is any code block, ignore text outside the code block. Use the first code block, but prefer a fenced code block. If there are several fenced code blocks, concatenate only the fenced code blocks. + + Return a list of code blocks if any, otherwise return a list with a single string of code. """ if match := list(FORMATTED_CODE_REGEX.finditer(code)): blocks = [block for block in match if block.group("block")] if len(blocks) > 1: - code = '\n'.join(block.group("code") for block in blocks) + codeblocks = [block.group("code") for block in blocks] info = "several code blocks" else: match = match[0] if len(blocks) == 0 else blocks[0] code, block, lang, delim = match.group("code", "block", "lang", "delim") + codeblocks = [dedent(code)] if block: info = (f"'{lang}' highlighted" if lang else "plain") + " code block" else: info = f"{delim}-enclosed inline code" else: - code = RAW_CODE_REGEX.fullmatch(code).group("code") + codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] info = "unformatted or badly formatted code" - code = textwrap.dedent(code) + code = "\n".join(codeblocks) log.trace(f"Extracted {info} for evaluation:\n{code}") - return code + return codeblocks + + @staticmethod + def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: + """ + Join the codeblocks into a single string, then return the code and the arguments in a tuple. + + If there are multiple codeblocks, insert the first one into the wrapped setup code. + """ + args = ["-m", "timeit"] + setup = "" + if len(codeblocks) > 1: + setup = codeblocks.pop(0) + + code = "\n".join(codeblocks) + + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + + return code, args @staticmethod - def get_results_message(results: dict) -> Tuple[str, str]: + def get_results_message(results: dict, job_name: str) -> Tuple[str, str]: """Return a user-friendly message and error corresponding to the process's return code.""" stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your eval job has completed with return code {returncode}" + msg = f"Your {job_name} job has completed with return code {returncode}" error = "" if returncode is None: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = stdout.strip() elif returncode == 128 + SIGKILL: - msg = "Your eval job timed out or ran out of memory" + msg = f"Your {job_name} job timed out or ran out of memory" elif returncode == 255: - msg = "Your eval job has failed" + msg = f"Your {job_name} job has failed" error = "A fatal NsJail error occurred" else: # Try to append signal's name if one exists @@ -130,8 +188,6 @@ class Snekbox(Cog): Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters and upload the full output to a paste service. """ - log.trace("Formatting output...") - output = output.rstrip("\n") original_output = output # To be uploaded to a pasting service if needed paste_link = None @@ -171,19 +227,27 @@ class Snekbox(Cog): return output, paste_link - async def send_eval(self, ctx: Context, code: str) -> Message: + async def send_job( + self, + ctx: Context, + code: str, + *, + args: Optional[list[str]] = None, + job_name: str + ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. Return the bot response. """ async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) + results = await self.post_job(code, args=args) + msg, error = self.get_results_message(results, job_name) if error: output, paste_link = error, None else: + log.trace("Formatting output...") output, paste_link = await self.format_output(results["stdout"]) icon = self.get_status_emoji(results) @@ -191,7 +255,7 @@ class Snekbox(Cog): if paste_link: msg = f"{msg}\nFull output: {paste_link}" - # Collect stats of eval fails + successes + # Collect stats of job fails + successes if icon == ":x:": self.bot.stats.incr("snekbox.python.fail") else: @@ -200,7 +264,7 @@ class Snekbox(Cog): filter_cog = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: - filter_triggered = await filter_cog.filter_eval(msg, ctx.message) + filter_triggered = await filter_cog.filter_snekbox_job(msg, ctx.message) if filter_triggered: response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: @@ -208,83 +272,85 @@ class Snekbox(Cog): response = await ctx.send(msg, allowed_mentions=allowed_mentions) scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop) - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") return response - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + async def continue_job( + self, ctx: Context, response: Message, command: Command + ) -> tuple[Optional[str], Optional[list[str]]]: """ - Check if the eval session should continue. + Check if the job's session should continue. - Return the new code to evaluate or None if the eval session should be terminated. + If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. + Otherwise return (None, None) if the job's session should be terminated. """ - _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + _predicate_message_edit = partial(predicate_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) with contextlib.suppress(NotFound): try: _, new_message = await self.bot.wait_for( 'message_edit', - check=_predicate_eval_message_edit, - timeout=REEVAL_TIMEOUT + check=_predicate_message_edit, + timeout=REDO_TIMEOUT ) - await ctx.message.add_reaction(REEVAL_EMOJI) + await ctx.message.add_reaction(REDO_EMOJI) await self.bot.wait_for( 'reaction_add', check=_predicate_emoji_reaction, timeout=10 ) - code = await self.get_code(new_message) - await ctx.message.clear_reaction(REEVAL_EMOJI) + code = await self.get_code(new_message, ctx.command) + await ctx.message.clear_reaction(REDO_EMOJI) with contextlib.suppress(HTTPException): await response.delete() + if code is None: + return None, None + except asyncio.TimeoutError: - await ctx.message.clear_reaction(REEVAL_EMOJI) - return None + await ctx.message.clear_reaction(REDO_EMOJI) + return None, None + + codeblocks = self.prepare_input(code) + + if command is self.timeit_command: + return self.prepare_timeit_input(codeblocks) + else: + return "\n".join(codeblocks), None - return code + return None, None - async def get_code(self, message: Message) -> Optional[str]: + async def get_code(self, message: Message, command: Command) -> Optional[str]: """ Return the code from `message` to be evaluated. - If the message is an invocation of the eval command, return the first argument or None if it + If the message is an invocation of the command, return the first argument or None if it doesn't exist. Otherwise, return the full content of the message. """ log.trace(f"Getting context for message {message.id}.") new_ctx = await self.bot.get_context(message) - if new_ctx.command is self.eval_command: - log.trace(f"Message {message.id} invokes eval command.") + if new_ctx.command is command: + log.trace(f"Message {message.id} invokes {command} command.") split = message.content.split(maxsplit=1) code = split[1] if len(split) > 1 else None else: - log.trace(f"Message {message.id} does not invoke eval command.") + log.trace(f"Message {message.id} does not invoke {command} command.") code = message.content return code - @command(name="eval", aliases=("e",)) - @guild_only() - @redirect_output( - destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, - ping_user=False - ) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ + async def run_job( + self, + job_name: str, + ctx: Context, + code: str, + *, + args: Optional[list[str]] = None, + ) -> None: + """Handles checks, stats and re-evaluation of a snekbox job.""" if ctx.author.id in self.jobs: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -292,10 +358,6 @@ class Snekbox(Cog): ) return - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - if Roles.helpers in (role.id for role in ctx.author.roles): self.bot.stats.incr("snekbox_usages.roles.helpers") else: @@ -312,26 +374,76 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) try: - response = await self.send_eval(ctx, code) + response = await self.send_job(ctx, code, args=args, job_name=job_name) finally: del self.jobs[ctx.author.id] - code = await self.continue_eval(ctx, response) + code, args = await self.continue_job(ctx, response, ctx.command) if not code: break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + @command(name="eval", aliases=("e",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, + ping_user=False + ) + async def eval_command(self, ctx: Context, *, code: str) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code = "\n".join(self.prepare_input(code)) + await self.run_job("eval", ctx, code) + + @command(name="timeit", aliases=("ti",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, + ping_user=False + ) + async def timeit_command(self, ctx: Context, *, code: str) -> str: + """ + Profile Python Code to find execution time. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + If multiple formatted codeblocks are provided, the first one will be the setup code, which will + not be timed. The remaining codeblocks will be joined together and timed. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + codeblocks = self.prepare_input(code) + code, args = self.prepare_timeit_input(codeblocks) + + await self.run_job("timeit", ctx, code=code, args=args) + -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: +def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: """Return True if the edited message is the context message and the content was indeed modified.""" return new_msg.id == ctx.message.id and old_msg.content != new_msg.content -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI +def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 8bdeedd27..5d213a883 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -61,7 +61,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -72,13 +72,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval') self.assertEqual(actual, expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127', '') ) @@ -86,7 +86,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}), + self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'), ('Your eval job has completed with return code 127 (SIGTEST)', '') ) @@ -156,28 +156,33 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + ctx.command = MagicMock() + + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=None) + self.cog.continue_eval = AsyncMock(return_value=(None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_once_with(ctx, response) + self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') + self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() + ctx.command = MagicMock() self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') - self.cog.continue_eval.assert_called_with(ctx, response) + self.cog.send_eval.assert_called_with( + ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' + ) + self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -191,12 +196,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) - async def test_send_eval(self): """Test the send_eval function.""" ctx = MockContext() @@ -210,10 +209,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('[No output]', None)) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -224,9 +223,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') async def test_send_eval_with_paste_link(self): @@ -242,10 +241,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -254,9 +253,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_eval_with_non_zero_eval(self): @@ -271,10 +270,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_eval = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -282,9 +281,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") @@ -297,9 +296,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) - self.assertEqual(actual, expected) + actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) + self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( @@ -318,8 +317,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage()) - self.assertEqual(actual, None) + actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, (None, None)) ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) async def test_get_code(self): @@ -343,7 +342,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) |