aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ToxicKidz <[email protected]>2022-01-17 17:42:37 -0500
committerGravatar ToxicKidz <[email protected]>2022-01-17 17:42:37 -0500
commit54e4f3777372ef526667885f4392030bab1b5b07 (patch)
tree938e6b463c36eb0d689e2c0f06bd9d3ef41ca349
parentfix: Modify tests to correspond with Snekbox.continue_eval (diff)
chore: Apply suggestions and adjust tests
-rw-r--r--bot/exts/utils/snekbox.py52
-rw-r--r--tests/bot/exts/utils/test_snekbox.py24
2 files changed, 29 insertions, 47 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 49f1be17b..1d9646113 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -36,6 +36,8 @@ RAW_CODE_REGEX = re.compile(
re.DOTALL # "." also matches newlines
)
+# The timeit command should only output the very last line, so all other output should be suppressed.
+# This will be used as the setup code along with any setup code provided.
TIMEIT_SETUP_WRAPPER = """
import atexit
import sys
@@ -163,19 +165,19 @@ class Snekbox(Cog):
return code, args
@staticmethod
- def get_results_message(results: dict) -> Tuple[str, str]:
+ def get_results_message(results: dict, job_name: str) -> Tuple[str, str]:
"""Return a user-friendly message and error corresponding to the process's return code."""
stdout, returncode = results["stdout"], results["returncode"]
- msg = f"Your eval job has completed with return code {returncode}"
+ msg = f"Your {job_name} job has completed with return code {returncode}"
error = ""
if returncode is None:
- msg = "Your eval job has failed"
+ msg = f"Your {job_name} job has failed"
error = stdout.strip()
elif returncode == 128 + SIGKILL:
- msg = "Your eval job timed out or ran out of memory"
+ msg = f"Your {job_name} job timed out or ran out of memory"
elif returncode == 255:
- msg = "Your eval job has failed"
+ msg = f"Your {job_name} job has failed"
error = "A fatal NsJail error occurred"
else:
# Try to append signal's name if one exists
@@ -249,7 +251,7 @@ class Snekbox(Cog):
code: str,
*,
args: Optional[list[str]] = None,
- format_func: FormatFunc
+ job_name: str
) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
@@ -258,13 +260,13 @@ class Snekbox(Cog):
"""
async with ctx.typing():
results = await self.post_eval(code, args=args)
- msg, error = self.get_results_message(results)
+ msg, error = self.get_results_message(results, job_name)
if error:
output, paste_link = error, None
else:
log.trace("Formatting output...")
- output, paste_link = await format_func(results["stdout"])
+ output, paste_link = await self.format_output(results["stdout"])
icon = self.get_status_emoji(results)
msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```"
@@ -288,12 +290,12 @@ class Snekbox(Cog):
response = await ctx.send(msg, allowed_mentions=allowed_mentions)
scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop)
- log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
+ log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")
return response
async def continue_eval(
self, ctx: Context, response: Message, command: Command
- ) -> Optional[tuple[str, Optional[list[str]]]]:
+ ) -> tuple[Optional[str], Optional[list[str]]]:
"""
Check if the eval session should continue.
@@ -355,19 +357,15 @@ class Snekbox(Cog):
return code
- async def run_eval(
+ async def run_job(
self,
+ job_name: str,
ctx: Context,
code: str,
- format_func: FormatFunc,
*,
args: Optional[list[str]] = None,
) -> None:
- """
- Handles checks, stats and re-evaluation of an eval.
-
- `format_func` is an async callable that takes a string (the output) and formats it to show to the user.
- """
+ """Handles checks, stats and re-evaluation of a snekbox job."""
if ctx.author.id in self.jobs:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
@@ -392,7 +390,7 @@ class Snekbox(Cog):
while True:
self.jobs[ctx.author.id] = datetime.datetime.now()
try:
- response = await self.send_eval(ctx, code, args=args, format_func=format_func)
+ response = await self.send_eval(ctx, code, args=args, job_name=job_name)
finally:
del self.jobs[ctx.author.id]
@@ -401,18 +399,6 @@ class Snekbox(Cog):
break
log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
- async def format_timeit_output(self, output: str) -> tuple[str, str]:
- """
- Parses the time from the end of the output given by timeit.
-
- If an error happened, then it won't contain the time and instead proceed with regular formatting.
- """
- split_output = output.rstrip("\n").rsplit("\n", 1)
- if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]):
- return split_output[1], None
-
- return await self.format_output(output)
-
@command(name="eval", aliases=("e",))
@guild_only()
@redirect_output(
@@ -434,7 +420,7 @@ class Snekbox(Cog):
issue with it!
"""
code = "\n".join(self.prepare_input(code))
- await self.run_eval(ctx, code, format_func=self.format_output)
+ await self.run_job("eval", ctx, code)
@command(name="timeit", aliases=("ti",))
@guild_only()
@@ -462,9 +448,7 @@ class Snekbox(Cog):
codeblocks = self.prepare_input(code)
code, args = self.prepare_timeit_input(codeblocks)
- await self.run_eval(
- ctx, code=code, format_func=self.format_timeit_output, args=args
- )
+ await self.run_job("timeit", ctx, code=code, args=args)
def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 4245de8a3..339cdaaa4 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -72,13 +72,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for stdout, returncode, expected in cases:
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode})
+ actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval')
self.assertEqual(actual, expected)
@patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)
def test_get_results_message_invalid_signal(self, mock_signals: Mock):
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),
('Your eval job has completed with return code 127', '')
)
@@ -86,7 +86,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def test_get_results_message_valid_signal(self, mock_signals: Mock):
mock_signals.return_value.name = 'SIGTEST'
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),
('Your eval job has completed with return code 127 (SIGTEST)', '')
)
@@ -164,9 +164,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
- self.cog.send_eval.assert_called_once_with(
- ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output
- )
+ self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval')
self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command)
async def test_eval_command_evaluate_twice(self):
@@ -182,7 +180,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
self.cog.send_eval.assert_called_with(
- ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output
+ ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval'
)
self.cog.continue_eval.assert_called_with(ctx, response, ctx.command)
@@ -214,7 +212,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -227,7 +225,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval')
self.cog.format_output.assert_called_once_with('')
async def test_send_eval_with_paste_link(self):
@@ -246,7 +244,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -257,7 +255,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval')
self.cog.format_output.assert_called_once_with('Way too long beard')
async def test_send_eval_with_non_zero_eval(self):
@@ -275,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -285,7 +283,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval')
self.cog.format_output.assert_not_called()
@patch("bot.exts.utils.snekbox.partial")