aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/internal_eval
diff options
context:
space:
mode:
authorGravatar bradtimmis <[email protected]>2021-08-30 22:09:44 -0400
committerGravatar bradtimmis <[email protected]>2021-08-30 22:09:44 -0400
commit8bc54b2e2aeaeef30efd9f7f684cce48b7b64daf (patch)
tree91e844d1431a127815d33406139e5798341b533a /bot/exts/internal_eval
parentNumerous syntax and bug fixes (diff)
parentMerge pull request #831 from brad90four/patch-1 (diff)
Update branch with main
Diffstat (limited to 'bot/exts/internal_eval')
-rw-r--r--bot/exts/internal_eval/__init__.py10
-rw-r--r--bot/exts/internal_eval/_helpers.py249
-rw-r--r--bot/exts/internal_eval/_internal_eval.py179
3 files changed, 438 insertions, 0 deletions
diff --git a/bot/exts/internal_eval/__init__.py b/bot/exts/internal_eval/__init__.py
new file mode 100644
index 00000000..695fa74d
--- /dev/null
+++ b/bot/exts/internal_eval/__init__.py
@@ -0,0 +1,10 @@
+from bot.bot import Bot
+
+
+def setup(bot: Bot) -> None:
+ """Set up the Internal Eval extension."""
+ # Import the Cog at runtime to prevent side effects like defining
+ # RedisCache instances too early.
+ from ._internal_eval import InternalEval
+
+ bot.add_cog(InternalEval(bot))
diff --git a/bot/exts/internal_eval/_helpers.py b/bot/exts/internal_eval/_helpers.py
new file mode 100644
index 00000000..3a50b9f3
--- /dev/null
+++ b/bot/exts/internal_eval/_helpers.py
@@ -0,0 +1,249 @@
+import ast
+import collections
+import contextlib
+import functools
+import inspect
+import io
+import logging
+import sys
+import traceback
+import types
+import typing
+
+
+log = logging.getLogger(__name__)
+
+# A type alias to annotate the tuples returned from `sys.exc_info()`
+ExcInfo = typing.Tuple[typing.Type[Exception], Exception, types.TracebackType]
+Namespace = typing.Dict[str, typing.Any]
+
+# This will be used as an coroutine function wrapper for the code
+# to be evaluated. The wrapper contains one `pass` statement which
+# will be replaced with `ast` with the code that we want to have
+# evaluated.
+# The function redirects output and captures exceptions that were
+# raised in the code we evaluate. The latter is used to provide a
+# meaningful traceback to the end user.
+EVAL_WRAPPER = """
+async def _eval_wrapper_function():
+ try:
+ with contextlib.redirect_stdout(_eval_context.stdout):
+ pass
+ if '_value_last_expression' in locals():
+ if inspect.isawaitable(_value_last_expression):
+ _value_last_expression = await _value_last_expression
+ _eval_context._value_last_expression = _value_last_expression
+ else:
+ _eval_context._value_last_expression = None
+ except Exception:
+ _eval_context.exc_info = sys.exc_info()
+ finally:
+ _eval_context.locals = locals()
+_eval_context.function = _eval_wrapper_function
+"""
+INTERNAL_EVAL_FRAMENAME = "<internal eval>"
+EVAL_WRAPPER_FUNCTION_FRAMENAME = "_eval_wrapper_function"
+
+
+def format_internal_eval_exception(exc_info: ExcInfo, code: str) -> str:
+ """Format an exception caught while evaluation code by inserting lines."""
+ exc_type, exc_value, tb = exc_info
+ stack_summary = traceback.StackSummary.extract(traceback.walk_tb(tb))
+ code = code.split("\n")
+
+ output = ["Traceback (most recent call last):"]
+ for frame in stack_summary:
+ if frame.filename == INTERNAL_EVAL_FRAMENAME:
+ line = code[frame.lineno - 1].lstrip()
+
+ if frame.name == EVAL_WRAPPER_FUNCTION_FRAMENAME:
+ name = INTERNAL_EVAL_FRAMENAME
+ else:
+ name = frame.name
+ else:
+ line = frame.line
+ name = frame.name
+
+ output.append(
+ f' File "{frame.filename}", line {frame.lineno}, in {name}\n'
+ f" {line}"
+ )
+
+ output.extend(traceback.format_exception_only(exc_type, exc_value))
+ return "\n".join(output)
+
+
+class EvalContext:
+ """
+ Represents the current `internal eval` context.
+
+ The context remembers names set during earlier runs of `internal eval`. To
+ clear the context, use the `.internal clear` command.
+ """
+
+ def __init__(self, context_vars: Namespace, local_vars: Namespace) -> None:
+ self._locals = dict(local_vars)
+ self.context_vars = dict(context_vars)
+
+ self.stdout = io.StringIO()
+ self._value_last_expression = None
+ self.exc_info = None
+ self.code = ""
+ self.function = None
+ self.eval_tree = None
+
+ @property
+ def dependencies(self) -> typing.Dict[str, typing.Any]:
+ """
+ Return a mapping of the dependencies for the wrapper function.
+
+ By using a property descriptor, the mapping can't be accidentally
+ mutated during evaluation. This ensures the dependencies are always
+ available.
+ """
+ return {
+ "print": functools.partial(print, file=self.stdout),
+ "contextlib": contextlib,
+ "inspect": inspect,
+ "sys": sys,
+ "_eval_context": self,
+ "_": self._value_last_expression,
+ }
+
+ @property
+ def locals(self) -> typing.Dict[str, typing.Any]:
+ """Return a mapping of names->values needed for evaluation."""
+ return {**collections.ChainMap(self.dependencies, self.context_vars, self._locals)}
+
+ @locals.setter
+ def locals(self, locals_: typing.Dict[str, typing.Any]) -> None:
+ """Update the contextual mapping of names to values."""
+ log.trace(f"Updating {self._locals} with {locals_}")
+ self._locals.update(locals_)
+
+ def prepare_eval(self, code: str) -> typing.Optional[str]:
+ """Prepare an evaluation by processing the code and setting up the context."""
+ self.code = code
+
+ if not self.code:
+ log.debug("No code was attached to the evaluation command")
+ return "[No code detected]"
+
+ try:
+ code_tree = ast.parse(code, filename=INTERNAL_EVAL_FRAMENAME)
+ except SyntaxError:
+ log.debug("Got a SyntaxError while parsing the eval code")
+ return "".join(traceback.format_exception(*sys.exc_info(), limit=0))
+
+ log.trace("Parsing the AST to see if there's a trailing expression we need to capture")
+ code_tree = CaptureLastExpression(code_tree).capture()
+
+ log.trace("Wrapping the AST in the AST of the wrapper coroutine")
+ eval_tree = WrapEvalCodeTree(code_tree).wrap()
+
+ self.eval_tree = eval_tree
+ return None
+
+ async def run_eval(self) -> Namespace:
+ """Run the evaluation and return the updated locals."""
+ log.trace("Compiling the AST to bytecode using `exec` mode")
+ compiled_code = compile(self.eval_tree, filename=INTERNAL_EVAL_FRAMENAME, mode="exec")
+
+ log.trace("Executing the compiled code with the desired namespace environment")
+ exec(compiled_code, self.locals) # noqa: B102,S102
+
+ log.trace("Awaiting the created evaluation wrapper coroutine.")
+ await self.function()
+
+ log.trace("Returning the updated captured locals.")
+ return self._locals
+
+ def format_output(self) -> str:
+ """Format the output of the most recent evaluation."""
+ output = []
+
+ log.trace(f"Getting output from stdout `{id(self.stdout)}`")
+ stdout_text = self.stdout.getvalue()
+ if stdout_text:
+ log.trace("Appending output captured from stdout/print")
+ output.append(stdout_text)
+
+ if self._value_last_expression is not None:
+ log.trace("Appending the output of a captured trialing expression")
+ output.append(f"[Captured] {self._value_last_expression!r}")
+
+ if self.exc_info:
+ log.trace("Appending exception information")
+ output.append(format_internal_eval_exception(self.exc_info, self.code))
+
+ log.trace(f"Generated output: {output!r}")
+ return "\n".join(output) or "[No output]"
+
+
+class WrapEvalCodeTree(ast.NodeTransformer):
+ """Wraps the AST of eval code with the wrapper function."""
+
+ def __init__(self, eval_code_tree: ast.AST, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.eval_code_tree = eval_code_tree
+
+ # To avoid mutable aliasing, parse the WRAPPER_FUNC for each wrapping
+ self.wrapper = ast.parse(EVAL_WRAPPER, filename=INTERNAL_EVAL_FRAMENAME)
+
+ def wrap(self) -> ast.AST:
+ """Wrap the tree of the code by the tree of the wrapper function."""
+ new_tree = self.visit(self.wrapper)
+ return ast.fix_missing_locations(new_tree)
+
+ def visit_Pass(self, node: ast.Pass) -> typing.List[ast.AST]: # noqa: N802
+ """
+ Replace the `_ast.Pass` node in the wrapper function by the eval AST.
+
+ This method works on the assumption that there's a single `pass`
+ statement in the wrapper function.
+ """
+ return list(ast.iter_child_nodes(self.eval_code_tree))
+
+
+class CaptureLastExpression(ast.NodeTransformer):
+ """Captures the return value from a loose expression."""
+
+ def __init__(self, tree: ast.AST, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.tree = tree
+ self.last_node = list(ast.iter_child_nodes(tree))[-1]
+
+ def visit_Expr(self, node: ast.Expr) -> typing.Union[ast.Expr, ast.Assign]: # noqa: N802
+ """
+ Replace the Expr node that is last child node of Module with an assignment.
+
+ We use an assignment to capture the value of the last node, if it's a loose
+ Expr node. Normally, the value of an Expr node is lost, meaning we don't get
+ the output of such a last "loose" expression. By assigning it a name, we can
+ retrieve it for our output.
+ """
+ if node is not self.last_node:
+ return node
+
+ log.trace("Found a trailing last expression in the evaluation code")
+
+ log.trace("Creating assignment statement with trailing expression as the right-hand side")
+ right_hand_side = list(ast.iter_child_nodes(node))[0]
+
+ assignment = ast.Assign(
+ targets=[ast.Name(id='_value_last_expression', ctx=ast.Store())],
+ value=right_hand_side,
+ lineno=node.lineno,
+ col_offset=0,
+ )
+ ast.fix_missing_locations(assignment)
+ return assignment
+
+ def capture(self) -> ast.AST:
+ """Capture the value of the last expression with an assignment."""
+ if not isinstance(self.last_node, ast.Expr):
+ # We only have to replace a node if the very last node is an Expr node
+ return self.tree
+
+ new_tree = self.visit(self.tree)
+ return ast.fix_missing_locations(new_tree)
diff --git a/bot/exts/internal_eval/_internal_eval.py b/bot/exts/internal_eval/_internal_eval.py
new file mode 100644
index 00000000..b7749144
--- /dev/null
+++ b/bot/exts/internal_eval/_internal_eval.py
@@ -0,0 +1,179 @@
+import logging
+import re
+import textwrap
+import typing
+
+import discord
+from discord.ext import commands
+
+from bot.bot import Bot
+from bot.constants import Client, Roles
+from bot.utils.decorators import with_role
+from bot.utils.extensions import invoke_help_command
+from ._helpers import EvalContext
+
+__all__ = ["InternalEval"]
+
+log = logging.getLogger(__name__)
+
+FORMATTED_CODE_REGEX = re.compile(
+ r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
+ r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
+ r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all code inside the markup
+ r"\s*" # any more whitespace before the end of the code markup
+ r"(?P=delim)", # match the exact same delimiter from the start again
+ re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
+)
+
+RAW_CODE_REGEX = re.compile(
+ r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all the rest as code
+ r"\s*$", # any trailing whitespace until the end of the string
+ re.DOTALL # "." also matches newlines
+)
+
+
+class InternalEval(commands.Cog):
+ """Top secret code evaluation for admins and owners."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.locals = {}
+
+ if Client.debug:
+ self.internal_group.add_check(commands.is_owner().predicate)
+
+ @staticmethod
+ def shorten_output(
+ output: str,
+ max_length: int = 1900,
+ placeholder: str = "\n[output truncated]"
+ ) -> str:
+ """
+ Shorten the `output` so it's shorter than `max_length`.
+
+ There are three tactics for this, tried in the following order:
+ - Shorten the output on a line-by-line basis
+ - Shorten the output on any whitespace character
+ - Shorten the output solely on character count
+ """
+ max_length = max_length - len(placeholder)
+
+ shortened_output = []
+ char_count = 0
+ for line in output.split("\n"):
+ if char_count + len(line) > max_length:
+ break
+ shortened_output.append(line)
+ char_count += len(line) + 1 # account for (possible) line ending
+
+ if shortened_output:
+ shortened_output.append(placeholder)
+ return "\n".join(shortened_output)
+
+ shortened_output = textwrap.shorten(output, width=max_length, placeholder=placeholder)
+
+ if shortened_output.strip() == placeholder.strip():
+ # `textwrap` was unable to find whitespace to shorten on, so it has
+ # reduced the output to just the placeholder. Let's shorten based on
+ # characters instead.
+ shortened_output = output[:max_length] + placeholder
+
+ return shortened_output
+
+ async def _upload_output(self, output: str) -> typing.Optional[str]:
+ """Upload `internal eval` output to our pastebin and return the url."""
+ try:
+ async with self.bot.http_session.post(
+ "https://paste.pythondiscord.com/documents", data=output, raise_for_status=True
+ ) as resp:
+ data = await resp.json()
+
+ if "key" in data:
+ return f"https://paste.pythondiscord.com/{data['key']}"
+ except Exception:
+ # 400 (Bad Request) means there are too many characters
+ log.exception("Failed to upload `internal eval` output to paste service!")
+
+ async def _send_output(self, ctx: commands.Context, output: str) -> None:
+ """Send the `internal eval` output to the command invocation context."""
+ upload_message = ""
+ if len(output) >= 1980:
+ # The output is too long, let's truncate it for in-channel output and
+ # upload the complete output to the paste service.
+ url = await self._upload_output(output)
+
+ if url:
+ upload_message = f"\nFull output here: {url}"
+ else:
+ upload_message = "\n:warning: Failed to upload full output!"
+
+ output = self.shorten_output(output)
+
+ await ctx.send(f"```py\n{output}\n```{upload_message}")
+
+ async def _eval(self, ctx: commands.Context, code: str) -> None:
+ """Evaluate the `code` in the current evaluation context."""
+ context_vars = {
+ "message": ctx.message,
+ "author": ctx.author,
+ "channel": ctx.channel,
+ "guild": ctx.guild,
+ "ctx": ctx,
+ "self": self,
+ "bot": self.bot,
+ "discord": discord,
+ }
+
+ eval_context = EvalContext(context_vars, self.locals)
+
+ log.trace("Preparing the evaluation by parsing the AST of the code")
+ error = eval_context.prepare_eval(code)
+
+ if error:
+ log.trace("The code can't be evaluated due to an error")
+ await ctx.send(f"```py\n{error}\n```")
+ return
+
+ log.trace("Evaluate the AST we've generated for the evaluation")
+ new_locals = await eval_context.run_eval()
+
+ log.trace("Updating locals with those set during evaluation")
+ self.locals.update(new_locals)
+
+ log.trace("Sending the formatted output back to the context")
+ await self._send_output(ctx, eval_context.format_output())
+
+ @commands.group(name="internal", aliases=("int",))
+ @with_role(Roles.admin)
+ async def internal_group(self, ctx: commands.Context) -> None:
+ """Internal commands. Top secret!"""
+ if not ctx.invoked_subcommand:
+ await invoke_help_command(ctx)
+
+ @internal_group.command(name="eval", aliases=("e",))
+ @with_role(Roles.admin)
+ async def eval(self, ctx: commands.Context, *, code: str) -> None:
+ """Run eval in a REPL-like format."""
+ if match := list(FORMATTED_CODE_REGEX.finditer(code)):
+ blocks = [block for block in match if block.group("block")]
+
+ if len(blocks) > 1:
+ code = "\n".join(block.group("code") for block in blocks)
+ else:
+ match = match[0] if len(blocks) == 0 else blocks[0]
+ code, block, lang, delim = match.group("code", "block", "lang", "delim")
+
+ else:
+ code = RAW_CODE_REGEX.fullmatch(code).group("code")
+
+ code = textwrap.dedent(code)
+ await self._eval(ctx, code)
+
+ @internal_group.command(name="reset", aliases=("clear", "exit", "r", "c"))
+ @with_role(Roles.admin)
+ async def reset(self, ctx: commands.Context) -> None:
+ """Reset the context and locals of the eval session."""
+ self.locals = {}
+ await ctx.send("The evaluation context was reset.")