diff options
author | 2024-02-03 17:12:12 -0800 | |
---|---|---|
committer | 2024-03-26 06:30:28 -0700 | |
commit | 9a8520178b4a966dffc140d46bbe83466a3cf39e (patch) | |
tree | 9ce700046373dcbf984d429a43c23de2474383e1 | |
parent | Attempt to fetch help post from cache before making an API request (diff) |
Snekbox: truncate blocked file extensions
Avoid Discord's character limit for messages.
Fix #2464
-rw-r--r-- | bot/exts/utils/snekbox/_cog.py | 131 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_eval.py | 2 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/test_snekbox.py | 19 |
3 files changed, 90 insertions, 62 deletions
diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index db4181d68..f26bf1000 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import re +from collections.abc import Iterable from functools import partial from operator import attrgetter from textwrap import dedent @@ -289,6 +290,69 @@ class Snekbox(Cog): return output, paste_link + async def format_file_text(self, text_files: list[FileAttachment], output: str) -> str: + # Inline until budget, then upload to paste service + # Budget is shared with stdout, so subtract what we've already used + budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) + budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) + msg = "" + + for file in text_files: + file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" + # Override to always allow 1 line and <= 50 chars, since this is less than a link + if len(file_text) <= 50 and not file_text.count("\n"): + msg += f"\n`{file.name}`\n```\n{file_text}\n```" + # otherwise, use budget + else: + format_text, link_text = await self.format_output( + file_text, + budget_lines, + budget_chars, + line_nums=False, + output_default="[Empty]" + ) + # With any link, use it (don't use budget) + if link_text: + msg += f"\n`{file.name}`\n{link_text}" + else: + msg += f"\n`{file.name}`\n```\n{format_text}\n```" + budget_lines -= format_text.count("\n") + 1 + budget_chars -= len(file_text) + + return msg + + def format_blocked_extensions(self, blocked: list[FileAttachment]) -> str: + # Sort by length and then lexicographically to fit as many as possible before truncating. + blocked_sorted = sorted(set(f.suffix for f in blocked), key=lambda e: (len(e), e)) + + # Only no extension + if len(blocked_sorted) == 1 and blocked_sorted[0] == "": + blocked_msg = "Files with no extension can't be uploaded." + # Both + elif "" in blocked_sorted: + blocked_str = self.join_blocked_extensions(ext for ext in blocked_sorted if ext) + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) + else: + blocked_str = self.join_blocked_extensions(blocked_sorted) + blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + + return f"\n{Emojis.failed_file} {blocked_msg}" + + def join_blocked_extensions(self, extensions: Iterable, delimiter: str = ", ", char_limit: int = 100) -> str: + joined = "" + for ext in extensions: + cur_delimiter = delimiter if joined else "" + if len(joined) + len(cur_delimiter) + len(ext) >= char_limit: + joined += f"{cur_delimiter}..." + break + + joined += f"{cur_delimiter}{ext}" + + return joined + + def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" # Filter files into allowed and blocked @@ -318,16 +382,18 @@ class Snekbox(Cog): """ async with ctx.typing(): result = await self.post_job(job) - msg = result.get_message(job) - error = result.error_message - - if error: - output, paste_link = error, None + # Collect stats of job fails + successes + if result.returncode != 0: + self.bot.stats.incr("snekbox.python.fail") else: - log.trace("Formatting output...") - output, paste_link = await self.format_output(result.stdout) + self.bot.stats.incr("snekbox.python.success") + + log.trace("Formatting output...") + output = result.error_message if result.error_message else result.stdout + output, paste_link = await self.format_output(output) - msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" + status_msg = result.get_status_message(job) + msg = f"{ctx.author.mention} {result.status_emoji} {status_msg}.\n" # This is done to make sure the last line of output contains the error # and the error is not manually printed by the author with a syntax error. @@ -345,39 +411,9 @@ class Snekbox(Cog): if files_error := result.files_error_message: msg += f"\n{files_error}" - # Collect stats of job fails + successes - if result.returncode != 0: - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - # Split text files text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES] - # Inline until budget, then upload to paste service - # Budget is shared with stdout, so subtract what we've already used - budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) - budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) - for file in text_files: - file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" - # Override to always allow 1 line and <= 50 chars, since this is less than a link - if len(file_text) <= 50 and not file_text.count("\n"): - msg += f"\n`{file.name}`\n```\n{file_text}\n```" - # otherwise, use budget - else: - format_text, link_text = await self.format_output( - file_text, - budget_lines, - budget_chars, - line_nums=False, - output_default="[Empty]" - ) - # With any link, use it (don't use budget) - if link_text: - msg += f"\n`{file.name}`\n{link_text}" - else: - msg += f"\n`{file.name}`\n```\n{format_text}\n```" - budget_lines -= format_text.count("\n") + 1 - budget_chars -= len(file_text) + msg += await self.format_file_text(text_files, output) filter_cog: Filtering | None = self.bot.get_cog("Filtering") blocked_exts = set() @@ -392,23 +428,8 @@ class Snekbox(Cog): # Filter file extensions allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked) - # Add notice if any files were blocked if blocked: - blocked_sorted = sorted(set(f.suffix for f in blocked)) - # Only no extension - if len(blocked_sorted) == 1 and blocked_sorted[0] == "": - blocked_msg = "Files with no extension can't be uploaded." - # Both - elif "" in blocked_sorted: - blocked_str = ", ".join(ext for ext in blocked_sorted if ext) - blocked_msg = ( - f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" - ) - else: - blocked_str = ", ".join(blocked_sorted) - blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" - - msg += f"\n{Emojis.failed_file} {blocked_msg}" + msg += self.format_blocked_extensions(blocked) # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py index d3d1e7a18..3867b81de 100644 --- a/bot/exts/utils/snekbox/_eval.py +++ b/bot/exts/utils/snekbox/_eval.py @@ -141,7 +141,7 @@ class EvalResult: text = escape_mentions(text) return text - def get_message(self, job: EvalJob) -> str: + def get_status_message(self, job: EvalJob) -> str: """Return a user-friendly message corresponding to the process's return code.""" msg = f"Your {job.version} {job.name} job" diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 8ee0f46ff..d057b284d 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -113,7 +113,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = EvalResult(stdout=stdout, returncode=returncode) job = EvalJob([]) # Check all 3 message types - msg = result.get_message(job) + msg = result.get_status_message(job) self.assertEqual(msg, exp_msg) error = result.error_message self.assertEqual(error, exp_err) @@ -166,7 +166,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.get_message(EvalJob([], version="3.10")), + result.get_status_message(EvalJob([], version="3.10")), "Your 3.10 eval job has completed with return code 127" ) self.assertEqual(result.error_message, "") @@ -177,7 +177,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mock_signals.return_value.name = "SIGTEST" result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.get_message(EvalJob([], version="3.12")), + result.get_status_message(EvalJob([], version="3.12")), "Your 3.12 eval job has completed with return code 127 (SIGTEST)" ) @@ -386,12 +386,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author.mention = "@user#7700" - eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")]) + files = [ + FileAttachment("test.disallowed2", b"test"), + FileAttachment("test.disallowed", b"test"), + FileAttachment("test.allowed", b"test"), + FileAttachment("test." + ("a" * 100), b"test") + ] + eval_result = EvalResult("", 0, files=files) self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.upload_output = AsyncMock() # This function isn't called + disallowed_exts = [".disallowed", "." + ("a" * 100), ".disallowed2"] mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, disallowed_exts)) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.12") @@ -402,7 +409,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertTrue( res.startswith("@user#7700 :white_check_mark: Your 3.12 eval job has completed with return code 0.") ) - self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res) + self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed, .disallowed2, ...**", res) self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() |