aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ToxicKidz <[email protected]>2022-01-27 10:31:12 -0500
committerGravatar ToxicKidz <[email protected]>2022-01-27 10:31:12 -0500
commit85a6f430aa68f59ce6958ecb6450eca0736628e4 (patch)
treebcc01e2997ad2a4c1337599cb0ad851562a1ddde
parentMerge branch 'main' of https://github.com/python-discord/bot into feat/timeit... (diff)
chore: Switch Snekbox.prepare_input with a CodeblockConverter
As per @Numerlor's suggestion
-rw-r--r--bot/exts/filters/filtering.py2
-rw-r--r--bot/exts/utils/snekbox.py74
-rw-r--r--tests/bot/exts/utils/test_snekbox.py105
3 files changed, 91 insertions, 90 deletions
diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py
index 599302576..375e9dca8 100644
--- a/bot/exts/filters/filtering.py
+++ b/bot/exts/filters/filtering.py
@@ -267,7 +267,7 @@ class Filtering(Cog):
# Update time when alert sent
await self.name_alerts.set(member.id, arrow.utcnow().timestamp())
- async def filter_snekbox_job(self, result: str, msg: Message) -> bool:
+ async def filter_snekbox_output(self, result: str, msg: Message) -> bool:
"""
Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly.
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 15599208f..41f6bf8ad 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -9,7 +9,7 @@ from typing import Optional, Tuple
from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User
-from discord.ext.commands import Cog, Command, Context, command, guild_only
+from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only
from bot.bot import Bot
from bot.constants import Categories, Channels, Roles, URLs
@@ -68,35 +68,11 @@ REDO_EMOJI = '\U0001f501' # :repeat:
REDO_TIMEOUT = 30
-class Snekbox(Cog):
- """Safe evaluation of Python code using Snekbox."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
- self.jobs = {}
-
- async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict:
- """Send a POST request to the Snekbox API to evaluate code and return the results."""
- url = URLs.snekbox_eval_api
- data = {"input": code}
-
- if args is not None:
- data["args"] = args
+class CodeblockConverter(Converter):
+ """Attempts to extract code from a codeblock, if provided."""
- async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
- return await resp.json()
-
- async def upload_output(self, output: str) -> Optional[str]:
- """Upload the job's output to a paste service and return a URL to it if successful."""
- log.trace("Uploading full output to paste service...")
-
- if len(output) > MAX_PASTE_LEN:
- log.info("Full output is too long to upload")
- return "too long to upload"
- return await send_to_paste_service(output, extension="txt")
-
- @staticmethod
- def prepare_input(code: str) -> list[str]:
+ @classmethod
+ async def convert(cls, ctx: Context, code: str) -> list[str]:
"""
Extract code from the Markdown, format it, and insert it into the code template.
@@ -128,6 +104,34 @@ class Snekbox(Cog):
log.trace(f"Extracted {info} for evaluation:\n{code}")
return codeblocks
+
+class Snekbox(Cog):
+ """Safe evaluation of Python code using Snekbox."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.jobs = {}
+
+ async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict:
+ """Send a POST request to the Snekbox API to evaluate code and return the results."""
+ url = URLs.snekbox_eval_api
+ data = {"input": code}
+
+ if args is not None:
+ data["args"] = args
+
+ async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
+ return await resp.json()
+
+ async def upload_output(self, output: str) -> Optional[str]:
+ """Upload the job's output to a paste service and return a URL to it if successful."""
+ log.trace("Uploading full output to paste service...")
+
+ if len(output) > MAX_PASTE_LEN:
+ log.info("Full output is too long to upload")
+ return "too long to upload"
+ return await send_to_paste_service(output, extension="txt")
+
@staticmethod
def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:
"""
@@ -313,7 +317,7 @@ class Snekbox(Cog):
await ctx.message.clear_reaction(REDO_EMOJI)
return None, None
- codeblocks = self.prepare_input(code)
+ codeblocks = await CodeblockConverter.convert(ctx, code)
if command is self.timeit_command:
return self.prepare_timeit_input(codeblocks)
@@ -393,7 +397,7 @@ class Snekbox(Cog):
channels=NO_SNEKBOX_CHANNELS,
ping_user=False
)
- async def eval_command(self, ctx: Context, *, code: str) -> None:
+ async def eval_command(self, ctx: Context, *, code: CodeblockConverter) -> None:
"""
Run Python code and get the results.
@@ -404,8 +408,7 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
- code = "\n".join(self.prepare_input(code))
- await self.run_job("eval", ctx, code)
+ await self.run_job("eval", ctx, "\n".join(code))
@command(name="timeit", aliases=("ti",))
@guild_only()
@@ -416,7 +419,7 @@ class Snekbox(Cog):
channels=NO_SNEKBOX_CHANNELS,
ping_user=False
)
- async def timeit_command(self, ctx: Context, *, code: str) -> str:
+ async def timeit_command(self, ctx: Context, *, code: CodeblockConverter) -> str:
"""
Profile Python Code to find execution time.
@@ -430,8 +433,7 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
- codeblocks = self.prepare_input(code)
- code, args = self.prepare_timeit_input(codeblocks)
+ code, args = self.prepare_timeit_input(code)
await self.run_job("timeit", ctx, code=code, args=args)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 5d213a883..75da0c860 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -17,7 +17,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot = MockBot()
self.cog = Snekbox(bot=self.bot)
- async def test_post_eval(self):
+ async def test_post_job(self):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
resp = MagicMock()
resp.json = AsyncMock(return_value="return")
@@ -26,7 +26,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
context_manager.__aenter__.return_value = resp
self.bot.http_session.post.return_value = context_manager
- self.assertEqual(await self.cog.post_eval("import random"), "return")
+ self.assertEqual(await self.cog.post_job("import random"), "return")
self.bot.http_session.post.assert_called_with(
constants.URLs.snekbox_eval_api,
json={"input": "import random"},
@@ -45,7 +45,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
await self.cog.upload_output("Test output.")
mock_paste_util.assert_called_once_with("Test output.", extension="txt")
- def test_prepare_input(self):
+ async def test_codeblock_converter(self):
+ ctx = MockContext()
cases = (
('print("Hello world!")', 'print("Hello world!")', 'non-formatted'),
('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
@@ -61,7 +62,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for case, expected, testname in cases:
with self.subTest(msg=f'Extract code from {testname}.'):
- self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected)
+ self.assertEqual(
+ '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected
+ )
def test_get_results_message(self):
"""Return error and message according to the eval result."""
@@ -158,31 +161,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
response = MockMessage()
ctx.command = MagicMock()
- self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode'])
- self.cog.send_eval = AsyncMock(return_value=response)
- self.cog.continue_eval = AsyncMock(return_value=(None, None))
+ self.cog.send_job = AsyncMock(return_value=response)
+ self.cog.continue_job = AsyncMock(return_value=(None, None))
- await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
- self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
- self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval')
- self.cog.continue_eval.assert_called_once_with(ctx, response, ctx.command)
+ await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode'])
+ self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval')
+ self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command)
async def test_eval_command_evaluate_twice(self):
"""Test the eval and re-eval command procedure."""
ctx = MockContext()
response = MockMessage()
ctx.command = MagicMock()
- self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
- self.cog.send_eval = AsyncMock(return_value=response)
- self.cog.continue_eval = AsyncMock()
- self.cog.continue_eval.side_effect = (('MyAwesomeFormattedCode', None), (None, None))
+ self.cog.send_job = AsyncMock(return_value=response)
+ self.cog.continue_job = AsyncMock()
+ self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None))
- await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
- self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
- self.cog.send_eval.assert_called_with(
+ await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode'])
+ self.cog.send_job.assert_called_with(
ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval'
)
- self.cog.continue_eval.assert_called_with(ctx, response, ctx.command)
+ self.cog.continue_job.assert_called_with(ctx, response, ctx.command)
async def test_eval_command_reject_two_eval_at_the_same_time(self):
"""Test if the eval command rejects an eval if the author already have a running eval."""
@@ -196,14 +195,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
)
- async def test_send_eval(self):
- """Test the send_eval function."""
+ async def test_send_job(self):
+ """Test the send_job function."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author = MockUser(mention='@LemonLemonishBeard#0042')
- self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
+ self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0})
self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
self.cog.format_output = AsyncMock(return_value=('[No output]', None))
@@ -212,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -223,19 +222,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval')
self.cog.format_output.assert_called_once_with('')
- async def test_send_eval_with_paste_link(self):
- """Test the send_eval function with a too long output that generate a paste link."""
+ async def test_send_job_with_paste_link(self):
+ """Test the send_job function with a too long output that generate a paste link."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
@@ -244,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -253,18 +252,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval')
self.cog.format_output.assert_called_once_with('Way too long beard')
- async def test_send_eval_with_non_zero_eval(self):
- """Test the send_eval function with a code returning a non-zero code."""
+ async def test_send_job_with_non_zero_eval(self):
+ """Test the send_job function with a code returning a non-zero code."""
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
+ self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
self.cog.format_output = AsyncMock() # This function isn't called
@@ -273,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_snekbox_job = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode', job_name='eval')
+ await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')
ctx.send.assert_called_once()
self.assertEqual(
@@ -281,14 +280,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
)
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
+ self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval')
self.cog.format_output.assert_not_called()
@patch("bot.exts.utils.snekbox.partial")
- async def test_continue_eval_does_continue(self, partial_mock):
- """Test that the continue_eval function does continue if required conditions are met."""
+ async def test_continue_job_does_continue(self, partial_mock):
+ """Test that the continue_job function does continue if required conditions are met."""
ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
response = MockMessage(delete=AsyncMock())
new_msg = MockMessage()
@@ -296,30 +295,30 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
expected = "NewCode"
self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
- actual = await self.cog.continue_eval(ctx, response, self.cog.eval_command)
+ actual = await self.cog.continue_job(ctx, response, self.cog.eval_command)
self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command)
self.assertEqual(actual, (expected, None))
self.bot.wait_for.assert_has_awaits(
(
call(
'message_edit',
- check=partial_mock(snekbox.predicate_eval_message_edit, ctx),
- timeout=snekbox.REEVAL_TIMEOUT,
+ check=partial_mock(snekbox.predicate_message_edit, ctx),
+ timeout=snekbox.REDO_TIMEOUT,
),
- call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10)
)
)
- ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
- ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
+ ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
response.delete.assert_called_once()
- async def test_continue_eval_does_not_continue(self):
+ async def test_continue_job_does_not_continue(self):
ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))
self.bot.wait_for.side_effect = asyncio.TimeoutError
- actual = await self.cog.continue_eval(ctx, MockMessage(), self.cog.eval_command)
+ actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command)
self.assertEqual(actual, (None, None))
- ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
async def test_get_code(self):
"""Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
@@ -347,8 +346,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_context.assert_awaited_once_with(message)
self.assertEqual(actual_code, expected_code)
- def test_predicate_eval_message_edit(self):
- """Test the predicate_eval_message_edit function."""
+ def test_predicate_message_edit(self):
+ """Test the predicate_message_edit function."""
msg0 = MockMessage(id=1, content='abc')
msg1 = MockMessage(id=2, content='abcdef')
msg2 = MockMessage(id=1, content='abcdef')
@@ -361,18 +360,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
for ctx_msg, new_msg, expected, testname in cases:
with self.subTest(msg=f'Messages with {testname} return {expected}'):
ctx = MockContext(message=ctx_msg)
- actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg)
+ actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg)
self.assertEqual(actual, expected)
- def test_predicate_eval_emoji_reaction(self):
- """Test the predicate_eval_emoji_reaction function."""
+ def test_predicate_emoji_reaction(self):
+ """Test the predicate_emoji_reaction function."""
valid_reaction = MockReaction(message=MockMessage(id=1))
- valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI
+ valid_reaction.__str__.return_value = snekbox.REDO_EMOJI
valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))
valid_user = MockUser(id=2)
invalid_reaction_id = MockReaction(message=MockMessage(id=42))
- invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI
+ invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI
invalid_user_id = MockUser(id=42)
invalid_reaction_str = MockReaction(message=MockMessage(id=1))
invalid_reaction_str.__str__.return_value = ':longbeard:'
@@ -385,7 +384,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for reaction, user, expected, testname in cases:
with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
- actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user)
+ actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user)
self.assertEqual(actual, expected)