diff options
| -rw-r--r-- | bot/exts/internal_eval/__init__.py | 10 | ||||
| -rw-r--r-- | bot/exts/internal_eval/_helpers.py | 243 | ||||
| -rw-r--r-- | bot/exts/internal_eval/_internal_eval.py | 152 | 
3 files changed, 405 insertions, 0 deletions
| diff --git a/bot/exts/internal_eval/__init__.py b/bot/exts/internal_eval/__init__.py index e69de29b..695fa74d 100644 --- a/bot/exts/internal_eval/__init__.py +++ b/bot/exts/internal_eval/__init__.py @@ -0,0 +1,10 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: +    """Set up the Internal Eval extension.""" +    # Import the Cog at runtime to prevent side effects like defining +    # RedisCache instances too early. +    from ._internal_eval import InternalEval + +    bot.add_cog(InternalEval(bot)) diff --git a/bot/exts/internal_eval/_helpers.py b/bot/exts/internal_eval/_helpers.py new file mode 100644 index 00000000..5c602e4d --- /dev/null +++ b/bot/exts/internal_eval/_helpers.py @@ -0,0 +1,243 @@ +import ast +import collections +import contextlib +import functools +import inspect +import io +import logging +import sys +import traceback +import types +import typing + + +log = logging.getLogger("rattlesnake.exts.admin_tools.internal_eval") + +# A type alias to annotate the tuples returned from `sys.exc_info()` +ExcInfo = typing.Tuple[typing.Type[Exception], Exception, types.TracebackType] +Namespace = typing.Dict[str, typing.Any] + +# This will be used as an coroutine function wrapper for the code +# to be evaluated. The wrapper contains one `pass` statement which +# will be replaced with `ast` with the code that we want to have +# evaluated. +# The function redirects output and captures exceptions that were +# raised in the code we evaluate. The latter is used to provide a +# meaningful traceback to the end user. +EVAL_WRAPPER = """ +async def _eval_wrapper_function(): +    try: +        with contextlib.redirect_stdout(_eval_context.stdout): +            pass +        if '_value_last_expression' in locals(): +            if inspect.isawaitable(_value_last_expression): +                _value_last_expression = await _value_last_expression +            _eval_context._value_last_expression = _value_last_expression +        else: +            _eval_context._value_last_expression = None +    except Exception: +        _eval_context.exc_info = sys.exc_info() +    finally: +        _eval_context.locals = locals() +_eval_context.function = _eval_wrapper_function +""" + + +def format_internal_eval_exception(exc_info: ExcInfo, code: str) -> str: +    """Format an exception caught while evaluation code by inserting lines.""" +    exc_type, exc_value, tb = exc_info +    stack_summary = traceback.StackSummary.extract(traceback.walk_tb(tb)) +    code = code.split("\n") + +    output = ["Traceback (most recent call last):"] +    for frame in stack_summary: +        if frame.filename == "<internal eval>": +            line = code[frame.lineno - 1].lstrip() + +            if frame.name == "_eval_wrapper_function": +                name = "<internal eval>" +            else: +                name = frame.name +        else: +            line = frame.line +            name = frame.name + +        output.append( +            f'  File "{frame.filename}", line {frame.lineno}, in {name}\n' +            f"    {line}" +        ) + +    output.extend(traceback.format_exception_only(exc_type, exc_value)) +    return "\n".join(output) + + +class EvalContext: +    """ +    Represents the current `internal eval` context. +    The context remembers names set during earlier runs of `internal eval`. To +    clear the context, use the `?internal clear` command. +    """ + +    def __init__(self, context_vars: Namespace, local_vars: Namespace) -> None: +        self._locals = dict(local_vars) +        self.context_vars = dict(context_vars) + +        self.stdout = io.StringIO() +        self._value_last_expression = None +        self.exc_info = None +        self.code = "" +        self.function = None +        self.eval_tree = None + +    @property +    def dependencies(self) -> typing.Dict[str, typing.Any]: +        """ +        Return a mapping of the dependencies for the wrapper function. +        By using a property descriptor, the mapping can't be accidentally +        mutated during evaluation. This ensures the dependencies are always +        available. +        """ +        return { +            "print": functools.partial(print, file=self.stdout), +            "contextlib": contextlib, +            "inspect": inspect, +            "sys": sys, +            "_eval_context": self, +            "_": self._value_last_expression, +        } + +    @property +    def locals(self) -> typing.Dict[str, typing.Any]: +        """Return a mapping of names->values needed for evaluation.""" +        return {**collections.ChainMap(self.dependencies, self.context_vars, self._locals)} + +    @locals.setter +    def locals(self, locals_: typing.Dict[str, typing.Any]) -> None: +        """Update the contextual mapping of names to values.""" +        log.trace(f"Updating {self._locals} with {locals_}") +        self._locals.update(locals_) + +    def prepare_eval(self, code: str) -> typing.Optional[str]: +        """Prepare an evaluation by processing the code and setting up the context.""" +        self.code = code + +        if not self.code: +            log.debug("No code was attached to the evaluation command") +            return "[No code detected]" + +        try: +            code_tree = ast.parse(code, filename="<internal eval>") +        except SyntaxError: +            log.debug("Got a SyntaxError while parsing the eval code") +            return "".join(traceback.format_exception(*sys.exc_info(), limit=0)) + +        log.trace("Parsing the AST to see if there's a trailing expression we need to capture") +        code_tree = CaptureLastExpression(code_tree).capture() + +        log.trace("Wrapping the AST in the AST of the wrapper coroutine") +        eval_tree = WrapEvalCodeTree(code_tree).wrap() + +        self.eval_tree = eval_tree +        return None + +    async def run_eval(self) -> Namespace: +        """Run the evaluation and return the updated locals.""" +        log.trace("Compiling the AST to bytecode using `exec` mode") +        compiled_code = compile(self.eval_tree, filename="<internal eval>", mode="exec") + +        log.trace("Executing the compiled code with the desired namespace environment") +        exec(compiled_code, self.locals)  # noqa: B102,S102 + +        log.trace("Awaiting the created evaluation wrapper coroutine.") +        await self.function() + +        log.trace("Returning the updated captured locals.") +        return self._locals + +    def format_output(self) -> str: +        """Format the output of the most recent evaluation.""" +        output = [] + +        log.trace(f"Getting output from stdout `{id(self.stdout)}`") +        stdout_text = self.stdout.getvalue() +        if stdout_text: +            log.trace("Appending output captured from stdout/print") +            output.append(stdout_text) + +        if self._value_last_expression is not None: +            log.trace("Appending the output of a captured trialing expression") +            output.append(f"[Captured] {self._value_last_expression!r}") + +        if self.exc_info: +            log.trace("Appending exception information") +            output.append(format_internal_eval_exception(self.exc_info, self.code)) + +        log.trace(f"Generated output: {output!r}") +        return "\n".join(output) or "[No output]" + + +class WrapEvalCodeTree(ast.NodeTransformer): +    """Wraps the AST of eval code with the wrapper function.""" + +    def __init__(self, eval_code_tree: ast.AST, *args, **kwargs) -> None: +        super().__init__(*args, **kwargs) +        self.eval_code_tree = eval_code_tree + +        # To avoid mutable aliasing, parse the WRAPPER_FUNC for each wrapping +        self.wrapper = ast.parse(EVAL_WRAPPER, filename="<internal eval>") + +    def wrap(self) -> ast.AST: +        """Wrap the tree of the code by the tree of the wrapper function.""" +        new_tree = self.visit(self.wrapper) +        return ast.fix_missing_locations(new_tree) + +    def visit_Pass(self, node: ast.Pass) -> typing.List[ast.AST]:  # noqa: N802 +        """ +        Replace the `_ast.Pass` node in the wrapper function by the eval AST. +        This method works on the assumption that there's a single `pass` +        statement in the wrapper function. +        """ +        return list(ast.iter_child_nodes(self.eval_code_tree)) + + +class CaptureLastExpression(ast.NodeTransformer): +    """Captures the return value from a loose expression.""" + +    def __init__(self, tree: ast.AST, *args, **kwargs) -> None: +        super().__init__(*args, **kwargs) +        self.tree = tree +        self.last_node = list(ast.iter_child_nodes(tree))[-1] + +    def visit_Expr(self, node: ast.Expr) -> typing.Union[ast.Expr, ast.Assign]:  # noqa: N802 +        """ +        Replace the Expr node that is last child node of Module with an assignment. +        We use an assignment to capture the value of the last node, if it's a loose +        Expr node. Normally, the value of an Expr node is lost, meaning we don't get +        the output of such a last "loose" expression. By assigning it a name, we can +        retrieve it for our output. +        """ +        if node is not self.last_node: +            return node + +        log.trace("Found a trailing last expression in the evaluation code") + +        log.trace("Creating assignment statement with trailing expression as the right-hand side") +        right_hand_side = list(ast.iter_child_nodes(node))[0] + +        assignment = ast.Assign( +            targets=[ast.Name(id='_value_last_expression', ctx=ast.Store())], +            value=right_hand_side, +            lineno=node.lineno, +            col_offset=0, +        ) +        ast.fix_missing_locations(assignment) +        return assignment + +    def capture(self) -> ast.AST: +        """Capture the value of the last expression with an assignment.""" +        if not isinstance(self.last_node, ast.Expr): +            # We only have to replace a node if the very last node is an Expr node +            return self.tree + +        new_tree = self.visit(self.tree) +        return ast.fix_missing_locations(new_tree) diff --git a/bot/exts/internal_eval/_internal_eval.py b/bot/exts/internal_eval/_internal_eval.py new file mode 100644 index 00000000..f7a0946b --- /dev/null +++ b/bot/exts/internal_eval/_internal_eval.py @@ -0,0 +1,152 @@ +import logging +import re +import textwrap +import typing + +import discord +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import Roles +from bot.utils.decorators import with_role +from ._helpers import EvalContext + +__all__ = ["InternalEval"] + +log = logging.getLogger("rattlesnake.exts.admin_tools.internal_eval") + +CODEBLOCK_REGEX = re.compile(r"(^```(py(thon)?)?\n)|(```$)") + + +class InternalEval(commands.Cog): +    """Top secret code evaluation for admins and owners.""" + +    def __init__(self, bot: Bot): +        self.bot = bot +        self.locals = {} + +    @staticmethod +    def shorten_output( +            output: str, +            max_length: int = 1900, +            placeholder: str = "\n[output truncated]" +    ) -> str: +        """ +        Shorten the `output` so it's shorter than `max_length`. +        There are three tactics for this, tried in the following order: +        - Shorten the output on a line-by-line basis +        - Shorten the output on any whitespace character +        - Shorten the output solely on character count +        """ +        max_length = max_length - len(placeholder) + +        shortened_output = [] +        char_count = 0 +        for line in output.split("\n"): +            if char_count + len(line) > max_length: +                break +            shortened_output.append(line) +            char_count += len(line) + 1  # account for (possible) line ending + +        if shortened_output: +            shortened_output.append(placeholder) +            return "\n".join(shortened_output) + +        shortened_output = textwrap.shorten(output, width=max_length, placeholder=placeholder) + +        if shortened_output.strip() == placeholder.strip(): +            # `textwrap` was unable to find whitespace to shorten on, so it has +            # reduced the output to just the placeholder. Let's shorten based on +            # characters instead. +            shortened_output = output[:max_length] + placeholder + +        return shortened_output + +    async def _upload_output(self, output: str) -> typing.Optional[str]: +        """Upload `internal eval` output to our pastebin and return the url.""" +        try: +            async with self.bot.http_session.post( +                "https://paste.pythondiscord.com/documents", data=output, raise_for_status=True +            ) as resp: +                data = await resp.json() + +            if "key" in data: +                return f"https://paste.pythondiscord.com/{data['key']}" +        except Exception: +            # 400 (Bad Request) means there are too many characters +            log.exception("Failed to upload `internal eval` output to paste service!") + +    async def _send_output(self, ctx: commands.Context, output: str) -> None: +        """Send the `internal eval` output to the command invocation context.""" +        upload_message = "" +        if len(output) >= 1980: +            # The output is too long, let's truncate it for in-channel output and +            # upload the complete output to the paste service. +            url = await self._upload_output(output) + +            if url: +                upload_message = f"\nFull output here: {url}" +            else: +                upload_message = "\n:warning: Failed to upload full output!" + +            output = self.shorten_output(output) + +        await ctx.send(f"```py\n{output}```{upload_message}") + +    async def _eval(self, ctx: commands.Context, code: str) -> None: +        """Evaluate the `code` in the current evaluation context.""" +        if code.startswith("exit"): +            self.locals = {} +            await ctx.send("The evaluation context was reset.") +            return + +        context_vars = { +            "message": ctx.message, +            "author": ctx.message.author, +            "channel": ctx.channel, +            "guild": ctx.guild, +            "ctx": ctx, +            "self": self, +            "bot": self.bot, +            "discord": discord, +        } + +        eval_context = EvalContext(context_vars, self.locals) + +        log.trace("Preparing the evaluation by parsing the AST of the code") +        error = eval_context.prepare_eval(code) + +        if error: +            log.trace("The code can't be evaluated due to an error") +            await ctx.send(f"```py\n{error}\n```") +            return + +        log.trace("Evaluate the AST we've generated for the evaluation") +        new_locals = await eval_context.run_eval() + +        log.trace("Updating locals with those set during evaluation") +        self.locals.update(new_locals) + +        log.trace("Sending the formatted output back to the context") +        await self._send_output(ctx, eval_context.format_output()) + +    @commands.group(name='internal', aliases=('int',)) +    @with_role(Roles.admin) +    async def internal_group(self, ctx: commands.Context) -> None: +        """Internal commands. Top secret!""" +        if not ctx.invoked_subcommand: +            await ctx.send_help(ctx.command) + +    @internal_group.command(name='eval', aliases=('e',)) +    @with_role(Roles.admin) +    async def eval(self, ctx: commands.Context, *, code: str) -> None: +        """Run eval in a REPL-like format.""" +        code = CODEBLOCK_REGEX.sub("", code.strip()) +        await self._eval(ctx, code) + +    @internal_group.command(name='reset', aliases=("clear", "exit", "r", "c")) +    @with_role(Roles.admin) +    async def reset(self, ctx: commands.Context) -> None: +        """Run eval in a REPL-like format.""" +        self.locals = {} +        await ctx.send("The evaluation context was reset.") | 
