aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox/_cog.py131
-rw-r--r--bot/exts/utils/snekbox/_eval.py2
-rw-r--r--tests/bot/exts/utils/snekbox/test_snekbox.py19
3 files changed, 90 insertions, 62 deletions
diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py
index db4181d68..f26bf1000 100644
--- a/bot/exts/utils/snekbox/_cog.py
+++ b/bot/exts/utils/snekbox/_cog.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
import re
+from collections.abc import Iterable
from functools import partial
from operator import attrgetter
from textwrap import dedent
@@ -289,6 +290,69 @@ class Snekbox(Cog):
return output, paste_link
+ async def format_file_text(self, text_files: list[FileAttachment], output: str) -> str:
+ # Inline until budget, then upload to paste service
+ # Budget is shared with stdout, so subtract what we've already used
+ budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1)
+ budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output)
+ msg = ""
+
+ for file in text_files:
+ file_text = file.content.decode("utf-8", errors="replace") or "[Empty]"
+ # Override to always allow 1 line and <= 50 chars, since this is less than a link
+ if len(file_text) <= 50 and not file_text.count("\n"):
+ msg += f"\n`{file.name}`\n```\n{file_text}\n```"
+ # otherwise, use budget
+ else:
+ format_text, link_text = await self.format_output(
+ file_text,
+ budget_lines,
+ budget_chars,
+ line_nums=False,
+ output_default="[Empty]"
+ )
+ # With any link, use it (don't use budget)
+ if link_text:
+ msg += f"\n`{file.name}`\n{link_text}"
+ else:
+ msg += f"\n`{file.name}`\n```\n{format_text}\n```"
+ budget_lines -= format_text.count("\n") + 1
+ budget_chars -= len(file_text)
+
+ return msg
+
+ def format_blocked_extensions(self, blocked: list[FileAttachment]) -> str:
+ # Sort by length and then lexicographically to fit as many as possible before truncating.
+ blocked_sorted = sorted(set(f.suffix for f in blocked), key=lambda e: (len(e), e))
+
+ # Only no extension
+ if len(blocked_sorted) == 1 and blocked_sorted[0] == "":
+ blocked_msg = "Files with no extension can't be uploaded."
+ # Both
+ elif "" in blocked_sorted:
+ blocked_str = self.join_blocked_extensions(ext for ext in blocked_sorted if ext)
+ blocked_msg = (
+ f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**"
+ )
+ else:
+ blocked_str = self.join_blocked_extensions(blocked_sorted)
+ blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**"
+
+ return f"\n{Emojis.failed_file} {blocked_msg}"
+
+ def join_blocked_extensions(self, extensions: Iterable, delimiter: str = ", ", char_limit: int = 100) -> str:
+ joined = ""
+ for ext in extensions:
+ cur_delimiter = delimiter if joined else ""
+ if len(joined) + len(cur_delimiter) + len(ext) >= char_limit:
+ joined += f"{cur_delimiter}..."
+ break
+
+ joined += f"{cur_delimiter}{ext}"
+
+ return joined
+
+
def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles:
"""Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists."""
# Filter files into allowed and blocked
@@ -318,16 +382,18 @@ class Snekbox(Cog):
"""
async with ctx.typing():
result = await self.post_job(job)
- msg = result.get_message(job)
- error = result.error_message
-
- if error:
- output, paste_link = error, None
+ # Collect stats of job fails + successes
+ if result.returncode != 0:
+ self.bot.stats.incr("snekbox.python.fail")
else:
- log.trace("Formatting output...")
- output, paste_link = await self.format_output(result.stdout)
+ self.bot.stats.incr("snekbox.python.success")
+
+ log.trace("Formatting output...")
+ output = result.error_message if result.error_message else result.stdout
+ output, paste_link = await self.format_output(output)
- msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"
+ status_msg = result.get_status_message(job)
+ msg = f"{ctx.author.mention} {result.status_emoji} {status_msg}.\n"
# This is done to make sure the last line of output contains the error
# and the error is not manually printed by the author with a syntax error.
@@ -345,39 +411,9 @@ class Snekbox(Cog):
if files_error := result.files_error_message:
msg += f"\n{files_error}"
- # Collect stats of job fails + successes
- if result.returncode != 0:
- self.bot.stats.incr("snekbox.python.fail")
- else:
- self.bot.stats.incr("snekbox.python.success")
-
# Split text files
text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES]
- # Inline until budget, then upload to paste service
- # Budget is shared with stdout, so subtract what we've already used
- budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1)
- budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output)
- for file in text_files:
- file_text = file.content.decode("utf-8", errors="replace") or "[Empty]"
- # Override to always allow 1 line and <= 50 chars, since this is less than a link
- if len(file_text) <= 50 and not file_text.count("\n"):
- msg += f"\n`{file.name}`\n```\n{file_text}\n```"
- # otherwise, use budget
- else:
- format_text, link_text = await self.format_output(
- file_text,
- budget_lines,
- budget_chars,
- line_nums=False,
- output_default="[Empty]"
- )
- # With any link, use it (don't use budget)
- if link_text:
- msg += f"\n`{file.name}`\n{link_text}"
- else:
- msg += f"\n`{file.name}`\n```\n{format_text}\n```"
- budget_lines -= format_text.count("\n") + 1
- budget_chars -= len(file_text)
+ msg += await self.format_file_text(text_files, output)
filter_cog: Filtering | None = self.bot.get_cog("Filtering")
blocked_exts = set()
@@ -392,23 +428,8 @@ class Snekbox(Cog):
# Filter file extensions
allowed, blocked = self._filter_files(ctx, result.files, blocked_exts)
blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked)
- # Add notice if any files were blocked
if blocked:
- blocked_sorted = sorted(set(f.suffix for f in blocked))
- # Only no extension
- if len(blocked_sorted) == 1 and blocked_sorted[0] == "":
- blocked_msg = "Files with no extension can't be uploaded."
- # Both
- elif "" in blocked_sorted:
- blocked_str = ", ".join(ext for ext in blocked_sorted if ext)
- blocked_msg = (
- f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**"
- )
- else:
- blocked_str = ", ".join(blocked_sorted)
- blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**"
-
- msg += f"\n{Emojis.failed_file} {blocked_msg}"
+ msg += self.format_blocked_extensions(blocked)
# Upload remaining non-text files
files = [f.to_file() for f in allowed if f not in text_files]
diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py
index d3d1e7a18..3867b81de 100644
--- a/bot/exts/utils/snekbox/_eval.py
+++ b/bot/exts/utils/snekbox/_eval.py
@@ -141,7 +141,7 @@ class EvalResult:
text = escape_mentions(text)
return text
- def get_message(self, job: EvalJob) -> str:
+ def get_status_message(self, job: EvalJob) -> str:
"""Return a user-friendly message corresponding to the process's return code."""
msg = f"Your {job.version} {job.name} job"
diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py
index 8ee0f46ff..d057b284d 100644
--- a/tests/bot/exts/utils/snekbox/test_snekbox.py
+++ b/tests/bot/exts/utils/snekbox/test_snekbox.py
@@ -113,7 +113,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
result = EvalResult(stdout=stdout, returncode=returncode)
job = EvalJob([])
# Check all 3 message types
- msg = result.get_message(job)
+ msg = result.get_status_message(job)
self.assertEqual(msg, exp_msg)
error = result.error_message
self.assertEqual(error, exp_err)
@@ -166,7 +166,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def test_eval_result_message_invalid_signal(self, _mock_signals: Mock):
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.10")),
+ result.get_status_message(EvalJob([], version="3.10")),
"Your 3.10 eval job has completed with return code 127"
)
self.assertEqual(result.error_message, "")
@@ -177,7 +177,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mock_signals.return_value.name = "SIGTEST"
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.12")),
+ result.get_status_message(EvalJob([], version="3.12")),
"Your 3.12 eval job has completed with return code 127 (SIGTEST)"
)
@@ -386,12 +386,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author.mention = "@user#7700"
- eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")])
+ files = [
+ FileAttachment("test.disallowed2", b"test"),
+ FileAttachment("test.disallowed", b"test"),
+ FileAttachment("test.allowed", b"test"),
+ FileAttachment("test." + ("a" * 100), b"test")
+ ]
+ eval_result = EvalResult("", 0, files=files)
self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.upload_output = AsyncMock() # This function isn't called
+ disallowed_exts = [".disallowed", "." + ("a" * 100), ".disallowed2"]
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, disallowed_exts))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.12")
@@ -402,7 +409,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(
res.startswith("@user#7700 :white_check_mark: Your 3.12 eval job has completed with return code 0.")
)
- self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res)
+ self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed, .disallowed2, ...**", res)
self.cog.post_job.assert_called_once_with(job)
self.cog.upload_output.assert_not_called()