aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox.py306
-rw-r--r--bot/exts/utils/snekio.py64
-rw-r--r--tests/bot/exts/utils/test_snekbox.py138
3 files changed, 322 insertions, 186 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 8a2e68b28..cd090ed79 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -1,11 +1,14 @@
+from __future__ import annotations
+
import asyncio
import contextlib
import re
+from dataclasses import dataclass, field
from functools import partial
from operator import attrgetter
from signal import Signals
from textwrap import dedent
-from typing import Literal, Optional, Tuple
+from typing import Literal, Optional, TYPE_CHECKING, Tuple
from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui
from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only
@@ -15,12 +18,16 @@ from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES, Roles, URLs
from bot.decorators import redirect_output
+from bot.exts.utils.snekio import FileAttachment, sizeof_fmt, FILE_SIZE_LIMIT
from bot.exts.help_channels._channel import is_help_forum_post
from bot.log import get_logger
from bot.utils import send_to_paste_service
from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.services import PasteTooLongError, PasteUploadError
+if TYPE_CHECKING:
+ from bot.exts.filters.filtering import Filtering
+
log = get_logger(__name__)
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
@@ -79,6 +86,109 @@ SIGKILL = 9
REDO_EMOJI = '\U0001f501' # :repeat:
REDO_TIMEOUT = 30
+PythonVersion = Literal["3.10", "3.11"]
+
+
+@dataclass
+class EvalJob:
+ """Job to be evaluated by snekbox."""
+
+ args: list[str]
+ files: list[FileAttachment] = field(default_factory=list)
+ name: str = "eval"
+ version: PythonVersion = "3.11"
+
+ @classmethod
+ def from_code(cls, code: str, path: str = "main.py") -> EvalJob:
+ """Create an EvalJob from a code string."""
+ return cls(
+ args=[path],
+ files=[FileAttachment(path, code.encode())],
+ )
+
+ def as_version(self, version: PythonVersion) -> EvalJob:
+ """Return a copy of the job with a different Python version."""
+ return EvalJob(
+ args=self.args,
+ files=self.files,
+ name=self.name,
+ version=version,
+ )
+
+ def to_dict(self) -> dict[str, list[str | dict[str, str]]]:
+ """Convert the job to a dict."""
+ return {
+ "args": self.args,
+ "files": [file.to_dict() for file in self.files],
+ }
+
+
+@dataclass(frozen=True)
+class EvalResult:
+ """The result of an eval job."""
+
+ stdout: str
+ returncode: int | None
+ files: list[FileAttachment] = field(default_factory=list)
+ err_files: list[str] = field(default_factory=list)
+
+ @property
+ def status_emoji(self):
+ """Return an emoji corresponding to the status code or lack of output in result."""
+ # If there are attachments, skip empty output warning
+ if not self.stdout.strip() and not self.files: # No output
+ return ":warning:"
+ elif self.returncode == 0: # No error
+ return ":white_check_mark:"
+ else: # Exception
+ return ":x:"
+
+ def message(self, job: EvalJob) -> tuple[str, str]:
+ """Return a user-friendly message and error corresponding to the process's return code."""
+ msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}"
+ error = ""
+
+ if self.returncode is None:
+ msg = f"Your {job.version} {job.name} job has failed"
+ error = self.stdout.strip()
+ elif self.returncode == 128 + SIGKILL:
+ msg = f"Your {job.version} {job.name} job timed out or ran out of memory"
+ elif self.returncode == 255:
+ msg = f"Your {job.version} {job.name} job has failed"
+ error = "A fatal NsJail error occurred"
+ else:
+ # Try to append signal's name if one exists
+ with contextlib.suppress(ValueError):
+ name = Signals(self.returncode - 128).name
+ msg = f"{msg} ({name})"
+
+ # Add error message for failed attachments
+ if self.err_files:
+ failed_files = f"({', '.join(self.err_files)})"
+ msg += (
+ f".\n\n> Some attached files were not able to be uploaded {failed_files}."
+ f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}"
+ )
+
+ return msg, error
+
+ @classmethod
+ def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult:
+ """Create an EvalResult from a dict."""
+ res = cls(
+ stdout=data["stdout"],
+ returncode=data["returncode"],
+ )
+
+ for file in data.get("files", []):
+ try:
+ res.files.append(FileAttachment.from_dict(file))
+ except ValueError as e:
+ log.info(f"Failed to parse file from snekbox response: {e}")
+ res.err_files.append(file["path"])
+
+ return res
+
class CodeblockConverter(Converter):
"""Attempts to extract code from a codeblock, if provided."""
@@ -121,22 +231,18 @@ class PythonVersionSwitcherButton(ui.Button):
"""A button that allows users to re-run their eval command in a different Python version."""
def __init__(
- self,
- job_name: str,
- version_to_switch_to: Literal["3.10", "3.11"],
- snekbox_cog: "Snekbox",
- ctx: Context,
- code: str,
- args: Optional[list[str]] = None
+ self,
+ version_to_switch_to: PythonVersion,
+ snekbox_cog: Snekbox,
+ ctx: Context,
+ job: EvalJob,
) -> None:
self.version_to_switch_to = version_to_switch_to
super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)
self.snekbox_cog = snekbox_cog
self.ctx = ctx
- self.job_name = job_name
- self.code = code
- self.args = args
+ self.job = job
async def callback(self, interaction: Interaction) -> None:
"""
@@ -149,13 +255,11 @@ class PythonVersionSwitcherButton(ui.Button):
await interaction.response.defer()
with contextlib.suppress(NotFound):
- # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button.
+ # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button.
# The log arg on send_job will stop the actual job from running.
await interaction.message.delete()
- await self.snekbox_cog.run_job(
- self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args
- )
+ await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to))
class Snekbox(Cog):
@@ -166,13 +270,11 @@ class Snekbox(Cog):
self.jobs = {}
def build_python_version_switcher_view(
- self,
- job_name: str,
- current_python_version: Literal["3.10", "3.11"],
- ctx: Context,
- code: str,
- args: Optional[list[str]] = None
- ) -> None:
+ self,
+ current_python_version: PythonVersion,
+ ctx: Context,
+ job: EvalJob,
+ ) -> interactions.ViewWithUserAndRoleCheck:
"""Return a view that allows the user to change what version of Python their code is run on."""
if current_python_version == "3.10":
alt_python_version = "3.11"
@@ -183,33 +285,25 @@ class Snekbox(Cog):
allowed_users=(ctx.author.id,),
allowed_roles=MODERATION_ROLES,
)
- view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args))
+ view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job))
view.add_item(interactions.DeleteMessageButton())
return view
- async def post_job(
- self,
- code: str,
- python_version: Literal["3.10", "3.11"],
- *,
- args: Optional[list[str]] = None
- ) -> dict:
+ async def post_job(self, job: EvalJob) -> EvalResult:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
- if python_version == "3.10":
+ if job.version == "3.10":
url = URLs.snekbox_eval_api
else:
url = URLs.snekbox_311_eval_api
- data = {"input": code}
-
- if args is not None:
- data["args"] = args
+ data = job.to_dict()
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
return await resp.json()
- async def upload_output(self, output: str) -> Optional[str]:
+ @staticmethod
+ async def upload_output(output: str) -> Optional[str]:
"""Upload the job's output to a paste service and return a URL to it if successful."""
log.trace("Uploading full output to paste service...")
@@ -221,57 +315,18 @@ class Snekbox(Cog):
return "unable to upload"
@staticmethod
- def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:
+ def prepare_timeit_input(codeblocks: list[str]) -> list[str]:
"""
Join the codeblocks into a single string, then return the code and the arguments in a tuple.
If there are multiple codeblocks, insert the first one into the wrapped setup code.
"""
args = ["-m", "timeit"]
- setup = ""
- if len(codeblocks) > 1:
- setup = codeblocks.pop(0)
-
+ setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else ""
code = "\n".join(codeblocks)
- args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
-
- return code, args
-
- @staticmethod
- def get_results_message(results: dict, job_name: str, python_version: Literal["3.10", "3.11"]) -> Tuple[str, str]:
- """Return a user-friendly message and error corresponding to the process's return code."""
- stdout, returncode = results["stdout"], results["returncode"]
- msg = f"Your {python_version} {job_name} job has completed with return code {returncode}"
- error = ""
-
- if returncode is None:
- msg = f"Your {python_version} {job_name} job has failed"
- error = stdout.strip()
- elif returncode == 128 + SIGKILL:
- msg = f"Your {python_version} {job_name} job timed out or ran out of memory"
- elif returncode == 255:
- msg = f"Your {python_version} {job_name} job has failed"
- error = "A fatal NsJail error occurred"
- else:
- # Try to append signal's name if one exists
- try:
- name = Signals(returncode - 128).name
- msg = f"{msg} ({name})"
- except ValueError:
- pass
-
- return msg, error
-
- @staticmethod
- def get_status_emoji(results: dict) -> str:
- """Return an emoji corresponding to the status code or lack of output in result."""
- if not results["stdout"].strip(): # No output
- return ":warning:"
- elif results["returncode"] == 0: # No error
- return ":white_check_mark:"
- else: # Exception
- return ":x:"
+ args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code])
+ return args
async def format_output(self, output: str) -> Tuple[str, Optional[str]]:
"""
@@ -320,42 +375,37 @@ class Snekbox(Cog):
return output, paste_link
@lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True)
- async def send_job(
- self,
- ctx: Context,
- python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
- job_name: str
- ) -> Message:
+ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_job(code, python_version, args=args)
- msg, error = self.get_results_message(results, job_name, python_version)
+ result = await self.post_job(job)
+ msg, error = result.message(job)
if error:
output, paste_link = error, None
else:
log.trace("Formatting output...")
- output, paste_link = await self.format_output(results["stdout"])
+ output, paste_link = await self.format_output(result.stdout)
+
+ if result.files and output in ("[No output]", ""):
+ msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"
+ else:
+ msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n\n```\n{output}\n```"
- icon = self.get_status_emoji(results)
- msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```"
if paste_link:
msg = f"{msg}\nFull output: {paste_link}"
# Collect stats of job fails + successes
- if icon == ":x:":
+ if result.returncode != 0:
self.bot.stats.incr("snekbox.python.fail")
else:
self.bot.stats.incr("snekbox.python.success")
- filter_cog = self.bot.get_cog("Filtering")
+ filter_cog: Filtering | None = self.bot.get_cog("Filtering")
filter_triggered = False
if filter_cog:
filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message)
@@ -363,21 +413,24 @@ class Snekbox(Cog):
response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
else:
allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
- view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args)
- response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view)
+ view = self.build_python_version_switcher_view(job.version, ctx, job)
+
+ # Attach files if provided
+ files = [f.to_file() for f in result.files]
+ response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files)
view.message = response
- log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")
+ log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}")
return response
async def continue_job(
self, ctx: Context, response: Message, job_name: str
- ) -> tuple[Optional[str], Optional[list[str]]]:
+ ) -> EvalJob | None:
"""
Check if the job's session should continue.
- If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command.
- Otherwise return (None, None) if the job's session should be terminated.
+ If the code is to be re-evaluated, return the new EvalJob.
+ Otherwise, return None if the job's session should be terminated.
"""
_predicate_message_edit = partial(predicate_message_edit, ctx)
_predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx)
@@ -399,7 +452,7 @@ class Snekbox(Cog):
# Ensure the response that's about to be edited is still the most recent.
# This could have already been updated via a button press to switch to an alt Python version.
if self.jobs[ctx.message.id] != response.id:
- return None, None
+ return None
code = await self.get_code(new_message, ctx.command)
with contextlib.suppress(HTTPException):
@@ -407,21 +460,20 @@ class Snekbox(Cog):
await response.delete()
if code is None:
- return None, None
+ return None
except asyncio.TimeoutError:
- with contextlib.suppress(HTTPException):
- await ctx.message.clear_reaction(REDO_EMOJI)
- return None, None
+ await ctx.message.clear_reaction(REDO_EMOJI)
+ return None
codeblocks = await CodeblockConverter.convert(ctx, code)
if job_name == "timeit":
- return self.prepare_timeit_input(codeblocks)
+ return EvalJob(self.prepare_timeit_input(codeblocks))
else:
- return "\n".join(codeblocks), None
+ return EvalJob.from_code("\n".join(codeblocks))
- return None, None
+ return None
async def get_code(self, message: Message, command: Command) -> Optional[str]:
"""
@@ -445,12 +497,8 @@ class Snekbox(Cog):
async def run_job(
self,
- job_name: str,
ctx: Context,
- python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
+ job: EvalJob,
) -> None:
"""Handles checks, stats and re-evaluation of a snekbox job."""
if Roles.helpers in (role.id for role in ctx.author.roles):
@@ -465,11 +513,11 @@ class Snekbox(Cog):
else:
self.bot.stats.incr("snekbox_usages.channels.topical")
- log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
+ log.info(f"Received code from {ctx.author} for evaluation:\n{job}")
while True:
try:
- response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name)
+ response = await self.send_job(ctx, job)
except LockedResourceError:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
@@ -477,15 +525,15 @@ class Snekbox(Cog):
)
return
- # Store the bot's response message id per invocation, to ensure the `wait_for`s in `continue_job`
+ # Store the bots response message id per invocation, to ensure the `wait_for`s in `continue_job`
# don't trigger if the response has already been replaced by a new response.
# This can happen when a button is pressed and then original code is edited and re-run.
self.jobs[ctx.message.id] = response.id
- code, args = await self.continue_job(ctx, response, job_name)
- if not code:
+ job = await self.continue_job(ctx, response, job.name)
+ if not job:
break
- log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
+ log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}")
@command(name="eval", aliases=("e",), usage="[python_version] <code, ...>")
@guild_only()
@@ -499,7 +547,7 @@ class Snekbox(Cog):
async def eval_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
@@ -511,17 +559,19 @@ class Snekbox(Cog):
clicking the reaction that subsequently appears.
If multiple codeblocks are in a message, all of them will be joined and evaluated,
- ignoring the text outside of them.
+ ignoring the text outside them.
- By default your code is run on Python's 3.11 beta release, to assist with testing. If you
+ By default, your code is run on Python's 3.11 beta release, to assist with testing. If you
run into issues related to this Python version, you can request the bot to use Python
3.10 by specifying the `python_version` arg and setting it to `3.10`.
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- await self.run_job("eval", ctx, python_version, "\n".join(code))
+ job = EvalJob.from_code("\n".join(code)).as_version(python_version)
+ await self.run_job(ctx, job)
@command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")
@guild_only()
@@ -535,7 +585,7 @@ class Snekbox(Cog):
async def timeit_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
@@ -556,10 +606,12 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- code, args = self.prepare_timeit_input(code)
+ args = self.prepare_timeit_input(code)
+ job = EvalJob(args, version=python_version, name="timeit")
- await self.run_job("timeit", ctx, python_version, code=code, args=args)
+ await self.run_job(ctx, job)
def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
diff --git a/bot/exts/utils/snekio.py b/bot/exts/utils/snekio.py
new file mode 100644
index 000000000..7c5fba648
--- /dev/null
+++ b/bot/exts/utils/snekio.py
@@ -0,0 +1,64 @@
+"""I/O File protocols for snekbox."""
+from __future__ import annotations
+
+from base64 import b64decode, b64encode
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+
+from discord import File
+
+# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers
+FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB
+
+
+def sizeof_fmt(num: int, suffix: str = "B") -> str:
+ """Return a human-readable file size."""
+ for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
+ if abs(num) < 1024:
+ return f"{num:3.1f}{unit}{suffix}"
+ num /= 1024
+ return f"{num:.1f}Yi{suffix}"
+
+
+@dataclass
+class FileAttachment:
+ """File Attachment from Snekbox eval."""
+
+ path: str
+ content: bytes
+
+ def __repr__(self) -> str:
+ """Return the content as a string."""
+ content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content
+ return f"FileAttachment(path={self.path!r}, content={content})"
+
+ @classmethod
+ def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment:
+ """Create a FileAttachment from a dict response."""
+ size = data.get("size")
+ if (size and size > size_limit) or (len(data["content"]) > size_limit):
+ raise ValueError("File size exceeds limit")
+
+ content = b64decode(data["content"])
+
+ if len(content) > size_limit:
+ raise ValueError("File size exceeds limit")
+
+ return cls(data["path"], content)
+
+ def to_dict(self) -> dict[str, str]:
+ """Convert the attachment to a json dict."""
+ content = self.content
+ if isinstance(content, str):
+ content = content.encode("utf-8")
+
+ return {
+ "path": self.path,
+ "content": b64encode(content).decode("ascii"),
+ }
+
+ def to_file(self) -> File:
+ """Convert to a discord.File."""
+ name = Path(self.path).name
+ return File(BytesIO(self.content), filename=name)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index b1f32c210..b52159101 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -1,5 +1,6 @@
import asyncio
import unittest
+from base64 import b64encode
from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
from discord import AllowedMentions
@@ -8,7 +9,7 @@ from discord.ext import commands
from bot import constants
from bot.errors import LockedResourceError
from bot.exts.utils import snekbox
-from bot.exts.utils.snekbox import Snekbox
+from bot.exts.utils.snekbox import EvalJob, Snekbox, EvalResult
from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser
@@ -17,6 +18,12 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Add mocked bot and cog to the instance."""
self.bot = MockBot()
self.cog = Snekbox(bot=self.bot)
+ self.job = EvalJob.from_code("import random")
+
+ @staticmethod
+ def code_args(code: str) -> tuple[EvalJob]:
+ """Converts code to a tuple of arguments expected."""
+ return EvalJob.from_code(code),
async def test_post_job(self):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
@@ -27,10 +34,21 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
context_manager.__aenter__.return_value = resp
self.bot.http_session.post.return_value = context_manager
- self.assertEqual(await self.cog.post_job("import random", "3.10"), "return")
+ job = EvalJob.from_code("import random")
+ self.assertEqual(await self.cog.post_job(job), "return")
+
+ expected = {
+ "args": ["main.py"],
+ "files": [
+ {
+ "path": "main.py",
+ "content": b64encode("import random".encode()).decode()
+ }
+ ]
+ }
self.bot.http_session.post.assert_called_with(
constants.URLs.snekbox_eval_api,
- json={"input": "import random"},
+ json=expected,
raise_for_status=True
)
resp.json.assert_awaited_once()
@@ -76,14 +94,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
(['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code')
)
- for case, setup_code, testname in cases:
+ for case, setup_code, test_name in cases:
setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code)
- expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup])
- with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
+ expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)]
+ with self.subTest(msg=f'Test with {test_name} and expected return {expected}'):
self.assertEqual(self.cog.prepare_timeit_input(case), expected)
- def test_get_results_message(self):
- """Return error and message according to the eval result."""
+ def test_eval_result_message(self):
+ """EvalResult.message, should return error and message."""
cases = (
('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')),
('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')),
@@ -91,25 +109,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for stdout, returncode, expected in cases:
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11')
- self.assertEqual(actual, expected)
+ result = EvalResult(stdout=stdout, returncode=returncode)
+ job = EvalJob([])
+ self.assertEqual(result.message(job), expected)
@patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)
- def test_get_results_message_invalid_signal(self, mock_signals: Mock):
+ def test_eval_result_message_invalid_signal(self, _mock_signals: Mock):
+ result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
- ('Your 3.11 eval job has completed with return code 127', '')
+ result.message(EvalJob([], version="3.10")),
+ ("Your 3.10 eval job has completed with return code 127", "")
)
@patch('bot.exts.utils.snekbox.Signals')
- def test_get_results_message_valid_signal(self, mock_signals: Mock):
- mock_signals.return_value.name = 'SIGTEST'
+ def test_eval_result_message_valid_signal(self, mock_signals: Mock):
+ mock_signals.return_value.name = "SIGTEST"
+ result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'),
- ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '')
+ result.message(EvalJob([], version="3.11")),
+ ("Your 3.11 eval job has completed with return code 127 (SIGTEST)", "")
)
- def test_get_status_emoji(self):
+ def test_eval_result_status_emoji(self):
"""Return emoji according to the eval result."""
cases = (
(' ', -1, ':warning:'),
@@ -118,8 +139,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for stdout, returncode, expected in cases:
with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
- actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
- self.assertEqual(actual, expected)
+ result = EvalResult(stdout=stdout, returncode=returncode)
+ self.assertEqual(result.status_emoji, expected)
async def test_format_output(self):
"""Test output formatting."""
@@ -178,10 +199,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.command = MagicMock()
self.cog.send_job = AsyncMock(return_value=response)
- self.cog.continue_job = AsyncMock(return_value=(None, None))
+ self.cog.continue_job = AsyncMock(return_value=None)
await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
- self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode")
+ self.cog.send_job.assert_called_once_with(ctx, job)
self.cog.continue_job.assert_called_once_with(ctx, response, 'eval')
async def test_eval_command_evaluate_twice(self):
@@ -191,13 +213,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.command = MagicMock()
self.cog.send_job = AsyncMock(return_value=response)
self.cog.continue_job = AsyncMock()
- self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None))
+ self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None)
await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode'])
- self.cog.send_job.assert_called_with(
- ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval'
- )
- self.cog.continue_job.assert_called_with(ctx, response, 'eval')
+
+ expected_job = EvalJob.from_code("MyAwesomeFormattedCode")
+ self.cog.send_job.assert_called_with(ctx, expected_job)
+ self.cog.continue_job.assert_called_with(ctx, response, "eval")
async def test_eval_command_reject_two_eval_at_the_same_time(self):
"""Test if the eval command rejects an eval if the author already have a running eval."""
@@ -212,8 +234,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect)
with self.assertRaises(LockedResourceError):
await asyncio.gather(
- self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
- self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'),
+ self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")),
+ self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")),
)
async def test_send_job(self):
@@ -223,30 +245,31 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author = MockUser(mention='@LemonLemonishBeard#0042')
- self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0})
- self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
- self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ eval_result = EvalResult("", 0)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.format_output = AsyncMock(return_value=('[No output]', None))
+ self.cog.upload_output = AsyncMock() # Should not be called
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code('MyAwesomeCode')
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```'
+ '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed '
+ 'with return code 0.\n\n```\n[No output]\n```'
)
allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions']
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11')
+ self.cog.post_job.assert_called_once_with(job)
self.cog.format_output.assert_called_once_with('')
+ self.cog.upload_output.assert_not_called()
async def test_send_job_with_paste_link(self):
"""Test the send_job function with a too long output that generate a paste link."""
@@ -255,29 +278,26 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
- self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ eval_result = EvalResult("Way too long beard", 0)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :yay!: Return code 0.'
+ '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job '
+ 'has completed with return code 0.'
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
- self.cog.get_results_message.assert_called_once_with(
- {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11'
- )
+ self.cog.post_job.assert_called_once_with(job)
self.cog.format_output.assert_called_once_with('Way too long beard')
async def test_send_job_with_non_zero_eval(self):
@@ -286,27 +306,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author.mention = '@LemonLemonishBeard#0042'
- self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
- self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
- self.cog.format_output = AsyncMock() # This function isn't called
+
+ eval_result = EvalResult("ERROR", 127)
+ self.cog.post_job = AsyncMock(return_value=eval_result)
+ self.cog.upload_output = AsyncMock() # This function isn't called
mocked_filter_cog = MagicMock()
mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval')
+ job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
+ await self.cog.send_job(ctx, job),
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
+ '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.'
+ '\n\n```\nERROR\n```'
)
- self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None)
- self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
- self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11')
- self.cog.format_output.assert_not_called()
+ self.cog.post_job.assert_called_once_with(job)
+ self.cog.upload_output.assert_not_called()
@patch("bot.exts.utils.snekbox.partial")
async def test_continue_job_does_continue(self, partial_mock):
@@ -328,7 +348,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
actual = await self.cog.continue_job(ctx, response, self.cog.eval_command)
self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command)
- self.assertEqual(actual, (expected, None))
+ self.assertEqual(actual, EvalJob.from_code(expected))
self.bot.wait_for.assert_has_awaits(
(
call(
@@ -348,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.wait_for.side_effect = asyncio.TimeoutError
actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command)
- self.assertEqual(actual, (None, None))
+ self.assertEqual(actual, None)
ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)
async def test_get_code(self):