diff options
Diffstat (limited to 'bot/exts/utils/eval.py')
| -rw-r--r-- | bot/exts/utils/eval.py | 226 | 
1 files changed, 226 insertions, 0 deletions
| diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py new file mode 100644 index 000000000..6419b320e --- /dev/null +++ b/bot/exts/utils/eval.py @@ -0,0 +1,226 @@ +import contextlib +import inspect +import logging +import pprint +import re +import textwrap +import traceback +from io import StringIO +from typing import Any, Optional, Tuple + +import discord +from discord.ext.commands import Cog, Context, group, has_any_role + +from bot.bot import Bot +from bot.constants import Roles +from bot.interpreter import Interpreter +from bot.utils import find_nth_occurrence, send_to_paste_service + +log = logging.getLogger(__name__) + + +class CodeEval(Cog): +    """Owner and admin feature that evaluates code and returns the result to the channel.""" + +    def __init__(self, bot: Bot): +        self.bot = bot +        self.env = {} +        self.ln = 0 +        self.stdout = StringIO() + +        self.interpreter = Interpreter(bot) + +    def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]: +        """Format the eval output into a string & attempt to format it into an Embed.""" +        self._ = out + +        res = "" + +        # Erase temp input we made +        if inp.startswith("_ = "): +            inp = inp[4:] + +        # Get all non-empty lines +        lines = [line for line in inp.split("\n") if line.strip()] +        if len(lines) != 1: +            lines += [""] + +        # Create the input dialog +        for i, line in enumerate(lines): +            if i == 0: +                # Start dialog +                start = f"In [{self.ln}]: " + +            else: +                # Indent the 3 dots correctly; +                # Normally, it's something like +                # In [X]: +                #    ...: +                # +                # But if it's +                # In [XX]: +                #    ...: +                # +                # You can see it doesn't look right. +                # This code simply indents the dots +                # far enough to align them. +                # we first `str()` the line number +                # then we get the length +                # and use `str.rjust()` +                # to indent it. +                start = "...: ".rjust(len(str(self.ln)) + 7) + +            if i == len(lines) - 2: +                if line.startswith("return"): +                    line = line[6:].strip() + +            # Combine everything +            res += (start + line + "\n") + +        self.stdout.seek(0) +        text = self.stdout.read() +        self.stdout.close() +        self.stdout = StringIO() + +        if text: +            res += (text + "\n") + +        if out is None: +            # No output, return the input statement +            return (res, None) + +        res += f"Out[{self.ln}]: " + +        if isinstance(out, discord.Embed): +            # We made an embed? Send that as embed +            res += "<Embed>" +            res = (res, out) + +        else: +            if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")): +                # Leave out the traceback message +                out = "\n" + "\n".join(out.split("\n")[1:]) + +            if isinstance(out, str): +                pretty = out +            else: +                pretty = pprint.pformat(out, compact=True, width=60) + +            if pretty != str(out): +                # We're using the pretty version, start on the next line +                res += "\n" + +            if pretty.count("\n") > 20: +                # Text too long, shorten +                li = pretty.split("\n") + +                pretty = ("\n".join(li[:3])  # First 3 lines +                          + "\n ...\n"  # Ellipsis to indicate removed lines +                          + "\n".join(li[-3:]))  # last 3 lines + +            # Add the output +            res += pretty +            res = (res, None) + +        return res  # Return (text, embed) + +    async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]: +        """Eval the input code string & send an embed to the invoking context.""" +        self.ln += 1 + +        if code.startswith("exit"): +            self.ln = 0 +            self.env = {} +            return await ctx.send("```Reset history!```") + +        env = { +            "message": ctx.message, +            "author": ctx.message.author, +            "channel": ctx.channel, +            "guild": ctx.guild, +            "ctx": ctx, +            "self": self, +            "bot": self.bot, +            "inspect": inspect, +            "discord": discord, +            "contextlib": contextlib +        } + +        self.env.update(env) + +        # Ignore this code, it works +        code_ = """ +async def func():  # (None,) -> Any +    try: +        with contextlib.redirect_stdout(self.stdout): +{0} +        if '_' in locals(): +            if inspect.isawaitable(_): +                _ = await _ +            return _ +    finally: +        self.env.update(locals()) +""".format(textwrap.indent(code, '            ')) + +        try: +            exec(code_, self.env)  # noqa: B102,S102 +            func = self.env['func'] +            res = await func() + +        except Exception: +            res = traceback.format_exc() + +        out, embed = self._format(code, res) +        out = out.rstrip("\n")  # Strip empty lines from output + +        # Truncate output to max 15 lines or 1500 characters +        newline_truncate_index = find_nth_occurrence(out, "\n", 15) + +        if newline_truncate_index is None or newline_truncate_index > 1500: +            truncate_index = 1500 +        else: +            truncate_index = newline_truncate_index + +        if len(out) > truncate_index: +            paste_link = await send_to_paste_service(self.bot.http_session, out, extension="py") +            if paste_link is not None: +                paste_text = f"full contents at {paste_link}" +            else: +                paste_text = "failed to upload contents to paste service." + +            await ctx.send( +                f"```py\n{out[:truncate_index]}\n```" +                f"... response truncated; {paste_text}", +                embed=embed +            ) +            return + +        await ctx.send(f"```py\n{out}```", embed=embed) + +    @group(name='internal', aliases=('int',)) +    @has_any_role(Roles.owners, Roles.admins) +    async def internal_group(self, ctx: Context) -> None: +        """Internal commands. Top secret!""" +        if not ctx.invoked_subcommand: +            await ctx.send_help(ctx.command) + +    @internal_group.command(name='eval', aliases=('e',)) +    @has_any_role(Roles.admins, Roles.owners) +    async def eval(self, ctx: Context, *, code: str) -> None: +        """Run eval in a REPL-like format.""" +        code = code.strip("`") +        if re.match('py(thon)?\n', code): +            code = "\n".join(code.split("\n")[1:]) + +        if not re.search(  # Check if it's an expression +                r"^(return|import|for|while|def|class|" +                r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len( +                    code.split("\n")) == 1: +            code = "_ = " + code + +        await self._eval(ctx, code) + + +def setup(bot: Bot) -> None: +    """Load the CodeEval cog.""" +    bot.add_cog(CodeEval(bot)) | 
