aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox.py194
1 files changed, 135 insertions, 59 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 1223b89ca..93941ed4c 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -4,14 +4,15 @@ import asyncio
import contextlib
import re
import zlib
-from base64 import b64decode
-from dataclasses import dataclass
+from base64 import b64decode, b64encode
+from collections.abc import Iterable
+from dataclasses import dataclass, field
from functools import partial
from io import BytesIO
from operator import attrgetter
from signal import Signals
from textwrap import dedent
-from typing import Literal, Optional, Tuple
+from typing import Generic, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar
from botcore.utils import interactions
from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
@@ -26,6 +27,9 @@ from bot.utils import send_to_paste_service
from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.services import PasteTooLongError, PasteUploadError
+if TYPE_CHECKING:
+ from bot.exts.filters.filtering import Filtering
+
log = get_logger(__name__)
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
@@ -84,29 +88,100 @@ SIGKILL = 9
REDO_EMOJI = '\U0001f501' # :repeat:
REDO_TIMEOUT = 30
+# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers
+FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB
+
+T = TypeVar("T")
+
+
+def sizeof_fmt(num: int, suffix: str = "B") -> str:
+ """Return a human-readable file size."""
+ for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
+ if abs(num) < 1024:
+ return f"{num:3.1f}{unit}{suffix}"
+ num /= 1024
+ return f"{num:.1f}Yi{suffix}"
+
@dataclass
-class FileAttachment:
+class FileAttachment(Generic[T]):
"""File Attachment from Snekbox eval."""
name: str
- mime: str
- content: bytes
+ content: T
+
+ def __repr__(self) -> str:
+ """Return the content as a string."""
+ content = self.content if isinstance(self.content, str) else "(...)"
+ return f"FileAttachment(name={self.name}, content={content})"
@classmethod
- def from_dict(cls, data: dict) -> FileAttachment:
+ def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment[bytes]:
"""Create a FileAttachment from a dict response."""
- return cls(
- data["name"],
- data["mime"],
- zlib.decompress(b64decode(data["content"])),
- )
+ size = data.get("size")
+ if (size and size > size_limit) or (len(data["content"]) > size_limit):
+ raise ValueError("File size exceeds limit")
+
+ match data.get("content-encoding"):
+ case "base64+zlib":
+ content = zlib.decompress(b64decode(data["content"]))
+ case "base64":
+ content = b64decode(data["content"])
+ case _:
+ content = data["content"]
+
+ if len(content) > size_limit:
+ raise ValueError("File size exceeds limit")
+
+ return cls(data["name"], content)
+
+ def to_json(self) -> dict[str, str]:
+ """Convert the attachment to a json dict."""
+ if isinstance(self.content, bytes):
+ content = b64encode(self.content).decode("ascii")
+ encoding = "base64"
+ else:
+ content = self.content
+ encoding = ""
+
+ return {
+ "name": self.name,
+ "content-encoding": encoding,
+ "content": content,
+ }
def to_file(self) -> File:
"""Convert to a discord.File."""
return File(BytesIO(self.content), filename=self.name)
+@dataclass
+class EvalJob:
+ """Represents a job to be evaluated by Snekbox."""
+
+ args: list[str]
+ files: list[FileAttachment] = field(default_factory=list)
+
+ def __str__(self) -> str:
+ """Return the job as a string."""
+ return f"EvalJob(args={self.args}, files={self.files})"
+
+ @classmethod
+ def from_code(cls, code: str, files: Iterable[FileAttachment] = (), name: str = "main.py") -> EvalJob:
+ """Create an EvalJob from a code string."""
+ return cls(
+ args=[name],
+ files=[FileAttachment(name, code), *files],
+ )
+
+ def to_json(self) -> dict[str, list[str | dict[str, str]]]:
+ """Convert the job to a dict."""
+ return {
+ "args": self.args,
+ "files": [file.to_json() for file in self.files],
+ }
+
+
class CodeblockConverter(Converter):
"""Attempts to extract code from a codeblock, if provided."""
@@ -151,10 +226,9 @@ class PythonVersionSwitcherButton(ui.Button):
self,
job_name: str,
version_to_switch_to: Literal["3.10", "3.11"],
- snekbox_cog: "Snekbox",
+ snekbox_cog: Snekbox,
ctx: Context,
- code: str,
- args: Optional[list[str]] = None
+ job: EvalJob,
) -> None:
self.version_to_switch_to = version_to_switch_to
super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)
@@ -162,8 +236,7 @@ class PythonVersionSwitcherButton(ui.Button):
self.snekbox_cog = snekbox_cog
self.ctx = ctx
self.job_name = job_name
- self.code = code
- self.args = args
+ self.job = job
async def callback(self, interaction: Interaction) -> None:
"""
@@ -176,13 +249,11 @@ class PythonVersionSwitcherButton(ui.Button):
await interaction.response.defer()
with contextlib.suppress(NotFound):
- # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button.
+ # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button.
# The log arg on send_job will stop the actual job from running.
await interaction.message.delete()
- await self.snekbox_cog.run_job(
- self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args
- )
+ await self.snekbox_cog.run_job(self.job_name, self.ctx, self.version_to_switch_to, self.job)
class Snekbox(Cog):
@@ -197,8 +268,7 @@ class Snekbox(Cog):
job_name: str,
current_python_version: Literal["3.10", "3.11"],
ctx: Context,
- code: str,
- args: Optional[list[str]] = None
+ job: EvalJob,
) -> interactions.ViewWithUserAndRoleCheck:
"""Return a view that allows the user to change what version of Python their code is run on."""
if current_python_version == "3.10":
@@ -210,17 +280,15 @@ class Snekbox(Cog):
allowed_users=(ctx.author.id,),
allowed_roles=MODERATION_ROLES,
)
- view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args))
+ view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, job))
view.add_item(interactions.DeleteMessageButton())
return view
async def post_job(
self,
- code: str,
+ job: EvalJob,
python_version: Literal["3.10", "3.11"],
- *,
- args: Optional[list[str]] = None
) -> dict:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
if python_version == "3.10":
@@ -228,10 +296,7 @@ class Snekbox(Cog):
else:
url = URLs.snekbox_311_eval_api
- data = {"input": code}
-
- if args is not None:
- data["args"] = args
+ data = {"args": job.args, "files": [f.to_json() for f in job.files]}
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
return await resp.json()
@@ -248,30 +313,34 @@ class Snekbox(Cog):
return "unable to upload"
@staticmethod
- def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:
+ def prepare_timeit_input(codeblocks: list[str]) -> list[str]:
"""
Join the codeblocks into a single string, then return the code and the arguments in a tuple.
If there are multiple codeblocks, insert the first one into the wrapped setup code.
"""
args = ["-m", "timeit"]
- setup = ""
- if len(codeblocks) > 1:
- setup = codeblocks.pop(0)
-
+ setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else ""
code = "\n".join(codeblocks)
- args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
-
- return code, args
+ args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code])
+ return args
@staticmethod
def get_results_message(
- results: dict, job_name: str, python_version: Literal["3.10", "3.11"]
+ results: dict, job_name: str, python_version: Literal["3.10", "3.11"]
) -> Tuple[str, str, list[FileAttachment]]:
"""Return a user-friendly message and error corresponding to the process's return code."""
stdout, returncode = results["stdout"], results["returncode"]
- attachments = [FileAttachment.from_dict(d) for d in results["attachments"]]
+
+ attachments: list[FileAttachment] = []
+ failed_attachments: list[str] = []
+ for attachment in results["attachments"]:
+ try:
+ attachments.append(FileAttachment.from_dict(attachment))
+ except ValueError:
+ failed_attachments.append(attachment["name"])
+
msg = f"Your {python_version} {job_name} job has completed with return code {returncode}"
error = ""
@@ -291,6 +360,14 @@ class Snekbox(Cog):
except ValueError:
pass
+ # Add error message for failed attachments
+ if failed_attachments:
+ failed_files = f"({', '.join(failed_attachments)})"
+ msg += (
+ f".\n\n> Some attached files were not able to be uploaded {failed_files}."
+ f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}"
+ )
+
return msg, error, attachments
@staticmethod
@@ -352,12 +429,10 @@ class Snekbox(Cog):
@lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True)
async def send_job(
self,
+ job_name: str,
ctx: Context,
python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
- job_name: str
+ job: EvalJob,
) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
@@ -365,7 +440,7 @@ class Snekbox(Cog):
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_job(code, python_version, args=args)
+ results = await self.post_job(job, python_version)
msg, error, attachments = self.get_results_message(results, job_name, python_version)
if error:
@@ -390,7 +465,7 @@ class Snekbox(Cog):
else:
self.bot.stats.incr("snekbox.python.success")
- filter_cog = self.bot.get_cog("Filtering")
+ filter_cog: Filtering | None = self.bot.get_cog("Filtering")
filter_triggered = False
if filter_cog:
filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message)
@@ -398,7 +473,7 @@ class Snekbox(Cog):
response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
else:
allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
- view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args)
+ view = self.build_python_version_switcher_view(job_name, python_version, ctx, job)
# Attach file if provided
files = [atc.to_file() for atc in attachments]
@@ -485,9 +560,7 @@ class Snekbox(Cog):
job_name: str,
ctx: Context,
python_version: Literal["3.10", "3.11"],
- code: str,
- *,
- args: Optional[list[str]] = None,
+ job: EvalJob,
) -> None:
"""Handles checks, stats and re-evaluation of a snekbox job."""
if Roles.helpers in (role.id for role in ctx.author.roles):
@@ -502,11 +575,11 @@ class Snekbox(Cog):
else:
self.bot.stats.incr("snekbox_usages.channels.topical")
- log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
+ log.info(f"Received code from {ctx.author} for evaluation:\n{job}")
while True:
try:
- response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name)
+ response = await self.send_job(job_name, ctx, python_version, job)
except LockedResourceError:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
@@ -514,7 +587,7 @@ class Snekbox(Cog):
)
return
- # Store the bot's response message id per invocation, to ensure the `wait_for`s in `continue_job`
+ # Store the bots response message id per invocation, to ensure the `wait_for`s in `continue_job`
# don't trigger if the response has already been replaced by a new response.
# This can happen when a button is pressed and then original code is edited and re-run.
self.jobs[ctx.message.id] = response.id
@@ -548,17 +621,19 @@ class Snekbox(Cog):
clicking the reaction that subsequently appears.
If multiple codeblocks are in a message, all of them will be joined and evaluated,
- ignoring the text outside of them.
+ ignoring the text outside them.
- By default your code is run on Python's 3.11 beta release, to assist with testing. If you
+ By default, your code is run on Python's 3.11 beta release, to assist with testing. If you
run into issues related to this Python version, you can request the bot to use Python
3.10 by specifying the `python_version` arg and setting it to `3.10`.
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- await self.run_job("eval", ctx, python_version, "\n".join(code))
+ job = EvalJob.from_code("\n".join(code))
+ await self.run_job("eval", ctx, python_version, job)
@command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")
@guild_only()
@@ -593,10 +668,11 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
+ code: list[str]
python_version = python_version or "3.11"
- code, args = self.prepare_timeit_input(code)
+ args = self.prepare_timeit_input(code)
- await self.run_job("timeit", ctx, python_version, code=code, args=args)
+ await self.run_job("timeit", ctx, python_version, EvalJob(args))
def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: