aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/internal_eval/_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/exts/internal_eval/_helpers.py')
-rw-r--r--bot/exts/internal_eval/_helpers.py249
1 files changed, 249 insertions, 0 deletions
diff --git a/bot/exts/internal_eval/_helpers.py b/bot/exts/internal_eval/_helpers.py
new file mode 100644
index 00000000..3a50b9f3
--- /dev/null
+++ b/bot/exts/internal_eval/_helpers.py
@@ -0,0 +1,249 @@
+import ast
+import collections
+import contextlib
+import functools
+import inspect
+import io
+import logging
+import sys
+import traceback
+import types
+import typing
+
+
+log = logging.getLogger(__name__)
+
+# A type alias to annotate the tuples returned from `sys.exc_info()`
+ExcInfo = typing.Tuple[typing.Type[Exception], Exception, types.TracebackType]
+Namespace = typing.Dict[str, typing.Any]
+
+# This will be used as an coroutine function wrapper for the code
+# to be evaluated. The wrapper contains one `pass` statement which
+# will be replaced with `ast` with the code that we want to have
+# evaluated.
+# The function redirects output and captures exceptions that were
+# raised in the code we evaluate. The latter is used to provide a
+# meaningful traceback to the end user.
+EVAL_WRAPPER = """
+async def _eval_wrapper_function():
+ try:
+ with contextlib.redirect_stdout(_eval_context.stdout):
+ pass
+ if '_value_last_expression' in locals():
+ if inspect.isawaitable(_value_last_expression):
+ _value_last_expression = await _value_last_expression
+ _eval_context._value_last_expression = _value_last_expression
+ else:
+ _eval_context._value_last_expression = None
+ except Exception:
+ _eval_context.exc_info = sys.exc_info()
+ finally:
+ _eval_context.locals = locals()
+_eval_context.function = _eval_wrapper_function
+"""
+INTERNAL_EVAL_FRAMENAME = "<internal eval>"
+EVAL_WRAPPER_FUNCTION_FRAMENAME = "_eval_wrapper_function"
+
+
+def format_internal_eval_exception(exc_info: ExcInfo, code: str) -> str:
+ """Format an exception caught while evaluation code by inserting lines."""
+ exc_type, exc_value, tb = exc_info
+ stack_summary = traceback.StackSummary.extract(traceback.walk_tb(tb))
+ code = code.split("\n")
+
+ output = ["Traceback (most recent call last):"]
+ for frame in stack_summary:
+ if frame.filename == INTERNAL_EVAL_FRAMENAME:
+ line = code[frame.lineno - 1].lstrip()
+
+ if frame.name == EVAL_WRAPPER_FUNCTION_FRAMENAME:
+ name = INTERNAL_EVAL_FRAMENAME
+ else:
+ name = frame.name
+ else:
+ line = frame.line
+ name = frame.name
+
+ output.append(
+ f' File "{frame.filename}", line {frame.lineno}, in {name}\n'
+ f" {line}"
+ )
+
+ output.extend(traceback.format_exception_only(exc_type, exc_value))
+ return "\n".join(output)
+
+
+class EvalContext:
+ """
+ Represents the current `internal eval` context.
+
+ The context remembers names set during earlier runs of `internal eval`. To
+ clear the context, use the `.internal clear` command.
+ """
+
+ def __init__(self, context_vars: Namespace, local_vars: Namespace) -> None:
+ self._locals = dict(local_vars)
+ self.context_vars = dict(context_vars)
+
+ self.stdout = io.StringIO()
+ self._value_last_expression = None
+ self.exc_info = None
+ self.code = ""
+ self.function = None
+ self.eval_tree = None
+
+ @property
+ def dependencies(self) -> typing.Dict[str, typing.Any]:
+ """
+ Return a mapping of the dependencies for the wrapper function.
+
+ By using a property descriptor, the mapping can't be accidentally
+ mutated during evaluation. This ensures the dependencies are always
+ available.
+ """
+ return {
+ "print": functools.partial(print, file=self.stdout),
+ "contextlib": contextlib,
+ "inspect": inspect,
+ "sys": sys,
+ "_eval_context": self,
+ "_": self._value_last_expression,
+ }
+
+ @property
+ def locals(self) -> typing.Dict[str, typing.Any]:
+ """Return a mapping of names->values needed for evaluation."""
+ return {**collections.ChainMap(self.dependencies, self.context_vars, self._locals)}
+
+ @locals.setter
+ def locals(self, locals_: typing.Dict[str, typing.Any]) -> None:
+ """Update the contextual mapping of names to values."""
+ log.trace(f"Updating {self._locals} with {locals_}")
+ self._locals.update(locals_)
+
+ def prepare_eval(self, code: str) -> typing.Optional[str]:
+ """Prepare an evaluation by processing the code and setting up the context."""
+ self.code = code
+
+ if not self.code:
+ log.debug("No code was attached to the evaluation command")
+ return "[No code detected]"
+
+ try:
+ code_tree = ast.parse(code, filename=INTERNAL_EVAL_FRAMENAME)
+ except SyntaxError:
+ log.debug("Got a SyntaxError while parsing the eval code")
+ return "".join(traceback.format_exception(*sys.exc_info(), limit=0))
+
+ log.trace("Parsing the AST to see if there's a trailing expression we need to capture")
+ code_tree = CaptureLastExpression(code_tree).capture()
+
+ log.trace("Wrapping the AST in the AST of the wrapper coroutine")
+ eval_tree = WrapEvalCodeTree(code_tree).wrap()
+
+ self.eval_tree = eval_tree
+ return None
+
+ async def run_eval(self) -> Namespace:
+ """Run the evaluation and return the updated locals."""
+ log.trace("Compiling the AST to bytecode using `exec` mode")
+ compiled_code = compile(self.eval_tree, filename=INTERNAL_EVAL_FRAMENAME, mode="exec")
+
+ log.trace("Executing the compiled code with the desired namespace environment")
+ exec(compiled_code, self.locals) # noqa: B102,S102
+
+ log.trace("Awaiting the created evaluation wrapper coroutine.")
+ await self.function()
+
+ log.trace("Returning the updated captured locals.")
+ return self._locals
+
+ def format_output(self) -> str:
+ """Format the output of the most recent evaluation."""
+ output = []
+
+ log.trace(f"Getting output from stdout `{id(self.stdout)}`")
+ stdout_text = self.stdout.getvalue()
+ if stdout_text:
+ log.trace("Appending output captured from stdout/print")
+ output.append(stdout_text)
+
+ if self._value_last_expression is not None:
+ log.trace("Appending the output of a captured trialing expression")
+ output.append(f"[Captured] {self._value_last_expression!r}")
+
+ if self.exc_info:
+ log.trace("Appending exception information")
+ output.append(format_internal_eval_exception(self.exc_info, self.code))
+
+ log.trace(f"Generated output: {output!r}")
+ return "\n".join(output) or "[No output]"
+
+
+class WrapEvalCodeTree(ast.NodeTransformer):
+ """Wraps the AST of eval code with the wrapper function."""
+
+ def __init__(self, eval_code_tree: ast.AST, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.eval_code_tree = eval_code_tree
+
+ # To avoid mutable aliasing, parse the WRAPPER_FUNC for each wrapping
+ self.wrapper = ast.parse(EVAL_WRAPPER, filename=INTERNAL_EVAL_FRAMENAME)
+
+ def wrap(self) -> ast.AST:
+ """Wrap the tree of the code by the tree of the wrapper function."""
+ new_tree = self.visit(self.wrapper)
+ return ast.fix_missing_locations(new_tree)
+
+ def visit_Pass(self, node: ast.Pass) -> typing.List[ast.AST]: # noqa: N802
+ """
+ Replace the `_ast.Pass` node in the wrapper function by the eval AST.
+
+ This method works on the assumption that there's a single `pass`
+ statement in the wrapper function.
+ """
+ return list(ast.iter_child_nodes(self.eval_code_tree))
+
+
+class CaptureLastExpression(ast.NodeTransformer):
+ """Captures the return value from a loose expression."""
+
+ def __init__(self, tree: ast.AST, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.tree = tree
+ self.last_node = list(ast.iter_child_nodes(tree))[-1]
+
+ def visit_Expr(self, node: ast.Expr) -> typing.Union[ast.Expr, ast.Assign]: # noqa: N802
+ """
+ Replace the Expr node that is last child node of Module with an assignment.
+
+ We use an assignment to capture the value of the last node, if it's a loose
+ Expr node. Normally, the value of an Expr node is lost, meaning we don't get
+ the output of such a last "loose" expression. By assigning it a name, we can
+ retrieve it for our output.
+ """
+ if node is not self.last_node:
+ return node
+
+ log.trace("Found a trailing last expression in the evaluation code")
+
+ log.trace("Creating assignment statement with trailing expression as the right-hand side")
+ right_hand_side = list(ast.iter_child_nodes(node))[0]
+
+ assignment = ast.Assign(
+ targets=[ast.Name(id='_value_last_expression', ctx=ast.Store())],
+ value=right_hand_side,
+ lineno=node.lineno,
+ col_offset=0,
+ )
+ ast.fix_missing_locations(assignment)
+ return assignment
+
+ def capture(self) -> ast.AST:
+ """Capture the value of the last expression with an assignment."""
+ if not isinstance(self.last_node, ast.Expr):
+ # We only have to replace a node if the very last node is an Expr node
+ return self.tree
+
+ new_tree = self.visit(self.tree)
+ return ast.fix_missing_locations(new_tree)