diff options
| -rw-r--r-- | bot/__main__.py | 12 | ||||
| -rw-r--r-- | bot/cogs/rmq.py | 229 | ||||
| -rw-r--r-- | bot/cogs/snekbox.py | 261 | ||||
| -rw-r--r-- | bot/constants.py | 12 | ||||
| -rw-r--r-- | bot/utils/service_discovery.py | 22 | ||||
| -rw-r--r-- | config-default.yml | 10 | 
6 files changed, 158 insertions, 388 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index 8687cc62c..8afec2718 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -8,7 +8,6 @@ from discord.ext.commands import Bot, when_mentioned_or  from bot.api import APIClient  from bot.constants import Bot as BotConfig, DEBUG_MODE -from bot.utils.service_discovery import wait_for_rmq  log = logging.getLogger(__name__) @@ -31,14 +30,6 @@ bot.http_session = ClientSession(  )  bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.info("Waiting for RabbitMQ...") -has_rmq = wait_for_rmq() - -if has_rmq: -    log.info("RabbitMQ found") -else: -    log.warning("Timed out while waiting for RabbitMQ") -  # Internal/debug  bot.load_extension("bot.cogs.logging")  bot.load_extension("bot.cogs.security") @@ -80,9 +71,6 @@ bot.load_extension("bot.cogs.token_remover")  bot.load_extension("bot.cogs.utils")  bot.load_extension("bot.cogs.wolfram") -if has_rmq: -    bot.load_extension("bot.cogs.rmq") -  bot.run(BotConfig.token)  bot.http_session.close()  # Close the aiohttp session when the bot finishes running diff --git a/bot/cogs/rmq.py b/bot/cogs/rmq.py deleted file mode 100644 index 2742fb969..000000000 --- a/bot/cogs/rmq.py +++ /dev/null @@ -1,229 +0,0 @@ -import asyncio -import datetime -import json -import logging -import pprint - -import aio_pika -from aio_pika import Message -from dateutil import parser as date_parser -from discord import Colour, Embed -from discord.ext.commands import Bot -from discord.utils import get - -from bot.constants import Channels, Guild, RabbitMQ - -log = logging.getLogger(__name__) - -LEVEL_COLOURS = { -    "debug": Colour.blue(), -    "info": Colour.green(), -    "warning": Colour.gold(), -    "error": Colour.red() -} - -DEFAULT_LEVEL_COLOUR = Colour.greyple() -EMBED_PARAMS = ( -    "colour", "title", "url", "description", "timestamp" -) - -CONSUME_TIMEOUT = datetime.timedelta(seconds=10) - - -class RMQ: -    """ -    RabbitMQ event handling -    """ - -    rmq = None  # type: aio_pika.Connection -    channel = None  # type: aio_pika.Channel -    queue = None  # type: aio_pika.Queue - -    def __init__(self, bot: Bot): -        self.bot = bot - -    async def on_ready(self): -        self.rmq = await aio_pika.connect_robust( -            host=RabbitMQ.host, port=RabbitMQ.port, login=RabbitMQ.username, password=RabbitMQ.password -        ) - -        log.info("Connected to RabbitMQ") - -        self.channel = await self.rmq.channel() -        self.queue = await self.channel.declare_queue("bot_events", durable=True) - -        log.debug("Channel opened, queue declared") - -        async for message in self.queue: -            with message.process(): -                message.ack() -                await self.handle_message(message, message.body.decode()) - -    async def send_text(self, queue: str, data: str): -        message = Message(data.encode("utf-8")) -        await self.channel.default_exchange.publish(message, queue) - -    async def send_json(self, queue: str, **data): -        message = Message(json.dumps(data).encode("utf-8")) -        await self.channel.default_exchange.publish(message, queue) - -    async def consume(self, queue: str, **kwargs): -        queue_obj = await self.channel.declare_queue(queue, **kwargs) - -        result = None -        start_time = datetime.datetime.now() - -        while result is None: -            if datetime.datetime.now() - start_time >= CONSUME_TIMEOUT: -                result = "Timed out while waiting for a response." -            else: -                result = await queue_obj.get(timeout=5, fail=False) -                await asyncio.sleep(0.5) - -        if result: -            result.ack() - -        return result - -    async def handle_message(self, message, data): -        log.debug(f"Message: {message}") -        log.debug(f"Data: {data}") - -        try: -            data = json.loads(data) -        except Exception: -            await self.do_mod_log("error", "Unable to parse event", data) -        else: -            event = data["event"] -            event_data = data["data"] - -            try: -                func = getattr(self, f"do_{event}") -                await func(**event_data) -            except Exception as e: -                await self.do_mod_log( -                    "error", f"Unable to handle event: {event}", -                    str(e) -                ) - -    async def do_mod_log(self, level: str, title: str, message: str): -        colour = LEVEL_COLOURS.get(level, DEFAULT_LEVEL_COLOUR) -        embed = Embed( -            title=title, description=f"```\n{message}\n```", -            colour=colour, timestamp=datetime.datetime.now() -        ) - -        await self.bot.get_channel(Channels.modlog).send(embed=embed) -        log.log(logging._nameToLevel[level.upper()], f"Modlog: {title} | {message}") - -    async def do_send_message(self, target: int, message: str): -        channel = self.bot.get_channel(target) - -        if channel is None: -            await self.do_mod_log( -                "error", "Failed: Send Message", -                f"Unable to find channel: {target}" -            ) -        else: -            await channel.send(message) - -            await self.do_mod_log( -                "info", "Succeeded: Send Embed", -                f"Message sent to channel {target}\n\n{message}" -            ) - -    async def do_send_embed(self, target: int, **embed_params): -        for param, value in list(embed_params.items()):  # To keep a full copy -            if param not in EMBED_PARAMS: -                await self.do_mod_log( -                    "warning", "Warning: Send Embed", -                    f"Unknown embed parameter: {param}" -                ) -                del embed_params[param] - -            if param == "timestamp": -                embed_params[param] = date_parser.parse(value) -            elif param == "colour": -                embed_params[param] = Colour(value) - -        channel = self.bot.get_channel(target) - -        if channel is None: -            await self.do_mod_log( -                "error", "Failed: Send Embed", -                f"Unable to find channel: {target}" -            ) -        else: -            await channel.send(embed=Embed(**embed_params)) - -            await self.do_mod_log( -                "info", "Succeeded: Send Embed", -                f"Embed sent to channel {target}\n\n{pprint.pformat(embed_params, 4)}" -            ) - -    async def do_add_role(self, target: int, role_id: int, reason: str): -        guild = self.bot.get_guild(Guild.id) -        member = guild.get_member(int(target)) - -        if member is None: -            return await self.do_mod_log( -                "error", "Failed: Add Role", -                f"Unable to find member: {target}" -            ) - -        role = get(guild.roles, id=int(role_id)) - -        if role is None: -            return await self.do_mod_log( -                "error", "Failed: Add Role", -                f"Unable to find role: {role_id}" -            ) - -        try: -            await member.add_roles(role, reason=reason) -        except Exception as e: -            await self.do_mod_log( -                "error", "Failed: Add Role", -                f"Error while adding role {role.name}: {e}" -            ) -        else: -            await self.do_mod_log( -                "info", "Succeeded: Add Role", -                f"Role {role.name} added to member {target}" -            ) - -    async def do_remove_role(self, target: int, role_id: int, reason: str): -        guild = self.bot.get_guild(Guild.id) -        member = guild.get_member(int(target)) - -        if member is None: -            return await self.do_mod_log( -                "error", "Failed: Remove Role", -                f"Unable to find member: {target}" -            ) - -        role = get(guild.roles, id=int(role_id)) - -        if role is None: -            return await self.do_mod_log( -                "error", "Failed: Remove Role", -                f"Unable to find role: {role_id}" -            ) - -        try: -            await member.remove_roles(role, reason=reason) -        except Exception as e: -            await self.do_mod_log( -                "error", "Failed: Remove Role", -                f"Error while adding role {role.name}: {e}" -            ) -        else: -            await self.do_mod_log( -                "info", "Succeeded: Remove Role", -                f"Role {role.name} removed from member {target}" -            ) - - -def setup(bot): -    bot.add_cog(RMQ(bot)) -    log.info("Cog loaded: RMQ") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index cb0454249..0f8d3e4b6 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -3,13 +3,14 @@ import logging  import random  import re  import textwrap +from signal import Signals +from typing import Optional, Tuple  from discord import Colour, Embed  from discord.ext.commands import (      Bot, CommandError, Context, NoPrivateMessage, command, guild_only  ) -from bot.cogs.rmq import RMQ  from bot.constants import Channels, ERROR_REPLIES, NEGATIVE_REPLIES, Roles, URLs  from bot.decorators import InChannelCheckFailure, in_channel  from bot.utils.messages import wait_for_deletion @@ -17,22 +18,6 @@ from bot.utils.messages import wait_for_deletion  log = logging.getLogger(__name__) -RMQ_ARGS = { -    "durable": False, -    "arguments": {"x-message-ttl": 5000}, -    "auto_delete": True -} - -CODE_TEMPLATE = """ -venv_file = "/snekbox/.venv/bin/activate_this.py" -exec(open(venv_file).read(), dict(__file__=venv_file)) - -try: -{CODE} -except Exception as e: -    print(e) -""" -  ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")  FORMATTED_CODE_REGEX = re.compile(      r"^\s*"                                 # any leading whitespace from the beginning of the string @@ -53,43 +38,47 @@ RAW_CODE_REGEX = re.compile(  )  BYPASS_ROLES = (Roles.owner, Roles.admin, Roles.moderator, Roles.helpers) +MAX_PASTE_LEN = 1000  class Snekbox:      """ -    Safe evaluation using Snekbox +    Safe evaluation of Python code using Snekbox      """      def __init__(self, bot: Bot):          self.bot = bot          self.jobs = {} -    @property -    def rmq(self) -> RMQ: -        return self.bot.get_cog("RMQ") - -    @command(name='eval', aliases=('e',)) -    @guild_only() -    @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES) -    async def eval_command(self, ctx: Context, *, code: str = None): -        """ -        Run some code. get the result back. We've done our best to make this safe, but do let us know if you -        manage to find an issue with it! +    async def post_eval(self, code: str) -> dict: +        """Send a POST request to the Snekbox API to evaluate code and return the results.""" +        url = URLs.snekbox_eval_api +        data = {"input": code} +        async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: +            return await resp.json() -        This command supports multiple lines of code, including code wrapped inside a formatted code block. -        """ +    async def upload_output(self, output: str) -> Optional[str]: +        """Upload the eval output to a paste service and return a URL to it if successful.""" +        log.trace("Uploading full output to paste service...") -        if ctx.author.id in self.jobs: -            await ctx.send(f"{ctx.author.mention} You've already got a job running - please wait for it to finish!") -            return +        if len(output) > MAX_PASTE_LEN: +            log.info("Full output is too long to upload") +            return "too long to upload" -        if not code:  # None or empty string -            return await ctx.invoke(self.bot.get_command("help"), "eval") +        url = URLs.paste_service.format(key="documents") +        try: +            async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: +                data = await resp.json() -        log.info(f"Received code from {ctx.author.name}#{ctx.author.discriminator} for evaluation:\n{code}") -        self.jobs[ctx.author.id] = datetime.datetime.now() +            if "key" in data: +                return URLs.paste_service.format(key=data["key"]) +        except Exception: +            # 400 (Bad Request) means there are too many characters +            log.exception("Failed to upload full output to paste service!") -        # Strip whitespace and inline or block code markdown and extract the code and some formatting info +    @staticmethod +    def prepare_input(code: str) -> str: +        """Extract code from the Markdown, format it, and insert it into the code template."""          match = FORMATTED_CODE_REGEX.fullmatch(code)          if match:              code, block, lang, delim = match.group("code", "block", "lang", "delim") @@ -101,86 +90,140 @@ class Snekbox:              log.trace(f"Extracted {info} for evaluation:\n{code}")          else:              code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) -            log.trace(f"Eval message contains not or badly formatted code, stripping whitespace only:\n{code}") +            log.trace( +                f"Eval message contains unformatted or badly formatted code, " +                f"stripping whitespace only:\n{code}" +            ) -        code = textwrap.indent(code, "    ") -        code = CODE_TEMPLATE.replace("{CODE}", code) +        return code + +    @staticmethod +    def get_results_message(results: dict) -> Tuple[str, str]: +        """Return a user-friendly message and error corresponding to the process's return code.""" +        stdout, returncode = results["stdout"], results["returncode"] +        msg = f"Your eval job has completed with return code {returncode}" +        error = "" + +        if returncode is None: +            msg = "Your eval job has failed" +            error = stdout.strip() +        elif returncode == 128 + Signals.SIGKILL: +            msg = "Your eval job timed out or ran out of memory" +        elif returncode == 255: +            msg = "Your eval job has failed" +            error = "A fatal NsJail error occurred" +        else: +            # Try to append signal's name if one exists +            try: +                name = Signals(returncode - 128).name +                msg = f"{msg} ({name})" +            except ValueError: +                pass -        try: -            await self.rmq.send_json( -                "input", -                snekid=str(ctx.author.id), message=code -            ) +        return msg, error -            async with ctx.typing(): -                message = await self.rmq.consume(str(ctx.author.id), **RMQ_ARGS) -                paste_link = None +    async def format_output(self, output: str) -> Tuple[str, Optional[str]]: +        """ +        Format the output and return a tuple of the formatted output and a URL to the full output. -                if isinstance(message, str): -                    output = str.strip(" \n") -                else: -                    output = message.body.decode().strip(" \n") +        Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters +        and upload the full output to a paste service. +        """ +        log.trace("Formatting output...") -                if "<@" in output: -                    output = output.replace("<@", "<@\u200B")  # Zero-width space +        output = output.strip(" \n") +        original_output = output  # To be uploaded to a pasting service if needed +        paste_link = None -                if "<!@" in output: -                    output = output.replace("<!@", "<!@\u200B")  # Zero-width space +        if "<@" in output: +            output = output.replace("<@", "<@\u200B")  # Zero-width space -                if ESCAPE_REGEX.findall(output): -                    output = "Code block escape attempt detected; will not output result" -                else: -                    # the original output, to send to a pasting service if needed -                    full_output = output -                    truncated = False -                    if output.count("\n") > 0: -                        output = [f"{i:03d} | {line}" for i, line in enumerate(output.split("\n"), start=1)] -                        output = "\n".join(output) - -                    if output.count("\n") > 10: -                        output = "\n".join(output.split("\n")[:10]) - -                        if len(output) >= 1000: -                            output = f"{output[:1000]}\n... (truncated - too long, too many lines)" -                        else: -                            output = f"{output}\n... (truncated - too many lines)" -                        truncated = True - -                    elif len(output) >= 1000: -                        output = f"{output[:1000]}\n... (truncated - too long)" -                        truncated = True - -                    if truncated: -                        try: -                            response = await self.bot.http_session.post( -                                URLs.paste_service.format(key="documents"), -                                data=full_output -                            ) -                            data = await response.json() -                            if "key" in data: -                                paste_link = URLs.paste_service.format(key=data["key"]) -                        except Exception: -                            log.exception("Failed to upload full output to paste service!") - -                if output.strip(): -                    if paste_link: -                        msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```" \ -                              f"\nFull output: {paste_link}" -                    else: -                        msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```" - -                    response = await ctx.send(msg) -                    self.bot.loop.create_task(wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)) +        if "<!@" in output: +            output = output.replace("<!@", "<!@\u200B")  # Zero-width space -                else: -                    await ctx.send( -                        f"{ctx.author.mention} Your eval job has completed.\n\n```py\n[No output]\n```" -                    ) +        if ESCAPE_REGEX.findall(output): +            return "Code block escape attempt detected; will not output result", paste_link +        truncated = False +        lines = output.count("\n") + +        if lines > 0: +            output = output.split("\n")[:10]  # Only first 10 cause the rest is truncated anyway +            output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1)) +            output = "\n".join(output) + +        if lines > 10: +            truncated = True +            if len(output) >= 1000: +                output = f"{output[:1000]}\n... (truncated - too long, too many lines)" +            else: +                output = f"{output}\n... (truncated - too many lines)" +        elif len(output) >= 1000: +            truncated = True +            output = f"{output[:1000]}\n... (truncated - too long)" + +        if truncated: +            paste_link = await self.upload_output(original_output) + +        output = output.strip() +        if not output: +            output = "[No output]" + +        return output, paste_link + +    @command(name="eval", aliases=("e",)) +    @guild_only() +    @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES) +    async def eval_command(self, ctx: Context, *, code: str = None): +        """ +        Run Python code and get the results. + +        This command supports multiple lines of code, including code wrapped inside a formatted code +        block. We've done our best to make this safe, but do let us know if you manage to find an +        issue with it! +        """ +        if ctx.author.id in self.jobs: +            return await ctx.send( +                f"{ctx.author.mention} You've already got a job running - " +                f"please wait for it to finish!" +            ) + +        if not code:  # None or empty string +            return await ctx.invoke(self.bot.get_command("help"), "eval") + +        log.info( +            f"Received code from {ctx.author.name}#{ctx.author.discriminator} " +            f"for evaluation:\n{code}" +        ) + +        self.jobs[ctx.author.id] = datetime.datetime.now() +        code = self.prepare_input(code) + +        try: +            async with ctx.typing(): +                results = await self.post_eval(code) +                msg, error = self.get_results_message(results) + +                if error: +                    output, paste_link = error, None +                else: +                    output, paste_link = await self.format_output(results["stdout"]) + +                msg = f"{ctx.author.mention} {msg}.\n\n```py\n{output}\n```" +                if paste_link: +                    msg = f"{msg}\nFull output: {paste_link}" + +                response = await ctx.send(msg) +                self.bot.loop.create_task( +                    wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) +                ) + +                log.info( +                    f"{ctx.author.name}#{ctx.author.discriminator}'s job had a return code of " +                    f"{results['returncode']}" +                ) +        finally:              del self.jobs[ctx.author.id] -        except Exception: -            del self.jobs[ctx.author.id] -            raise      @eval_command.error      async def eval_command_error(self, ctx: Context, error: CommandError): diff --git a/bot/constants.py b/bot/constants.py index d2e3bb315..0bd950a7d 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -374,18 +374,12 @@ class Keys(metaclass=YAMLGetter):      youtube: str -class RabbitMQ(metaclass=YAMLGetter): -    section = "rabbitmq" - -    host: str -    password: str -    port: int -    username: str - -  class URLs(metaclass=YAMLGetter):      section = "urls" +    # Snekbox endpoints +    snekbox_eval_api: str +      # Discord API endpoints      discord_api: str      discord_invite_api: str diff --git a/bot/utils/service_discovery.py b/bot/utils/service_discovery.py deleted file mode 100644 index 8d79096bd..000000000 --- a/bot/utils/service_discovery.py +++ /dev/null @@ -1,22 +0,0 @@ -import datetime -import socket -import time -from contextlib import closing - -from bot.constants import RabbitMQ - -THIRTY_SECONDS = datetime.timedelta(seconds=30) - - -def wait_for_rmq(): -    start = datetime.datetime.now() - -    while True: -        if datetime.datetime.now() - start > THIRTY_SECONDS: -            return False - -        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: -            if sock.connect_ex((RabbitMQ.host, RabbitMQ.port)) == 0: -                return True - -        time.sleep(0.5) diff --git a/config-default.yml b/config-default.yml index f6481cfcd..7854b5db9 100644 --- a/config-default.yml +++ b/config-default.yml @@ -204,13 +204,6 @@ keys:      youtube:     !ENV "YOUTUBE_API_KEY" -rabbitmq: -    host:          "pdrmq" -    password: !ENV ["RABBITMQ_DEFAULT_PASS", "guest"] -    port:          5672 -    username: !ENV ["RABBITMQ_DEFAULT_USER", "guest"] - -  urls:      # PyDis site vars      site:        &DOMAIN       "pythondiscord.com" @@ -242,6 +235,9 @@ urls:      site_user_complete_api:             !JOIN [*SCHEMA, *API, "/bot/users/complete"]      paste_service:                      !JOIN [*SCHEMA, *PASTE, "/{key}"] +    # Snekbox +    snekbox_eval_api: "http://localhost:8060/eval" +      # Env vars      deploy: !ENV "DEPLOY_URL"      status: !ENV "STATUS_URL" | 
