diff options
author | 2022-01-27 10:31:12 -0500 | |
---|---|---|
committer | 2022-01-27 10:31:12 -0500 | |
commit | 85a6f430aa68f59ce6958ecb6450eca0736628e4 (patch) | |
tree | bcc01e2997ad2a4c1337599cb0ad851562a1ddde | |
parent | Merge branch 'main' of https://github.com/python-discord/bot into feat/timeit... (diff) |
chore: Switch Snekbox.prepare_input with a CodeblockConverter
As per @Numerlor's suggestion
-rw-r--r-- | bot/exts/filters/filtering.py | 2 | ||||
-rw-r--r-- | bot/exts/utils/snekbox.py | 74 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 105 |
3 files changed, 91 insertions, 90 deletions
diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 599302576..375e9dca8 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -267,7 +267,7 @@ class Filtering(Cog): # Update time when alert sent await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) - async def filter_snekbox_job(self, result: str, msg: Message) -> bool: + async def filter_snekbox_output(self, result: str, msg: Message) -> bool: """ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 15599208f..41f6bf8ad 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -9,7 +9,7 @@ from typing import Optional, Tuple from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Command, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -68,35 +68,11 @@ REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - url = URLs.snekbox_eval_api - data = {"input": code} - - if args is not None: - data["args"] = args +class CodeblockConverter(Converter): + """Attempts to extract code from a codeblock, if provided.""" - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() - - async def upload_output(self, output: str) -> Optional[str]: - """Upload the job's output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - if len(output) > MAX_PASTE_LEN: - log.info("Full output is too long to upload") - return "too long to upload" - return await send_to_paste_service(output, extension="txt") - - @staticmethod - def prepare_input(code: str) -> list[str]: + @classmethod + async def convert(cls, ctx: Context, code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. @@ -128,6 +104,34 @@ class Snekbox(Cog): log.trace(f"Extracted {info} for evaluation:\n{code}") return codeblocks + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + + if args is not None: + data["args"] = args + + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() + + async def upload_output(self, output: str) -> Optional[str]: + """Upload the job's output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" + return await send_to_paste_service(output, extension="txt") + @staticmethod def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: """ @@ -313,7 +317,7 @@ class Snekbox(Cog): await ctx.message.clear_reaction(REDO_EMOJI) return None, None - codeblocks = self.prepare_input(code) + codeblocks = await CodeblockConverter.convert(ctx, code) if command is self.timeit_command: return self.prepare_timeit_input(codeblocks) @@ -393,7 +397,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def eval_command(self, ctx: Context, *, code: str) -> None: + async def eval_command(self, ctx: Context, *, code: CodeblockConverter) -> None: """ Run Python code and get the results. @@ -404,8 +408,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = "\n".join(self.prepare_input(code)) - await self.run_job("eval", ctx, code) + await self.run_job("eval", ctx, "\n".join(code)) @command(name="timeit", aliases=("ti",)) @guild_only() @@ -416,7 +419,7 @@ class Snekbox(Cog): channels=NO_SNEKBOX_CHANNELS, ping_user=False ) - async def timeit_command(self, ctx: Context, *, code: str) -> str: + async def timeit_command(self, ctx: Context, *, code: CodeblockConverter) -> str: """ Profile Python Code to find execution time. @@ -430,8 +433,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - codeblocks = self.prepare_input(code) - code, args = self.prepare_timeit_input(codeblocks) + code, args = self.prepare_timeit_input(code) await self.run_job("timeit", ctx, code=code, args=args) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 5d213a883..75da0c860 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -17,7 +17,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) - async def test_post_eval(self): + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() resp.json = AsyncMock(return_value="return") @@ -26,7 +26,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_eval("import random"), "return") + self.assertEqual(await self.cog.post_job("import random"), "return") self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, json={"input": "import random"}, @@ -45,7 +45,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.upload_output("Test output.") mock_paste_util.assert_called_once_with("Test output.", extension="txt") - def test_prepare_input(self): + async def test_codeblock_converter(self): + ctx = MockContext() cases = ( ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -61,7 +62,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -158,31 +161,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock(return_value=(None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock(return_value=(None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval') - self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command) + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') + self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command) async def test_eval_command_evaluate_twice(self): """Test the eval and re-eval command procedure.""" ctx = MockContext() response = MockMessage() ctx.command = MagicMock() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') - self.cog.send_eval = AsyncMock(return_value=response) - self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + self.cog.send_job = AsyncMock(return_value=response) + self.cog.continue_job = AsyncMock() + self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) - await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') - self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with( + await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) + self.cog.send_job.assert_called_with( ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' ) - self.cog.continue_eval.assert_called_with(ctx, response, ctx.command) + self.cog.continue_job.assert_called_with(ctx, response, ctx.command) async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -196,14 +195,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_send_eval(self): - """Test the send_eval function.""" + async def test_send_job(self): + """Test the send_job function.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) @@ -212,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -223,19 +222,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('') - async def test_send_eval_with_paste_link(self): - """Test the send_eval function with a too long output that generate a paste link.""" + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) @@ -244,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -253,18 +252,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval') self.cog.format_output.assert_called_once_with('Way too long beard') - async def test_send_eval_with_non_zero_eval(self): - """Test the send_eval function with a code returning a non-zero code.""" + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" ctx = MockContext() ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called @@ -273,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval') + await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval') ctx.send.assert_called_once() self.assertEqual( @@ -281,14 +280,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) + self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval') self.cog.format_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") - async def test_continue_eval_does_continue(self, partial_mock): - """Test that the continue_eval function does continue if required conditions are met.""" + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) response = MockMessage(delete=AsyncMock()) new_msg = MockMessage() @@ -296,30 +295,30 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected = "NewCode" self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command) + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, (expected, None)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_eval_message_edit, ctx), - timeout=snekbox.REEVAL_TIMEOUT, + check=partial_mock(snekbox.predicate_message_edit, ctx), + timeout=snekbox.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) response.delete.assert_called_once() - async def test_continue_eval_does_not_continue(self): + async def test_continue_job_does_not_continue(self): ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) self.bot.wait_for.side_effect = asyncio.TimeoutError - actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command) + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) self.assertEqual(actual, (None, None)) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -347,8 +346,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) - def test_predicate_eval_message_edit(self): - """Test the predicate_eval_message_edit function.""" + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" msg0 = MockMessage(id=1, content='abc') msg1 = MockMessage(id=2, content='abcdef') msg2 = MockMessage(id=1, content='abcdef') @@ -361,18 +360,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) - def test_predicate_eval_emoji_reaction(self): - """Test the predicate_eval_emoji_reaction function.""" + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_reaction.__str__.return_value = snekbox.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -385,7 +384,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) |