aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2023-03-10 02:04:56 +0200
committerGravatar GitHub <[email protected]>2023-03-10 02:04:56 +0200
commite21cd3ecea9afc7cf5d08bf75c944dd6049f4d28 (patch)
treeaa1f471ea6b0ced910ac2725773c295ae9efb891
parentMerge #2459: Update nested delimiter to double underscore (diff)
parentMerge branch 'main' into snekbox-files (diff)
Merge pull request #2326 from python-discord/snekbox-filesfix-eval-message-limit
Support `eval` (snekbox) file system and attachment display
-rw-r--r--bot/constants.py1
-rw-r--r--bot/exts/utils/snekbox/__init__.py12
-rw-r--r--bot/exts/utils/snekbox/_cog.py (renamed from bot/exts/utils/snekbox.py)373
-rw-r--r--bot/exts/utils/snekbox/_eval.py183
-rw-r--r--bot/exts/utils/snekbox/_io.py102
-rw-r--r--tests/bot/exts/utils/snekbox/__init__.py0
-rw-r--r--tests/bot/exts/utils/snekbox/test_io.py34
-rw-r--r--tests/bot/exts/utils/snekbox/test_snekbox.py (renamed from tests/bot/exts/utils/test_snekbox.py)266
8 files changed, 727 insertions, 244 deletions
diff --git a/bot/constants.py b/bot/constants.py
index a4d5761be..3aacd0a16 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -584,6 +584,7 @@ class _Emojis(EnvConfig):
defcon_update = "<:defconsettingsupdated:470326274082996224>" # noqa: E704
failmail = "<:failmail:633660039931887616>"
+ failed_file = "<:failed_file:1073298441968562226>"
incident_actioned = "<:incident_actioned:714221559279255583>"
incident_investigating = "<:incident_investigating:714224190928191551>"
diff --git a/bot/exts/utils/snekbox/__init__.py b/bot/exts/utils/snekbox/__init__.py
new file mode 100644
index 000000000..cd1d3b059
--- /dev/null
+++ b/bot/exts/utils/snekbox/__init__.py
@@ -0,0 +1,12 @@
+from bot.bot import Bot
+from bot.exts.utils.snekbox._cog import CodeblockConverter, Snekbox
+from bot.exts.utils.snekbox._eval import EvalJob, EvalResult
+
+__all__ = ("CodeblockConverter", "Snekbox", "EvalJob", "EvalResult")
+
+
+async def setup(bot: Bot) -> None:
+ """Load the Snekbox cog."""
+ # Defer import to reduce side effects from importing the codeblock package.
+ from bot.exts.utils.snekbox._cog import Snekbox
+ await bot.add_cog(Snekbox(bot))
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox/_cog.py
index ddcbe01fa..b48fcf592 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox/_cog.py
@@ -1,11 +1,12 @@
+from __future__ import annotations
+
import asyncio
import contextlib
import re
from functools import partial
from operator import attrgetter
-from signal import Signals
from textwrap import dedent
-from typing import Literal, Optional, Tuple
+from typing import Literal, NamedTuple, Optional, TYPE_CHECKING
from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui
from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only
@@ -13,14 +14,21 @@ from pydis_core.utils import interactions
from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
from bot.bot import Bot
-from bot.constants import Channels, MODERATION_ROLES, Roles, URLs
+from bot.constants import Channels, Emojis, Filter, MODERATION_ROLES, Roles, URLs
from bot.decorators import redirect_output
+from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME
+from bot.exts.filters.antimalware import TXT_LIKE_FILES
from bot.exts.help_channels._channel import is_help_forum_post
+from bot.exts.utils.snekbox._eval import EvalJob, EvalResult
+from bot.exts.utils.snekbox._io import FileAttachment
from bot.log import get_logger
from bot.utils import send_to_paste_service
from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.services import PasteTooLongError, PasteUploadError
+if TYPE_CHECKING:
+ from bot.exts.filters.filtering import Filtering
+
log = get_logger(__name__)
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
@@ -68,17 +76,23 @@ if not hasattr(sys, "_setup_finished"):
"""
MAX_PASTE_LENGTH = 10_000
+# Max to display in a codeblock before sending to a paste service
+# This also applies to text files
+MAX_OUTPUT_BLOCK_LINES = 10
+MAX_OUTPUT_BLOCK_CHARS = 1000
# The Snekbox commands' whitelists and blacklists.
NO_SNEKBOX_CHANNELS = (Channels.python_general,)
NO_SNEKBOX_CATEGORIES = ()
SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
-SIGKILL = 9
-
REDO_EMOJI = '\U0001f501' # :repeat:
REDO_TIMEOUT = 30
+PythonVersion = Literal["3.10", "3.11"]
+
+FilteredFiles = NamedTuple("FilteredFiles", [("allowed", list[FileAttachment]), ("blocked", list[FileAttachment])])
+
class CodeblockConverter(Converter):
"""Attempts to extract code from a codeblock, if provided."""
@@ -122,21 +136,17 @@ class PythonVersionSwitcherButton(ui.Button):
def __init__(
self,
- job_name: str,
- version_to_switch_to: Literal["3.10", "3.11"],
- snekbox_cog: "Snekbox",
+ version_to_switch_to: PythonVersion,
+ snekbox_cog: Snekbox,
ctx: Context,
- code: str,
- args: Optional[list[str]] = None
+ job: EvalJob,
) -> None:
self.version_to_switch_to = version_to_switch_to
super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)
self.snekbox_cog = snekbox_cog
self.ctx = ctx
- self.job_name = job_name
- self.code = code
- self.args = args
+ self.job = job
async def callback(self, interaction: Interaction) -> None:
"""
@@ -149,13 +159,11 @@ class PythonVersionSwitcherButton(ui.Button):
await interaction.response.defer()
with contextlib.suppress(NotFound):
- # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button.
+ # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button.
# The log arg on send_job will stop the actual job from running.
await interaction.message.delete()
- await self.snekbox_cog.run_job(
- self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args
- )
+ await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to))
class Snekbox(Cog):
@@ -167,13 +175,12 @@ class Snekbox(Cog):
def build_python_version_switcher_view(
self,
- job_name: str,
- current_python_version: Literal["3.10", "3.11"],
+ current_python_version: PythonVersion,
ctx: Context,
- code: str,
- args: Optional[list[str]] = None
- ) -> None:
+ job: EvalJob,
+ ) -> interactions.ViewWithUserAndRoleCheck:
"""Return a view that allows the user to change what version of Python their code is run on."""
+ alt_python_version: PythonVersion
if current_python_version == "3.10":
alt_python_version = "3.11"
else:
@@ -183,33 +190,25 @@ class Snekbox(Cog):
allowed_users=(ctx.author.id,),
allowed_roles=MODERATION_ROLES,
)
- view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args))
+ view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job))
view.add_item(interactions.DeleteMessageButton())
return view
- async def post_job(
- self,
- code: str,
- python_version: Literal["3.10", "3.11"],
- *,
- args: Optional[list[str]] = None
- ) -> dict:
+ async def post_job(self, job: EvalJob) -> EvalResult:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
- if python_version == "3.10":
+ if job.version == "3.10":
url = URLs.snekbox_eval_api
else:
url = URLs.snekbox_311_eval_api
- data = {"input": code}
-
- if args is not None:
- data["args"] = args
+ data = job.to_dict()
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
- return await resp.json()
+ return EvalResult.from_dict(await resp.json())
- async def upload_output(self, output: str) -> Optional[str]:
+ @staticmethod
+ async def upload_output(output: str) -> Optional[str]:
"""Upload the job's output to a paste service and return a URL to it if successful."""
log.trace("Uploading full output to paste service...")
@@ -221,59 +220,27 @@ class Snekbox(Cog):
return "unable to upload"
@staticmethod
- def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:
+ def prepare_timeit_input(codeblocks: list[str]) -> list[str]:
"""
- Join the codeblocks into a single string, then return the code and the arguments in a tuple.
+ Join the codeblocks into a single string, then return the arguments in a list.
If there are multiple codeblocks, insert the first one into the wrapped setup code.
"""
args = ["-m", "timeit"]
- setup = ""
- if len(codeblocks) > 1:
- setup = codeblocks.pop(0)
-
+ setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else ""
code = "\n".join(codeblocks)
- args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
-
- return code, args
-
- @staticmethod
- def get_results_message(results: dict, job_name: str, python_version: Literal["3.10", "3.11"]) -> Tuple[str, str]:
- """Return a user-friendly message and error corresponding to the process's return code."""
- stdout, returncode = results["stdout"], results["returncode"]
- msg = f"Your {python_version} {job_name} job has completed with return code {returncode}"
- error = ""
-
- if returncode is None:
- msg = f"Your {python_version} {job_name} job has failed"
- error = stdout.strip()
- elif returncode == 128 + SIGKILL:
- msg = f"Your {python_version} {job_name} job timed out or ran out of memory"
- elif returncode == 255:
- msg = f"Your {python_version} {job_name} job has failed"
- error = "A fatal NsJail error occurred"
- else:
- # Try to append signal's name if one exists
- try:
- name = Signals(returncode - 128).name
- msg = f"{msg} ({name})"
- except ValueError:
- pass
-
- return msg, error
+ args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code])
+ return args
- @staticmethod
- def get_status_emoji(results: dict) -> str:
- """Return an emoji corresponding to the status code or lack of output in result."""
- if not results["stdout"].strip(): # No output
- return ":warning:"
- elif results["returncode"] == 0: # No error
- return ":white_check_mark:"
- else: # Exception
- return ":x:"
-
- async def format_output(self, output: str) -> Tuple[str, Optional[str]]:
+ async def format_output(
+ self,
+ output: str,
+ max_lines: int = MAX_OUTPUT_BLOCK_LINES,
+ max_chars: int = MAX_OUTPUT_BLOCK_CHARS,
+ line_nums: bool = True,
+ output_default: str = "[No output]",
+ ) -> tuple[str, str | None]:
"""
Format the output and return a tuple of the formatted output and a URL to the full output.
@@ -295,96 +262,182 @@ class Snekbox(Cog):
return "Code block escape attempt detected; will not output result", paste_link
truncated = False
- lines = output.count("\n")
+ lines = output.splitlines()
- if lines > 0:
- output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)]
- output = output[:11] # Limiting to only 11 lines
- output = "\n".join(output)
+ if len(lines) > 1:
+ if line_nums:
+ lines = [f"{i:03d} | {line}" for i, line in enumerate(lines, 1)]
+ lines = lines[:max_lines+1] # Limiting to max+1 lines
+ output = "\n".join(lines)
- if lines > 10:
+ if len(lines) > max_lines:
truncated = True
- if len(output) >= 1000:
- output = f"{output[:1000]}\n... (truncated - too long, too many lines)"
+ if len(output) >= max_chars:
+ output = f"{output[:max_chars]}\n... (truncated - too long, too many lines)"
else:
output = f"{output}\n... (truncated - too many lines)"
- elif len(output) >= 1000:
+ elif len(output) >= max_chars:
truncated = True
- output = f"{output[:1000]}\n... (truncated - too long)"
+ output = f"{output[:max_chars]}\n... (truncated - too long)"
if truncated:
paste_link = await self.upload_output(original_output)
- output = output or "[No output]"
+ if output_default and not output:
+ output = output_default
return output, paste_link
+ def get_extensions_whitelist(self) -> set[str]:
+ """Return a set of whitelisted file extensions."""
+ return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES
+
+ def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles:
+ """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists."""
+ # Check if user is staff, if is, return
+ # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
+ if hasattr(ctx.author, "roles") and any(role.id in Filter.role_whitelist for role in ctx.author.roles):
+ return FilteredFiles(files, [])
+ # Ignore code jam channels
+ if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME:
+ return FilteredFiles(files, [])
+
+ # Get whitelisted extensions
+ whitelist = self.get_extensions_whitelist()
+
+ # Filter files into allowed and blocked
+ blocked = []
+ allowed = []
+ for file in files:
+ if file.suffix in whitelist:
+ allowed.append(file)
+ else:
+ blocked.append(file)
+
+ if blocked:
+ blocked_str = ", ".join(f.suffix for f in blocked)
+ log.info(
+ f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}",
+ extra={"attachment_list": [f.path for f in files]}
+ )
+
+ return FilteredFiles(allowed, blocked)
+
@lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True)
- async def send_job(
- self,
- ctx: Context,
- python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
- job_name: str
- ) -> Message:
+ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_job(code, python_version, args=args)
- msg, error = self.get_results_message(results, job_name, python_version)
+ result = await self.post_job(job)
+ msg = result.get_message(job)
+ error = result.error_message
if error:
output, paste_link = error, None
else:
log.trace("Formatting output...")
- output, paste_link = await self.format_output(results["stdout"])
+ output, paste_link = await self.format_output(result.stdout)
- warning_message = ""
+ msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"
# This is done to make sure the last line of output contains the error
# and the error is not manually printed by the author with a syntax error.
- if results["stdout"].rstrip().endswith("EOFError: EOF when reading a line") and results["returncode"] == 1:
- warning_message += ":warning: Note: `input` is not supported by the bot :warning:\n\n"
+ if result.stdout.rstrip().endswith("EOFError: EOF when reading a line") and result.returncode == 1:
+ msg += ":warning: Note: `input` is not supported by the bot :warning:\n\n"
+
+ # Skip output if it's empty and there are file uploads
+ if result.stdout or not result.has_files:
+ msg += f"\n```\n{output}\n```"
- icon = self.get_status_emoji(results)
- msg = f"{ctx.author.mention} {icon} {msg}.\n\n{warning_message}```\n{output}\n```"
if paste_link:
- msg = f"{msg}\nFull output: {paste_link}"
+ msg += f"\nFull output: {paste_link}"
+
+ # Additional files error message after output
+ if files_error := result.files_error_message:
+ msg += f"\n{files_error}"
# Collect stats of job fails + successes
- if icon == ":x:":
+ if result.returncode != 0:
self.bot.stats.incr("snekbox.python.fail")
else:
self.bot.stats.incr("snekbox.python.success")
- filter_cog = self.bot.get_cog("Filtering")
- filter_triggered = False
- if filter_cog:
- filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message)
- if filter_triggered:
- response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
- else:
- allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
- view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args)
- response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view)
- view.message = response
-
- log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")
+ # Filter file extensions
+ allowed, blocked = self._filter_files(ctx, result.files)
+ # Also scan failed files for blocked extensions
+ failed_files = [FileAttachment(name, b"") for name in result.failed_files]
+ blocked.extend(self._filter_files(ctx, failed_files).blocked)
+ # Add notice if any files were blocked
+ if blocked:
+ blocked_sorted = sorted(set(f.suffix for f in blocked))
+ # Only no extension
+ if len(blocked_sorted) == 1 and blocked_sorted[0] == "":
+ blocked_msg = "Files with no extension can't be uploaded."
+ # Both
+ elif "" in blocked_sorted:
+ blocked_str = ", ".join(ext for ext in blocked_sorted if ext)
+ blocked_msg = (
+ f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**"
+ )
+ else:
+ blocked_str = ", ".join(blocked_sorted)
+ blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**"
+
+ msg += f"\n{Emojis.failed_file} {blocked_msg}"
+
+ # Split text files
+ text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES]
+ # Inline until budget, then upload to paste service
+ # Budget is shared with stdout, so subtract what we've already used
+ budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1)
+ budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output)
+ for file in text_files:
+ file_text = file.content.decode("utf-8", errors="replace") or "[Empty]"
+ # Override to always allow 1 line and <= 50 chars, since this is less than a link
+ if len(file_text) <= 50 and not file_text.count("\n"):
+ msg += f"\n`{file.name}`\n```\n{file_text}\n```"
+ # otherwise, use budget
+ else:
+ format_text, link_text = await self.format_output(
+ file_text,
+ budget_lines,
+ budget_chars,
+ line_nums=False,
+ output_default="[Empty]"
+ )
+ # With any link, use it (don't use budget)
+ if link_text:
+ msg += f"\n`{file.name}`\n{link_text}"
+ else:
+ msg += f"\n`{file.name}`\n```\n{format_text}\n```"
+ budget_lines -= format_text.count("\n") + 1
+ budget_chars -= len(file_text)
+
+ filter_cog: Filtering | None = self.bot.get_cog("Filtering")
+ if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)):
+ return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
+
+ # Upload remaining non-text files
+ files = [f.to_file() for f in allowed if f not in text_files]
+ allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
+ view = self.build_python_version_switcher_view(job.version, ctx, job)
+ response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files)
+ view.message = response
+
+ log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}")
return response
async def continue_job(
self, ctx: Context, response: Message, job_name: str
- ) -> tuple[Optional[str], Optional[list[str]]]:
+ ) -> EvalJob | None:
"""
Check if the job's session should continue.
- If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command.
- Otherwise return (None, None) if the job's session should be terminated.
+ If the code is to be re-evaluated, return the new EvalJob.
+ Otherwise, return None if the job's session should be terminated.
"""
_predicate_message_edit = partial(predicate_message_edit, ctx)
_predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx)
@@ -406,7 +459,7 @@ class Snekbox(Cog):
# Ensure the response that's about to be edited is still the most recent.
# This could have already been updated via a button press to switch to an alt Python version.
if self.jobs[ctx.message.id] != response.id:
- return None, None
+ return None
code = await self.get_code(new_message, ctx.command)
with contextlib.suppress(HTTPException):
@@ -414,21 +467,21 @@ class Snekbox(Cog):
await response.delete()
if code is None:
- return None, None
+ return None
except asyncio.TimeoutError:
with contextlib.suppress(HTTPException):
await ctx.message.clear_reaction(REDO_EMOJI)
- return None, None
+ return None
codeblocks = await CodeblockConverter.convert(ctx, code)
if job_name == "timeit":
- return self.prepare_timeit_input(codeblocks)
+ return EvalJob(self.prepare_timeit_input(codeblocks))
else:
- return "\n".join(codeblocks), None
+ return EvalJob.from_code("\n".join(codeblocks))
- return None, None
+ return None
async def get_code(self, message: Message, command: Command) -> Optional[str]:
"""
@@ -452,12 +505,8 @@ class Snekbox(Cog):
async def run_job(
self,
- job_name: str,
ctx: Context,
- python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
+ job: EvalJob,
) -> None:
"""Handles checks, stats and re-evaluation of a snekbox job."""
if Roles.helpers in (role.id for role in ctx.author.roles):
@@ -472,11 +521,11 @@ class Snekbox(Cog):
else:
self.bot.stats.incr("snekbox_usages.channels.topical")
- log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
+ log.info(f"Received code from {ctx.author} for evaluation:\n{job}")
while True:
try:
- response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name)
+ response = await self.send_job(ctx, job)
except LockedResourceError:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
@@ -489,10 +538,10 @@ class Snekbox(Cog):
# This can happen when a button is pressed and then original code is edited and re-run.
self.jobs[ctx.message.id] = response.id
- code, args = await self.continue_job(ctx, response, job_name)
- if not code:
+ job = await self.continue_job(ctx, response, job.name)
+ if not job:
break
- log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
+ log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}")
@command(name="eval", aliases=("e",), usage="[python_version] <code, ...>")
@guild_only()
@@ -506,29 +555,32 @@ class Snekbox(Cog):
async def eval_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
"""
Run Python code and get the results.
- This command supports multiple lines of code, including code wrapped inside a formatted code
- block. Code can be re-evaluated by editing the original message within 10 seconds and
+ This command supports multiple lines of code, including formatted code blocks.
+ Code can be re-evaluated by editing the original message within 10 seconds and
clicking the reaction that subsequently appears.
+ The starting working directory `/home`, is a writeable temporary file system.
+ Files created, excluding names with leading underscores, will be uploaded in the response.
+
If multiple codeblocks are in a message, all of them will be joined and evaluated,
- ignoring the text outside of them.
+ ignoring the text outside them.
- By default your code is run on Python's 3.11 beta release, to assist with testing. If you
- run into issues related to this Python version, you can request the bot to use Python
- 3.10 by specifying the `python_version` arg and setting it to `3.10`.
+ By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified.
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- await self.run_job("eval", ctx, python_version, "\n".join(code))
+ job = EvalJob.from_code("\n".join(code)).as_version(python_version)
+ await self.run_job(ctx, job)
@command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")
@guild_only()
@@ -542,7 +594,7 @@ class Snekbox(Cog):
async def timeit_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
@@ -556,17 +608,17 @@ class Snekbox(Cog):
If multiple formatted codeblocks are provided, the first one will be the setup code, which will
not be timed. The remaining codeblocks will be joined together and timed.
- By default your code is run on Python's 3.11 beta release, to assist with testing. If you
- run into issues related to this Python version, you can request the bot to use Python
- 3.10 by specifying the `python_version` arg and setting it to `3.10`.
+ By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified.
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- code, args = self.prepare_timeit_input(code)
+ args = self.prepare_timeit_input(code)
+ job = EvalJob(args, version=python_version, name="timeit")
- await self.run_job("timeit", ctx, python_version, code=code, args=args)
+ await self.run_job(ctx, job)
def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
@@ -577,8 +629,3 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) ->
def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool:
"""Return True if the reaction REDO_EMOJI was added by the context message author on this message."""
return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI
-
-
-async def setup(bot: Bot) -> None:
- """Load the Snekbox cog."""
- await bot.add_cog(Snekbox(bot))
diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py
new file mode 100644
index 000000000..2f61b5924
--- /dev/null
+++ b/bot/exts/utils/snekbox/_eval.py
@@ -0,0 +1,183 @@
+from __future__ import annotations
+
+import contextlib
+from dataclasses import dataclass, field
+from signal import Signals
+from typing import TYPE_CHECKING
+
+from discord.utils import escape_markdown, escape_mentions
+
+from bot.constants import Emojis
+from bot.exts.utils.snekbox._io import FILE_COUNT_LIMIT, FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt
+from bot.log import get_logger
+
+if TYPE_CHECKING:
+ from bot.exts.utils.snekbox._cog import PythonVersion
+
+log = get_logger(__name__)
+
+SIGKILL = 9
+
+
+@dataclass(frozen=True)
+class EvalJob:
+ """Job to be evaluated by snekbox."""
+
+ args: list[str]
+ files: list[FileAttachment] = field(default_factory=list)
+ name: str = "eval"
+ version: PythonVersion = "3.11"
+
+ @classmethod
+ def from_code(cls, code: str, path: str = "main.py") -> EvalJob:
+ """Create an EvalJob from a code string."""
+ return cls(
+ args=[path],
+ files=[FileAttachment(path, code.encode())],
+ )
+
+ def as_version(self, version: PythonVersion) -> EvalJob:
+ """Return a copy of the job with a different Python version."""
+ return EvalJob(
+ args=self.args,
+ files=self.files,
+ name=self.name,
+ version=version,
+ )
+
+ def to_dict(self) -> dict[str, list[str | dict[str, str]]]:
+ """Convert the job to a dict."""
+ return {
+ "args": self.args,
+ "files": [file.to_dict() for file in self.files],
+ }
+
+
+@dataclass(frozen=True)
+class EvalResult:
+ """The result of an eval job."""
+
+ stdout: str
+ returncode: int | None
+ files: list[FileAttachment] = field(default_factory=list)
+ failed_files: list[str] = field(default_factory=list)
+
+ @property
+ def has_output(self) -> bool:
+ """True if the result has any output (stdout, files, or failed files)."""
+ return bool(self.stdout.strip() or self.files or self.failed_files)
+
+ @property
+ def has_files(self) -> bool:
+ """True if the result has any files or failed files."""
+ return bool(self.files or self.failed_files)
+
+ @property
+ def status_emoji(self) -> str:
+ """Return an emoji corresponding to the status code or lack of output in result."""
+ if not self.has_output:
+ return ":warning:"
+ elif self.returncode == 0: # No error
+ return ":white_check_mark:"
+ else: # Exception
+ return ":x:"
+
+ @property
+ def error_message(self) -> str:
+ """Return an error message corresponding to the process's return code."""
+ error = ""
+ if self.returncode is None:
+ error = self.stdout.strip()
+ elif self.returncode == 255:
+ error = "A fatal NsJail error occurred"
+ return error
+
+ @property
+ def files_error_message(self) -> str:
+ """Return an error message corresponding to the failed files."""
+ if not self.failed_files:
+ return ""
+
+ failed_files = f"({self.get_failed_files_str()})"
+
+ n_failed = len(self.failed_files)
+ s_upload = "uploads" if n_failed > 1 else "upload"
+
+ msg = f"{Emojis.failed_file} {n_failed} file {s_upload} {failed_files} failed"
+
+ # Exceeded file count limit
+ if (n_failed + len(self.files)) > FILE_COUNT_LIMIT:
+ s_it = "they" if n_failed > 1 else "it"
+ msg += f" as {s_it} exceeded the {FILE_COUNT_LIMIT} file limit."
+ # Exceeded file size limit
+ else:
+ s_each_file = "each file's" if n_failed > 1 else "its file"
+ msg += f" because {s_each_file} size exceeds {sizeof_fmt(FILE_SIZE_LIMIT)}."
+
+ return msg
+
+ def get_failed_files_str(self, char_max: int = 85) -> str:
+ """
+ Return a string containing the names of failed files, truncated char_max.
+
+ Will truncate on whole file names if less than 3 characters remaining.
+ """
+ names = []
+ for file in self.failed_files:
+ # Only attempt to truncate name if more than 3 chars remaining
+ if char_max < 3:
+ names.append("...")
+ break
+
+ if len(file) > char_max:
+ names.append(file[:char_max] + "...")
+ break
+ char_max -= len(file)
+ names.append(file)
+
+ text = ", ".join(names)
+ # Since the file names are provided by user
+ text = escape_markdown(text)
+ text = escape_mentions(text)
+ return text
+
+ def get_message(self, job: EvalJob) -> str:
+ """Return a user-friendly message corresponding to the process's return code."""
+ msg = f"Your {job.version} {job.name} job"
+
+ if self.returncode is None:
+ msg += " has failed"
+ elif self.returncode == 128 + SIGKILL:
+ msg += " timed out or ran out of memory"
+ elif self.returncode == 255:
+ msg += " has failed"
+ else:
+ msg += f" has completed with return code {self.returncode}"
+ # Try to append signal's name if one exists
+ with contextlib.suppress(ValueError):
+ name = Signals(self.returncode - 128).name
+ msg += f" ({name})"
+
+ return msg
+
+ @classmethod
+ def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult:
+ """Create an EvalResult from a dict."""
+ res = cls(
+ stdout=data["stdout"],
+ returncode=data["returncode"],
+ )
+
+ files = iter(data["files"])
+ for i, file in enumerate(files):
+ # Limit to FILE_COUNT_LIMIT files
+ if i >= FILE_COUNT_LIMIT:
+ res.failed_files.extend(file["path"] for file in files)
+ break
+ try:
+ res.files.append(FileAttachment.from_dict(file))
+ except ValueError as e:
+ log.info(f"Failed to parse file from snekbox response: {e}")
+ res.failed_files.append(file["path"])
+
+ return res
diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py
new file mode 100644
index 000000000..9be396335
--- /dev/null
+++ b/bot/exts/utils/snekbox/_io.py
@@ -0,0 +1,102 @@
+"""I/O File protocols for snekbox."""
+from __future__ import annotations
+
+from base64 import b64decode, b64encode
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import PurePosixPath
+
+import regex
+from discord import File
+
+# Note discord bot upload limit is 8 MiB per file,
+# or 50 MiB for lvl 2 boosted servers
+FILE_SIZE_LIMIT = 8 * 1024 * 1024
+
+# Discord currently has a 10-file limit per message
+FILE_COUNT_LIMIT = 10
+
+
+# ANSI escape sequences
+RE_ANSI = regex.compile(r"\\u.*\[(.*?)m")
+# Characters with a leading backslash
+RE_BACKSLASH = regex.compile(r"\\.")
+# Discord disallowed file name characters
+RE_DISCORD_FILE_NAME_DISALLOWED = regex.compile(r"[^a-zA-Z0-9._-]+")
+
+
+def sizeof_fmt(num: int | float, suffix: str = "B") -> str:
+ """Return a human-readable file size."""
+ num = float(num)
+ for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
+ if abs(num) < 1024:
+ num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}"
+ return f"{num_str} {unit}{suffix}"
+ num /= 1024
+ num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}"
+ return f"{num_str} Yi{suffix}"
+
+
+def normalize_discord_file_name(name: str) -> str:
+ """Return a normalized valid discord file name."""
+ # Discord file names only allow A-Z, a-z, 0-9, underscores, dashes, and dots
+ # https://discord.com/developers/docs/reference#uploading-files
+ # Server will remove any other characters, but we'll get a 400 error for \ escaped chars
+ name = RE_ANSI.sub("_", name)
+ name = RE_BACKSLASH.sub("_", name)
+ # Replace any disallowed character with an underscore
+ name = RE_DISCORD_FILE_NAME_DISALLOWED.sub("_", name)
+ return name
+
+
+@dataclass(frozen=True)
+class FileAttachment:
+ """File Attachment from Snekbox eval."""
+
+ path: str
+ content: bytes
+
+ def __repr__(self) -> str:
+ """Return the content as a string."""
+ content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content
+ return f"FileAttachment(path={self.path!r}, content={content})"
+
+ @property
+ def suffix(self) -> str:
+ """Return the file suffix."""
+ return PurePosixPath(self.path).suffix
+
+ @property
+ def name(self) -> str:
+ """Return the file name."""
+ return PurePosixPath(self.path).name
+
+ @classmethod
+ def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment:
+ """Create a FileAttachment from a dict response."""
+ size = data.get("size")
+ if (size and size > size_limit) or (len(data["content"]) > size_limit):
+ raise ValueError("File size exceeds limit")
+
+ content = b64decode(data["content"])
+
+ if len(content) > size_limit:
+ raise ValueError("File size exceeds limit")
+
+ return cls(data["path"], content)
+
+ def to_dict(self) -> dict[str, str]:
+ """Convert the attachment to a json dict."""
+ content = self.content
+ if isinstance(content, str):
+ content = content.encode("utf-8")
+
+ return {
+ "path": self.path,
+ "content": b64encode(content).decode("ascii"),
+ }
+
+ def to_file(self) -> File:
+ """Convert to a discord.File."""
+ name = normalize_discord_file_name(self.name)
+ return File(BytesIO(self.content), filename=name)
diff --git a/tests/bot/exts/utils/snekbox/__init__.py b/tests/bot/exts/utils/snekbox/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/exts/utils/snekbox/__init__.py
diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py
new file mode 100644
index 000000000..bcf1162b8
--- /dev/null
+++ b/tests/bot/exts/utils/snekbox/test_io.py
@@ -0,0 +1,34 @@
+from unittest import TestCase
+
+# noinspection PyProtectedMember
+from bot.exts.utils.snekbox import _io
+
+
+class SnekboxIOTests(TestCase):
+ # noinspection SpellCheckingInspection
+ def test_normalize_file_name(self):
+ """Invalid file names should be normalized."""
+ cases = [
+ # ANSI escape sequences -> underscore
+ (r"\u001b[31mText", "_Text"),
+ # (Multiple consecutive should be collapsed to one underscore)
+ (r"a\u001b[35m\u001b[37mb", "a_b"),
+ # Backslash escaped chars -> underscore
+ (r"\n", "_"),
+ (r"\r", "_"),
+ (r"A\0\tB", "A__B"),
+ # Any other disallowed chars -> underscore
+ (r"\\.txt", "_.txt"),
+ (r"A!@#$%^&*B, C()[]{}+=D.txt", "A_B_C_D.txt"), # noqa: P103
+ (" ", "_"),
+ # Normal file names should be unchanged
+ ("legal_file-name.txt", "legal_file-name.txt"),
+ ("_-.", "_-."),
+ ]
+ for name, expected in cases:
+ with self.subTest(name=name, expected=expected):
+ # Test function directly
+ self.assertEqual(_io.normalize_discord_file_name(name), expected)
+ # Test FileAttachment.to_file()
+ obj = _io.FileAttachment(name, b"")
+ self.assertEqual(obj.to_file().filename, expected)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py
index b1f32c210..9dcf7fd8c 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/snekbox/test_snekbox.py
@@ -1,5 +1,6 @@
import asyncio
import unittest
+from base64 import b64encode
from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
from discord import AllowedMentions
@@ -8,7 +9,8 @@ from discord.ext import commands
from bot import constants
from bot.errors import LockedResourceError
from bot.exts.utils import snekbox
-from bot.exts.utils.snekbox import Snekbox
+from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox
+from bot.exts.utils.snekbox._io import FileAttachment
from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser
@@ -17,34 +19,55 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Add mocked bot and cog to the instance."""
self.bot = MockBot()
self.cog = Snekbox(bot=self.bot)
+ self.job = EvalJob.from_code("import random")
+
+ @staticmethod
+ def code_args(code: str) -> tuple[EvalJob]:
+ """Converts code to a tuple of arguments expected."""
+ return EvalJob.from_code(code),
async def test_post_job(self):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
resp = MagicMock()
- resp.json = AsyncMock(return_value="return")
+ resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []})
context_manager = MagicMock()
context_manager.__aenter__.return_value = resp
self.bot.http_session.post.return_value = context_manager
- self.assertEqual(await self.cog.post_job("import random", "3.10"), "return")
+ job = EvalJob.from_code("import random").as_version("3.10")
+ self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137))
+
+ expected = {
+ "args": ["main.py"],
+ "files": [
+ {
+ "path": "main.py",
+ "content": b64encode("import random".encode()).decode()
+ }
+ ]
+ }
self.bot.http_session.post.assert_called_with(
constants.URLs.snekbox_eval_api,
- json={"input": "import random"},
+ json=expected,
raise_for_status=True
)
resp.json.assert_awaited_once()
async def test_upload_output_reject_too_long(self):
"""Reject output longer than MAX_PASTE_LENGTH."""
- result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1))
+ result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1))
self.assertEqual(result, "too long to upload")
- @patch("bot.exts.utils.snekbox.send_to_paste_service")
+ @patch("bot.exts.utils.snekbox._cog.send_to_paste_service")
async def test_upload_output(self, mock_paste_util):
"""Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
await self.cog.upload_output("Test output.")
- mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH)
+ mock_paste_util.assert_called_once_with(
+ "Test output.",
+ extension="txt",
+ max_length=snekbox._cog.MAX_PASTE_LENGTH
+ )
async def test_codeblock_converter(self):
ctx = MockContext()
@@ -76,40 +99,94 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
(['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code')
)
- for case, setup_code, testname in cases:
- setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code)
- expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup])
- with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
+ for case, setup_code, test_name in cases:
+ setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code)
+ expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)]
+ with self.subTest(msg=f'Test with {test_name} and expected return {expected}'):
self.assertEqual(self.cog.prepare_timeit_input(case), expected)
- def test_get_results_message(self):
- """Return error and message according to the eval result."""
+ def test_eval_result_message(self):
+ """EvalResult.get_message(), should return message."""
cases = (
- ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')),
- ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')),
- ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred'))
+ ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')),
+ ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')),
+ ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', ''))
)
for stdout, returncode, expected in cases:
+ exp_msg, exp_err, exp_files_err = expected
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11')
- self.assertEqual(actual, expected)
-
- @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)
- def test_get_results_message_invalid_signal(self, mock_signals: Mock):
+ result = EvalResult(stdout=stdout, returncode=returncode)
+ job = EvalJob([])
+ # Check all 3 message types
+ msg = result.get_message(job)
+ self.assertEqual(msg, exp_msg)
+ error = result.error_message
+ self.assertEqual(error, exp_err)
+ files_error = result.files_error_message
+ self.assertEqual(files_error, exp_files_err)
+
+ @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2)
+ def test_eval_result_files_error_message(self):
+ """EvalResult.files_error_message, should return files error message."""
+ cases = [
+ ([], ["abc"], (
+ "1 file upload (abc) failed because its file size exceeds 8 MiB."
+ )),
+ ([], ["file1.bin", "f2.bin"], (
+ "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB."
+ )),
+ (["a", "b"], ["c"], (
+ "1 file upload (c) failed as it exceeded the 2 file limit."
+ )),
+ (["a"], ["b", "c"], (
+ "2 file uploads (b, c) failed as they exceeded the 2 file limit."
+ )),
+ ]
+ for files, failed_files, expected_msg in cases:
+ with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg):
+ result = EvalResult("", 0, files, failed_files)
+ msg = result.files_error_message
+ self.assertIn(expected_msg, msg)
+
+ @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2)
+ def test_eval_result_files_error_str(self):
+ """EvalResult.files_error_message, should return files error message."""
+ cases = [
+ # Normal
+ (["x.ini"], "x.ini"),
+ (["123456", "879"], "123456, 879"),
+ # Break on whole name if less than 3 characters remaining
+ (["12345678", "9"], "12345678, ..."),
+ # Otherwise break on max chars
+ (["123", "345", "67890000"], "123, 345, 6789..."),
+ (["abcdefg1234567"], "abcdefg123..."),
+ ]
+ for failed_files, expected in cases:
+ with self.subTest(failed_files=failed_files, expected=expected):
+ result = EvalResult("", 0, [], failed_files)
+ msg = result.get_failed_files_str(char_max=10)
+ self.assertEqual(msg, expected)
+
+ @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError)
+ def test_eval_result_message_invalid_signal(self, _mock_signals: Mock):
+ result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
- ('Your 3.11 eval job has completed with return code 127', '')
+ result.get_message(EvalJob([], version="3.10")),
+ "Your 3.10 eval job has completed with return code 127"
)
+ self.assertEqual(result.error_message, "")
+ self.assertEqual(result.files_error_message, "")
- @patch('bot.exts.utils.snekbox.Signals')
- def test_get_results_message_valid_signal(self, mock_signals: Mock):
- mock_signals.return_value.name = 'SIGTEST'
+ @patch('bot.exts.utils.snekbox._eval.Signals')
+ def test_eval_result_message_valid_signal(self, mock_signals: Mock):
+ mock_signals.return_value.name = "SIGTEST"
+ result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
- ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '')
+ result.get_message(EvalJob([], version="3.11")),
+ "Your 3.11 eval job has completed with return code 127 (SIGTEST)"
)
- def test_get_status_emoji(self):
+ def test_eval_result_status_emoji(self):
"""Return emoji according to the eval result."""
cases = (
(' ', -1, ':warning:'),
@@ -118,8 +195,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for stdout, returncode, expected in cases:
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
- self.assertEqual(actual, expected)
+ result = EvalResult(stdout=stdout, returncode=returncode)
+ self.assertEqual(result.status_emoji, expected)
async def test_format_output(self):
"""Test output formatting."""
@@ -178,10 +255,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.command = MagicMock()
self.cog.send_job = AsyncMock(return_value=response)
- self.cog.continue_job = AsyncMock(return_value=(None, None))
+ self.cog.continue_job = AsyncMock(return_value=None)
await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
- self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode")
+ self.cog.send_job.assert_called_once_with(ctx, job)
self.cog.continue_job.assert_called_once_with(ctx, response, 'eval')
async def test_eval_command_evaluate_twice(self):
@@ -191,13 +269,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.command = MagicMock()
self.cog.send_job = AsyncMock(return_value=response)
self.cog.continue_job = AsyncMock()
- self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None))
+ self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None)
await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
- self.cog.send_job.assert_called_with(
- ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval'
- )
- self.cog.continue_job.assert_called_with(ctx, response, 'eval')
+
+ expected_job = EvalJob.from_code("MyAwesomeFormattedCode")
+ self.cog.send_job.assert_called_with(ctx, expected_job)
+ self.cog.continue_job.assert_called_with(ctx, response, "eval")
async def test_eval_command_reject_two_eval_at_the_same_time(self):
"""Test if the eval command rejects an eval if the author already have a running eval."""
@@ -212,8 +290,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect)
with self.assertRaises(LockedResourceError):
await asyncio.gather(
- self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
- self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
+ self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")),
+ self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")),
)
async def test_send_job(self):
@@ -223,30 +301,31 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author = MockUser(mention='@LemonLemonishBeard#0042')
- self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0})
- self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
- self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ eval_result = EvalResult("", 0)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.format_output = AsyncMock(return_value=('[No output]', None))
+ self.cog.upload_output = AsyncMock() # Should not be called
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code('MyAwesomeCode')
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```'
+ '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed '
+ 'with return code 0.\n\n```\n[No output]\n```'
)
allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions']
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11')
+ self.cog.post_job.assert_called_once_with(job)
self.cog.format_output.assert_called_once_with('')
+ self.cog.upload_output.assert_not_called()
async def test_send_job_with_paste_link(self):
"""Test the send_job function with a too long output that generate a paste link."""
@@ -255,29 +334,26 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
- self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ eval_result = EvalResult("Way too long beard", 0)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :yay!: Return code 0.'
+ '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job '
+ 'has completed with return code 0.'
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with(
- {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11'
- )
+ self.cog.post_job.assert_called_once_with(job)
self.cog.format_output.assert_called_once_with('Way too long beard')
async def test_send_job_with_non_zero_eval(self):
@@ -286,29 +362,57 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
- self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
- self.cog.format_output = AsyncMock() # This function isn't called
+
+ eval_result = EvalResult("ERROR", 127)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
+ self.cog.upload_output = AsyncMock() # This function isn't called
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
+ '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.'
+ '\n\n```\nERROR\n```'
+ )
+
+ self.cog.post_job.assert_called_once_with(job)
+ self.cog.upload_output.assert_not_called()
+
+ async def test_send_job_with_disallowed_file_ext(self):
+ """Test send_job with disallowed file extensions."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = "@user#7700"
+
+ eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")])
+ self.cog.post_job = AsyncMock(return_value=eval_result)
+ self.cog.upload_output = AsyncMock() # This function isn't called
+
+ mocked_filter_cog = MagicMock()
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
+ self.bot.get_cog.return_value = mocked_filter_cog
+
+ job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
+ await self.cog.send_job(ctx, job),
+
+ ctx.send.assert_called_once()
+ res = ctx.send.call_args.args[0]
+ self.assertTrue(
+ res.startswith("@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.")
)
+ self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11')
- self.cog.format_output.assert_not_called()
+ self.cog.post_job.assert_called_once_with(job)
+ self.cog.upload_output.assert_not_called()
- @patch("bot.exts.utils.snekbox.partial")
+ @patch("bot.exts.utils.snekbox._cog.partial")
async def test_continue_job_does_continue(self, partial_mock):
"""Test that the continue_job function does continue if required conditions are met."""
ctx = MockContext(
@@ -328,19 +432,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
actual = await self.cog.continue_job(ctx, response, self.cog.eval_command)
self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command)
- self.assertEqual(actual, (expected, None))
+ self.assertEqual(actual, EvalJob.from_code(expected))
self.bot.wait_for.assert_has_awaits(
(
call(
'message_edit',
- check=partial_mock(snekbox.predicate_message_edit, ctx),
- timeout=snekbox.REDO_TIMEOUT,
+ check=partial_mock(snekbox._cog.predicate_message_edit, ctx),
+ timeout=snekbox._cog.REDO_TIMEOUT,
),
- call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10)
+ call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10)
)
)
- ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
- ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
+ ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI)
+ ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI)
response.delete.assert_called_once()
async def test_continue_job_does_not_continue(self):
@@ -348,8 +452,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.wait_for.side_effect = asyncio.TimeoutError
actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command)
- self.assertEqual(actual, (None, None))
- ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
+ self.assertEqual(actual, None)
+ ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI)
async def test_get_code(self):
"""Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
@@ -391,18 +495,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
for ctx_msg, new_msg, expected, testname in cases:
with self.subTest(msg=f'Messages with {testname} return {expected}'):
ctx = MockContext(message=ctx_msg)
- actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg)
+ actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg)
self.assertEqual(actual, expected)
def test_predicate_emoji_reaction(self):
"""Test the predicate_emoji_reaction function."""
valid_reaction = MockReaction(message=MockMessage(id=1))
- valid_reaction.__str__.return_value = snekbox.REDO_EMOJI
+ valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI
valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))
valid_user = MockUser(id=2)
invalid_reaction_id = MockReaction(message=MockMessage(id=42))
- invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI
+ invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI
invalid_user_id = MockUser(id=42)
invalid_reaction_str = MockReaction(message=MockMessage(id=1))
invalid_reaction_str.__str__.return_value = ':longbeard:'
@@ -415,7 +519,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for reaction, user, expected, testname in cases:
with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
- actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user)
+ actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user)
self.assertEqual(actual, expected)