aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/internal_eval/_helpers.py
blob: 3a50b9f313ed0c9946467034fc74be16b9585b0a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import ast
import collections
import contextlib
import functools
import inspect
import io
import logging
import sys
import traceback
import types
import typing


log = logging.getLogger(__name__)

# A type alias to annotate the tuples returned from `sys.exc_info()`
ExcInfo = typing.Tuple[typing.Type[Exception], Exception, types.TracebackType]
Namespace = typing.Dict[str, typing.Any]

# This will be used as an coroutine function wrapper for the code
# to be evaluated. The wrapper contains one `pass` statement which
# will be replaced with `ast` with the code that we want to have
# evaluated.
# The function redirects output and captures exceptions that were
# raised in the code we evaluate. The latter is used to provide a
# meaningful traceback to the end user.
EVAL_WRAPPER = """
async def _eval_wrapper_function():
    try:
        with contextlib.redirect_stdout(_eval_context.stdout):
            pass
        if '_value_last_expression' in locals():
            if inspect.isawaitable(_value_last_expression):
                _value_last_expression = await _value_last_expression
            _eval_context._value_last_expression = _value_last_expression
        else:
            _eval_context._value_last_expression = None
    except Exception:
        _eval_context.exc_info = sys.exc_info()
    finally:
        _eval_context.locals = locals()
_eval_context.function = _eval_wrapper_function
"""
INTERNAL_EVAL_FRAMENAME = "<internal eval>"
EVAL_WRAPPER_FUNCTION_FRAMENAME = "_eval_wrapper_function"


def format_internal_eval_exception(exc_info: ExcInfo, code: str) -> str:
    """Format an exception caught while evaluation code by inserting lines."""
    exc_type, exc_value, tb = exc_info
    stack_summary = traceback.StackSummary.extract(traceback.walk_tb(tb))
    code = code.split("\n")

    output = ["Traceback (most recent call last):"]
    for frame in stack_summary:
        if frame.filename == INTERNAL_EVAL_FRAMENAME:
            line = code[frame.lineno - 1].lstrip()

            if frame.name == EVAL_WRAPPER_FUNCTION_FRAMENAME:
                name = INTERNAL_EVAL_FRAMENAME
            else:
                name = frame.name
        else:
            line = frame.line
            name = frame.name

        output.append(
            f'  File "{frame.filename}", line {frame.lineno}, in {name}\n'
            f"    {line}"
        )

    output.extend(traceback.format_exception_only(exc_type, exc_value))
    return "\n".join(output)


