aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/utils/eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/exts/utils/eval.py')
-rw-r--r--bot/exts/utils/eval.py226
1 files changed, 226 insertions, 0 deletions
diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py
new file mode 100644
index 000000000..6419b320e
--- /dev/null
+++ b/bot/exts/utils/eval.py
@@ -0,0 +1,226 @@
+import contextlib
+import inspect
+import logging
+import pprint
+import re
+import textwrap
+import traceback
+from io import StringIO
+from typing import Any, Optional, Tuple
+
+import discord
+from discord.ext.commands import Cog, Context, group, has_any_role
+
+from bot.bot import Bot
+from bot.constants import Roles
+from bot.interpreter import Interpreter
+from bot.utils import find_nth_occurrence, send_to_paste_service
+
+log = logging.getLogger(__name__)
+
+
+class CodeEval(Cog):
+ """Owner and admin feature that evaluates code and returns the result to the channel."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.env = {}
+ self.ln = 0
+ self.stdout = StringIO()
+
+ self.interpreter = Interpreter(bot)
+
+ def _format(self, inp: str, out: Any) -> Tuple[str, Optional[discord.Embed]]:
+ """Format the eval output into a string & attempt to format it into an Embed."""
+ self._ = out
+
+ res = ""
+
+ # Erase temp input we made
+ if inp.startswith("_ = "):
+ inp = inp[4:]
+
+ # Get all non-empty lines
+ lines = [line for line in inp.split("\n") if line.strip()]
+ if len(lines) != 1:
+ lines += [""]
+
+ # Create the input dialog
+ for i, line in enumerate(lines):
+ if i == 0:
+ # Start dialog
+ start = f"In [{self.ln}]: "
+
+ else:
+ # Indent the 3 dots correctly;
+ # Normally, it's something like
+ # In [X]:
+ # ...:
+ #
+ # But if it's
+ # In [XX]:
+ # ...:
+ #
+ # You can see it doesn't look right.
+ # This code simply indents the dots
+ # far enough to align them.
+ # we first `str()` the line number
+ # then we get the length
+ # and use `str.rjust()`
+ # to indent it.
+ start = "...: ".rjust(len(str(self.ln)) + 7)
+
+ if i == len(lines) - 2:
+ if line.startswith("return"):
+ line = line[6:].strip()
+
+ # Combine everything
+ res += (start + line + "\n")
+
+ self.stdout.seek(0)
+ text = self.stdout.read()
+ self.stdout.close()
+ self.stdout = StringIO()
+
+ if text:
+ res += (text + "\n")
+
+ if out is None:
+ # No output, return the input statement
+ return (res, None)
+
+ res += f"Out[{self.ln}]: "
+
+ if isinstance(out, discord.Embed):
+ # We made an embed? Send that as embed
+ res += "<Embed>"
+ res = (res, out)
+
+ else:
+ if (isinstance(out, str) and out.startswith("Traceback (most recent call last):\n")):
+ # Leave out the traceback message
+ out = "\n" + "\n".join(out.split("\n")[1:])
+
+ if isinstance(out, str):
+ pretty = out
+ else:
+ pretty = pprint.pformat(out, compact=True, width=60)
+
+ if pretty != str(out):
+ # We're using the pretty version, start on the next line
+ res += "\n"
+
+ if pretty.count("\n") > 20:
+ # Text too long, shorten
+ li = pretty.split("\n")
+
+ pretty = ("\n".join(li[:3]) # First 3 lines
+ + "\n ...\n" # Ellipsis to indicate removed lines
+ + "\n".join(li[-3:])) # last 3 lines
+
+ # Add the output
+ res += pretty
+ res = (res, None)
+
+ return res # Return (text, embed)
+
+ async def _eval(self, ctx: Context, code: str) -> Optional[discord.Message]:
+ """Eval the input code string & send an embed to the invoking context."""
+ self.ln += 1
+
+ if code.startswith("exit"):
+ self.ln = 0
+ self.env = {}
+ return await ctx.send("```Reset history!```")
+
+ env = {
+ "message": ctx.message,
+ "author": ctx.message.author,
+ "channel": ctx.channel,
+ "guild": ctx.guild,
+ "ctx": ctx,
+ "self": self,
+ "bot": self.bot,
+ "inspect": inspect,
+ "discord": discord,
+ "contextlib": contextlib
+ }
+
+ self.env.update(env)
+
+ # Ignore this code, it works
+ code_ = """
+async def func(): # (None,) -> Any
+ try:
+ with contextlib.redirect_stdout(self.stdout):
+{0}
+ if '_' in locals():
+ if inspect.isawaitable(_):
+ _ = await _
+ return _
+ finally:
+ self.env.update(locals())
+""".format(textwrap.indent(code, ' '))
+
+ try:
+ exec(code_, self.env) # noqa: B102,S102
+ func = self.env['func']
+ res = await func()
+
+ except Exception:
+ res = traceback.format_exc()
+
+ out, embed = self._format(code, res)
+ out = out.rstrip("\n") # Strip empty lines from output
+
+ # Truncate output to max 15 lines or 1500 characters
+ newline_truncate_index = find_nth_occurrence(out, "\n", 15)
+
+ if newline_truncate_index is None or newline_truncate_index > 1500:
+ truncate_index = 1500
+ else:
+ truncate_index = newline_truncate_index
+
+ if len(out) > truncate_index:
+ paste_link = await send_to_paste_service(self.bot.http_session, out, extension="py")
+ if paste_link is not None:
+ paste_text = f"full contents at {paste_link}"
+ else:
+ paste_text = "failed to upload contents to paste service."
+
+ await ctx.send(
+ f"```py\n{out[:truncate_index]}\n```"
+ f"... response truncated; {paste_text}",
+ embed=embed
+ )
+ return
+
+ await ctx.send(f"```py\n{out}```", embed=embed)
+
+ @group(name='internal', aliases=('int',))
+ @has_any_role(Roles.owners, Roles.admins)
+ async def internal_group(self, ctx: Context) -> None:
+ """Internal commands. Top secret!"""
+ if not ctx.invoked_subcommand:
+ await ctx.send_help(ctx.command)
+
+ @internal_group.command(name='eval', aliases=('e',))
+ @has_any_role(Roles.admins, Roles.owners)
+ async def eval(self, ctx: Context, *, code: str) -> None:
+ """Run eval in a REPL-like format."""
+ code = code.strip("`")
+ if re.match('py(thon)?\n', code):
+ code = "\n".join(code.split("\n")[1:])
+
+ if not re.search( # Check if it's an expression
+ r"^(return|import|for|while|def|class|"
+ r"from|exit|[a-zA-Z0-9]+\s*=)", code, re.M) and len(
+ code.split("\n")) == 1:
+ code = "_ = " + code
+
+ await self._eval(ctx, code)
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeEval cog."""
+ bot.add_cog(CodeEval(bot))