diff options
| author | 2022-11-30 07:48:12 +0800 | |
|---|---|---|
| committer | 2022-11-30 07:48:12 +0800 | |
| commit | 51e7e17e06aedc9f05347a5196faf137dc5a00ae (patch) | |
| tree | 197f3abfd3654c1c171f178242f345504deacb32 | |
| parent | Update unit test (diff) | |
Refactors for EvalResult and EvalJob dataclasses
Diffstat (limited to '')
| -rw-r--r-- | bot/exts/utils/snekbox.py | 296 | ||||
| -rw-r--r-- | bot/exts/utils/snekio.py | 64 | 
2 files changed, 185 insertions, 175 deletions
| diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 75c6f2d3a..d8a3088a6 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -3,25 +3,22 @@ from __future__ import annotations  import asyncio  import contextlib  import re -from base64 import b64decode, b64encode -from collections.abc import Iterable  from dataclasses import dataclass, field  from functools import partial -from io import BytesIO  from operator import attrgetter -from pathlib import Path  from signal import Signals  from textwrap import dedent -from typing import Generic, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar +from typing import Literal, Optional, TYPE_CHECKING, Tuple  from botcore.utils import interactions  from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from discord import AllowedMentions, File, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui +from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui  from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only  from bot.bot import Bot  from bot.constants import Categories, Channels, MODERATION_ROLES, Roles, URLs  from bot.decorators import redirect_output +from bot.exts.utils.snekio import FileAttachment, sizeof_fmt, FILE_SIZE_LIMIT  from bot.log import get_logger  from bot.utils import send_to_paste_service  from bot.utils.lock import LockedResourceError, lock_arg @@ -88,89 +85,108 @@ SIGKILL = 9  REDO_EMOJI = '\U0001f501'  # :repeat:  REDO_TIMEOUT = 30 -# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers -FILE_SIZE_LIMIT = 8 * 1024 * 1024  # 8 MiB - -T = TypeVar("T") - - -def sizeof_fmt(num: int, suffix: str = "B") -> str: -    """Return a human-readable file size.""" -    for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): -        if abs(num) < 1024: -            return f"{num:3.1f}{unit}{suffix}" -        num /= 1024 -    return f"{num:.1f}Yi{suffix}" +PythonVersion = Literal["3.10", "3.11"]  @dataclass -class FileAttachment(Generic[T]): -    """File Attachment from Snekbox eval.""" - -    path: str -    content: T +class EvalJob: +    """Job to be evaluated by snekbox.""" -    def __repr__(self) -> str: -        """Return the content as a string.""" -        content = self.content if isinstance(self.content, str) else "(...)" -        return f"FileAttachment(path={self.path}, content={content})" +    args: list[str] +    files: list[FileAttachment] = field(default_factory=list) +    name: str = "eval" +    version: PythonVersion = "3.11"      @classmethod -    def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment[bytes]: -        """Create a FileAttachment from a dict response.""" -        size = data.get("size") -        if (size and size > size_limit) or (len(data["content"]) > size_limit): -            raise ValueError("File size exceeds limit") +    def from_code(cls, code: str, path: str = "main.py") -> EvalJob: +        """Create an EvalJob from a code string.""" +        return cls( +            args=[path], +            files=[FileAttachment(path, code.encode())], +        ) -        content = b64decode(data["content"]) +    def as_version(self, version: PythonVersion) -> EvalJob: +        """Return a copy of the job with a different Python version.""" +        return EvalJob( +            args=self.args, +            files=self.files, +            name=self.name, +            version=version, +        ) -        if len(content) > size_limit: -            raise ValueError("File size exceeds limit") +    def to_dict(self) -> dict[str, list[str | dict[str, str]]]: +        """Convert the job to a dict.""" +        return { +            "args": self.args, +            "files": [file.to_dict() for file in self.files], +        } -        return cls(data["path"], content) -    def to_json(self) -> dict[str, str]: -        """Convert the attachment to a json dict.""" -        content = self.content -        if isinstance(content, str): -            content = content.encode("utf-8") +@dataclass(frozen=True) +class EvalResult: +    """The result of an eval job.""" -        return { -            "path": self.path, -            "content": b64encode(content).decode("ascii"), -        } +    stdout: str +    returncode: int | None +    files: list[FileAttachment] = field(default_factory=list) +    err_files: list[str] = field(default_factory=list) -    def to_file(self) -> File: -        """Convert to a discord.File.""" -        name = Path(self.path).name -        return File(BytesIO(self.content), filename=name) +    @property +    def status_emoji(self): +        """Return an emoji corresponding to the status code or lack of output in result.""" +        # If there are attachments, skip empty output warning +        if not self.stdout.strip() and not self.files:  # No output +            return ":warning:" +        elif self.returncode == 0:  # No error +            return ":white_check_mark:" +        else:  # Exception +            return ":x:" +    def message(self, job: EvalJob) -> tuple[str, str]: +        """Return a user-friendly message and error corresponding to the process's return code.""" +        msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}" +        error = "" -@dataclass -class EvalJob: -    """Represents a job to be evaluated by Snekbox.""" +        if self.returncode is None: +            msg = f"Your {job.version} {job.name} job has failed" +            error = self.stdout.strip() +        elif self.returncode == 128 + SIGKILL: +            msg = f"Your {job.version} {job.name} job timed out or ran out of memory" +        elif self.returncode == 255: +            msg = f"Your {job.version} {job.name} job has failed" +            error = "A fatal NsJail error occurred" +        else: +            # Try to append signal's name if one exists +            with contextlib.suppress(ValueError): +                name = Signals(self.returncode - 128).name +                msg = f"{msg} ({name})" -    args: list[str] -    files: list[FileAttachment] = field(default_factory=list) +        # Add error message for failed attachments +        if self.err_files: +            failed_files = f"({', '.join(self.err_files)})" +            msg += ( +                f".\n\n> Some attached files were not able to be uploaded {failed_files}." +                f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" +            ) -    def __repr__(self) -> str: -        """Return the job as a string.""" -        return f"EvalJob(args={self.args}, files={self.files})" +        return msg, error      @classmethod -    def from_code(cls, code: str, files: Iterable[FileAttachment] = (), name: str = "main.py") -> EvalJob: -        """Create an EvalJob from a code string.""" -        return cls( -            args=[name], -            files=[FileAttachment(name, code), *files], +    def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: +        """Create an EvalResult from a dict.""" +        res = cls( +            stdout=data["stdout"], +            returncode=data["returncode"],          ) -    def to_json(self) -> dict[str, list[str | dict[str, str]]]: -        """Convert the job to a dict.""" -        return { -            "args": self.args, -            "files": [file.to_json() for file in self.files], -        } +        for file in data.get("files", []): +            try: +                res.files.append(FileAttachment.from_dict(file)) +            except ValueError as e: +                log.info(f"Failed to parse file from snekbox response: {e}") +                res.err_files.append(file["path"]) + +        return res  class CodeblockConverter(Converter): @@ -214,19 +230,17 @@ class PythonVersionSwitcherButton(ui.Button):      """A button that allows users to re-run their eval command in a different Python version."""      def __init__( -        self, -        job_name: str, -        version_to_switch_to: Literal["3.10", "3.11"], -        snekbox_cog: Snekbox, -        ctx: Context, -        job: EvalJob, +            self, +            version_to_switch_to: PythonVersion, +            snekbox_cog: Snekbox, +            ctx: Context, +            job: EvalJob,      ) -> None:          self.version_to_switch_to = version_to_switch_to          super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)          self.snekbox_cog = snekbox_cog          self.ctx = ctx -        self.job_name = job_name          self.job = job      async def callback(self, interaction: Interaction) -> None: @@ -244,7 +258,7 @@ class PythonVersionSwitcherButton(ui.Button):              # The log arg on send_job will stop the actual job from running.              await interaction.message.delete() -        await self.snekbox_cog.run_job(self.job_name, self.ctx, self.version_to_switch_to, self.job) +        await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to))  class Snekbox(Cog): @@ -255,11 +269,10 @@ class Snekbox(Cog):          self.jobs = {}      def build_python_version_switcher_view( -        self, -        job_name: str, -        current_python_version: Literal["3.10", "3.11"], -        ctx: Context, -        job: EvalJob, +            self, +            current_python_version: PythonVersion, +            ctx: Context, +            job: EvalJob,      ) -> interactions.ViewWithUserAndRoleCheck:          """Return a view that allows the user to change what version of Python their code is run on."""          if current_python_version == "3.10": @@ -271,28 +284,25 @@ class Snekbox(Cog):              allowed_users=(ctx.author.id,),              allowed_roles=MODERATION_ROLES,          ) -        view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, job)) +        view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job))          view.add_item(interactions.DeleteMessageButton())          return view -    async def post_job( -        self, -        job: EvalJob, -        python_version: Literal["3.10", "3.11"], -    ) -> dict: +    async def post_job(self, job: EvalJob) -> EvalResult:          """Send a POST request to the Snekbox API to evaluate code and return the results.""" -        if python_version == "3.10": +        if job.version == "3.10":              url = URLs.snekbox_eval_api          else:              url = URLs.snekbox_311_eval_api -        data = {"args": job.args, "files": [f.to_json() for f in job.files]} +        data = job.to_dict()          async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:              return await resp.json() -    async def upload_output(self, output: str) -> Optional[str]: +    @staticmethod +    async def upload_output(output: str) -> Optional[str]:          """Upload the job's output to a paste service and return a URL to it if successful."""          log.trace("Uploading full output to paste service...") @@ -317,61 +327,6 @@ class Snekbox(Cog):          args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code])          return args -    @staticmethod -    def get_results_message( -        results: dict, job_name: str, python_version: Literal["3.10", "3.11"] -    ) -> Tuple[str, str, list[FileAttachment]]: -        """Return a user-friendly message and error corresponding to the process's return code.""" -        stdout, returncode = results["stdout"], results["returncode"] - -        attachments: list[FileAttachment] = [] -        failed_attachments: list[str] = [] -        for attachment in results.get("files", []): -            try: -                attachments.append(FileAttachment.from_dict(attachment)) -            except ValueError: -                failed_attachments.append(attachment["path"]) - -        msg = f"Your {python_version} {job_name} job has completed with return code {returncode}" -        error = "" - -        if returncode is None: -            msg = f"Your {python_version} {job_name} job has failed" -            error = stdout.strip() -        elif returncode == 128 + SIGKILL: -            msg = f"Your {python_version} {job_name} job timed out or ran out of memory" -        elif returncode == 255: -            msg = f"Your {python_version} {job_name} job has failed" -            error = "A fatal NsJail error occurred" -        else: -            # Try to append signal's name if one exists -            try: -                name = Signals(returncode - 128).name -                msg = f"{msg} ({name})" -            except ValueError: -                pass - -        # Add error message for failed attachments -        if failed_attachments: -            failed_files = f"({', '.join(failed_attachments)})" -            msg += ( -                f".\n\n> Some attached files were not able to be uploaded {failed_files}." -                f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" -            ) - -        return msg, error, attachments - -    @staticmethod -    def get_status_emoji(results: dict) -> str: -        """Return an emoji corresponding to the status code or lack of output in result.""" -        # If there are attachments, skip empty output warning -        if not results["stdout"].strip() and not results.get("files"):  # No output -            return ":warning:" -        elif results["returncode"] == 0:  # No error -            return ":white_check_mark:" -        else:  # Exception -            return ":x:" -      async def format_output(self, output: str) -> Tuple[str, Optional[str]]:          """          Format the output and return a tuple of the formatted output and a URL to the full output. @@ -419,40 +374,32 @@ class Snekbox(Cog):          return output, paste_link      @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) -    async def send_job( -        self, -        job_name: str, -        ctx: Context, -        python_version: Literal["3.10", "3.11"], -        job: EvalJob, -    ) -> Message: +    async def send_job(self, ctx: Context, job: EvalJob) -> Message:          """          Evaluate code, format it, and send the output to the corresponding channel.          Return the bot response.          """          async with ctx.typing(): -            results = await self.post_job(job, python_version) -            msg, error, attachments = self.get_results_message(results, job_name, python_version) +            result = await self.post_job(job) +            msg, error = result.message(job)              if error:                  output, paste_link = error, None              else:                  log.trace("Formatting output...") -                output, paste_link = await self.format_output(results["stdout"]) - -            icon = self.get_status_emoji(results) +                output, paste_link = await self.format_output(result.stdout) -            if attachments and output in ("[No output]", ""): -                msg = f"{ctx.author.mention} {icon} {msg}.\n" +            if result.files and output in ("[No output]", ""): +                msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"              else: -                msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" +                msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n\n```\n{output}\n```"              if paste_link:                  msg = f"{msg}\nFull output: {paste_link}"              # Collect stats of job fails + successes -            if icon == ":x:": +            if result.returncode != 0:                  self.bot.stats.incr("snekbox.python.fail")              else:                  self.bot.stats.incr("snekbox.python.success") @@ -465,14 +412,14 @@ class Snekbox(Cog):                  response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")              else:                  allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) -                view = self.build_python_version_switcher_view(job_name, python_version, ctx, job) +                view = self.build_python_version_switcher_view(job.version, ctx, job) -                # Attach file if provided -                files = [atc.to_file() for atc in attachments] +                # Attach files if provided +                files = [f.to_file() for f in result.files]                  response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files)                  view.message = response -            log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") +            log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}")          return response      async def continue_job( @@ -549,9 +496,7 @@ class Snekbox(Cog):      async def run_job(          self, -        job_name: str,          ctx: Context, -        python_version: Literal["3.10", "3.11"],          job: EvalJob,      ) -> None:          """Handles checks, stats and re-evaluation of a snekbox job.""" @@ -571,7 +516,7 @@ class Snekbox(Cog):          while True:              try: -                response = await self.send_job(job_name, ctx, python_version, job) +                response = await self.send_job(ctx, job)              except LockedResourceError:                  await ctx.send(                      f"{ctx.author.mention} You've already got a job running - " @@ -584,7 +529,7 @@ class Snekbox(Cog):              # This can happen when a button is pressed and then original code is edited and re-run.              self.jobs[ctx.message.id] = response.id -            job = await self.continue_job(ctx, response, job_name) +            job = await self.continue_job(ctx, response, job.name)              if not job:                  break              log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}") @@ -601,7 +546,7 @@ class Snekbox(Cog):      async def eval_command(          self,          ctx: Context, -        python_version: Optional[Literal["3.10", "3.11"]], +        python_version: PythonVersion | None,          *,          code: CodeblockConverter      ) -> None: @@ -624,8 +569,8 @@ class Snekbox(Cog):          """          code: list[str]          python_version = python_version or "3.11" -        job = EvalJob.from_code("\n".join(code)) -        await self.run_job("eval", ctx, python_version, job) +        job = EvalJob.from_code("\n".join(code)).as_version(python_version) +        await self.run_job(ctx, job)      @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")      @guild_only() @@ -639,7 +584,7 @@ class Snekbox(Cog):      async def timeit_command(          self,          ctx: Context, -        python_version: Optional[Literal["3.10", "3.11"]], +        python_version: PythonVersion | None,          *,          code: CodeblockConverter      ) -> None: @@ -663,8 +608,9 @@ class Snekbox(Cog):          code: list[str]          python_version = python_version or "3.11"          args = self.prepare_timeit_input(code) +        job = EvalJob(args, version=python_version, name="timeit") -        await self.run_job("timeit", ctx, python_version, EvalJob(args)) +        await self.run_job(ctx, job)  def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: diff --git a/bot/exts/utils/snekio.py b/bot/exts/utils/snekio.py new file mode 100644 index 000000000..7c5fba648 --- /dev/null +++ b/bot/exts/utils/snekio.py @@ -0,0 +1,64 @@ +"""I/O File protocols for snekbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path + +from discord import File + +# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024  # 8 MiB + + +def sizeof_fmt(num: int, suffix: str = "B") -> str: +    """Return a human-readable file size.""" +    for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): +        if abs(num) < 1024: +            return f"{num:3.1f}{unit}{suffix}" +        num /= 1024 +    return f"{num:.1f}Yi{suffix}" + + +@dataclass +class FileAttachment: +    """File Attachment from Snekbox eval.""" + +    path: str +    content: bytes + +    def __repr__(self) -> str: +        """Return the content as a string.""" +        content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content +        return f"FileAttachment(path={self.path!r}, content={content})" + +    @classmethod +    def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: +        """Create a FileAttachment from a dict response.""" +        size = data.get("size") +        if (size and size > size_limit) or (len(data["content"]) > size_limit): +            raise ValueError("File size exceeds limit") + +        content = b64decode(data["content"]) + +        if len(content) > size_limit: +            raise ValueError("File size exceeds limit") + +        return cls(data["path"], content) + +    def to_dict(self) -> dict[str, str]: +        """Convert the attachment to a json dict.""" +        content = self.content +        if isinstance(content, str): +            content = content.encode("utf-8") + +        return { +            "path": self.path, +            "content": b64encode(content).decode("ascii"), +        } + +    def to_file(self) -> File: +        """Convert to a discord.File.""" +        name = Path(self.path).name +        return File(BytesIO(self.content), filename=name) | 