class EvalContext:
    """
    Represents the current `internal eval` context.

    The context remembers names set during earlier runs of `internal eval`. To
    clear the context, use the `.internal clear` command.
    """

    def __init__(self, context_vars: Namespace, local_vars: Namespace) -> None:
        self._locals = dict(local_vars)
        self.context_vars = dict(context_vars)

        self.stdout = io.StringIO()
        self._value_last_expression = None
        self.exc_info = None
        self.code = ""
        self.function = None
        self.eval_tree = None

    @property
    def dependencies(self) -> typing.Dict[str, typing.Any]:
        """
        Return a mapping of the dependencies for the wrapper function.

        By using a property descriptor, the mapping can't be accidentally
        mutated during evaluation. This ensures the dependencies are always
        available.
        """
        return {
            "print": functools.partial(print, file=self.stdout),
            "contextlib": contextlib,
            "inspect": inspect,
            "sys": sys,
            "_eval_context": self,
            "_": self._value_last_expression,
        }

    @property
    def locals(self) -> typing.Dict[str, typing.Any]:
        """Return a mapping of names->values needed for evaluation."""
        return {**collections.ChainMap(self.dependencies, self.context_vars, self._locals)}

    @locals.setter
    def locals(self, locals_: typing.Dict[str, typing.Any]) -> None:
        """Update the contextual mapping of names to values."""
        log.trace(f"Updating {self._locals} with {locals_}")
        self._locals.update(locals_)

    def prepare_eval(self, code: str) -> typing.Optional[str]:
        """Prepare an evaluation by processing the code and setting up the context."""
        self.code = code

        if not self.code:
            log.debug("No code was attached to the evaluation command")
            return "[No code detected]"

        try:
            code_tree = ast.parse(code, filename=INTERNAL_EVAL_FRAMENAME)
        except SyntaxError:
            log.debug("Got a SyntaxError while parsing the eval code")
            return "".join(traceback.format_exception(*sys.exc_info(), limit=0))

        log.trace("Parsing the AST to see if there's a trailing expression we need to capture")
        code_tree = CaptureLastExpression(code_tree).capture()

        log.trace("Wrapping the AST in the AST of the wrapper coroutine")
        eval_tree = WrapEvalCodeTree(code_tree).wrap()

        self.eval_tree = eval_tree
        return None

    async def run_eval(self) -> Namespace:
        """Run the evaluation and return the updated locals."""
        log.trace("Compiling the AST to bytecode using `exec` mode")
        compiled_code = compile(self.eval_tree, filename=INTERNAL_EVAL_FRAMENAME, mode="exec")

        log.trace("Executing the compiled code with the desired namespace environment")
        exec(compiled_code, self.locals)  # noqa: B102,S102

        log.trace("Awaiting the created evaluation wrapper coroutine.")
        await self.function()

        log.trace("Returning the updated captured locals.")
        return self._locals

    def format_output(self) -> str:
        """Format the output of the most recent evaluation."""
        output = []

        log.trace(f"Getting output from stdout `{id(self.stdout)}`")
        stdout_text = self.stdout.getvalue()
        if stdout_text:
            log.trace("Appending output captured from stdout/print")
            output.append(stdout_text)

        if self._value_last_expression is not None:
            log.trace("Appending the output of a captured trialing expression")
            output.append(f"[Captured] {self._value_last_expression!r}")

        if self.exc_info:
            log.trace("Appending exception information")
            output.append(format_internal_eval_exception(self.exc_info, self.code))

        log.trace(f"Generated output: {output!r}")
        return "\n".join(output) or "[No output]"


class WrapEvalCodeTree(ast.NodeTransformer):
    """Wraps the AST of eval code with the wrapper function."""

    def __init__(self, eval_code_tree: ast.AST, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.eval_code_tree = eval_code_tree

        # To avoid mutable aliasing, parse the WRAPPER_FUNC for each wrapping
        self.wrapper = ast.parse(EVAL_WRAPPER, filename=INTERNAL_EVAL_FRAMENAME)

    def wrap(self) -> ast.AST:
        """Wrap the tree of the code by the tree of the wrapper function."""
        new_tree = self.visit(self.wrapper)
        return ast.fix_missing_locations(new_tree)

    def visit_Pass(self, node: ast.Pass) -> typing.List[ast.AST]:  # noqa: N802
        """
        Replace the `_ast.Pass` node in the wrapper function by the eval AST.

        This method works on the assumption that there's a single `pass`
        statement in the wrapper function.
        """
        return list(ast.iter_child_nodes(self.eval_code_tree))


class CaptureLastExpression(ast.NodeTransformer):
    """Captures the return value from a loose expression."""

    def __init__(self, tree: ast.AST, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.tree = tree
        self.last_node = list(ast.iter_child_nodes(tree))[-1]

    def visit_Expr(self, node: ast.Expr) -> typing.Union[ast.Expr, ast.Assign]:  # noqa: N802
        """
        Replace the Expr node that is last child node of Module with an assignment.

        We use an assignment to capture the value of the last node, if it's a loose
        Expr node. Normally, the value of an Expr node is lost, meaning we don't get
        the output of such a last "loose" expression. By assigning it a name, we can
        retrieve it for our output.
        """
        if node is not self.last_node:
            return node

        log.trace("Found a trailing last expression in the evaluation code")

        log.trace("Creating assignment statement with trailing expression as the right-hand side")
        right_hand_side = list(ast.iter_child_nodes(node))[0]

        assignment = ast.Assign(
            targets=[ast.Name(id='_value_last_expression', ctx=ast.Store())],
            value=right_hand_side,
            lineno=node.lineno,
            col_offset=0,
        )
        ast.fix_missing_locations(assignment)
        return assignment

    def capture(self) -> ast.AST:
        """Capture the value of the last expression with an assignment."""
        if not isinstance(self.last_node, ast.Expr):
            # We only have to replace a node if the very last node is an Expr node
            return self.tree

        new_tree = self.visit(self.tree)
        return ast.fix_missing_locations(new_tree)