diff options
author | 2019-06-22 15:06:16 +0200 | |
---|---|---|
committer | 2019-06-22 15:06:16 +0200 | |
commit | 25640adec9d042ccf249a91540fb09d354b04dfd (patch) | |
tree | 62cf2a21079510afd5d5a884006ddf04269780e7 | |
parent | Merge pull request #371 from python-discord/django-appeals (diff) | |
parent | Snekbox: limit paste service uploads to 1000 characters (diff) |
Merge pull request #372 from python-discord/snekbox
Replace RabbitMQ with Snekbox API
-rw-r--r-- | bot/__main__.py | 12 | ||||
-rw-r--r-- | bot/cogs/rmq.py | 229 | ||||
-rw-r--r-- | bot/cogs/snekbox.py | 261 | ||||
-rw-r--r-- | bot/constants.py | 12 | ||||
-rw-r--r-- | bot/utils/service_discovery.py | 22 | ||||
-rw-r--r-- | config-default.yml | 10 |
6 files changed, 158 insertions, 388 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index 8687cc62c..8afec2718 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -8,7 +8,6 @@ from discord.ext.commands import Bot, when_mentioned_or from bot.api import APIClient from bot.constants import Bot as BotConfig, DEBUG_MODE -from bot.utils.service_discovery import wait_for_rmq log = logging.getLogger(__name__) @@ -31,14 +30,6 @@ bot.http_session = ClientSession( ) bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.info("Waiting for RabbitMQ...") -has_rmq = wait_for_rmq() - -if has_rmq: - log.info("RabbitMQ found") -else: - log.warning("Timed out while waiting for RabbitMQ") - # Internal/debug bot.load_extension("bot.cogs.logging") bot.load_extension("bot.cogs.security") @@ -80,9 +71,6 @@ bot.load_extension("bot.cogs.token_remover") bot.load_extension("bot.cogs.utils") bot.load_extension("bot.cogs.wolfram") -if has_rmq: - bot.load_extension("bot.cogs.rmq") - bot.run(BotConfig.token) bot.http_session.close() # Close the aiohttp session when the bot finishes running diff --git a/bot/cogs/rmq.py b/bot/cogs/rmq.py deleted file mode 100644 index 2742fb969..000000000 --- a/bot/cogs/rmq.py +++ /dev/null @@ -1,229 +0,0 @@ -import asyncio -import datetime -import json -import logging -import pprint - -import aio_pika -from aio_pika import Message -from dateutil import parser as date_parser -from discord import Colour, Embed -from discord.ext.commands import Bot -from discord.utils import get - -from bot.constants import Channels, Guild, RabbitMQ - -log = logging.getLogger(__name__) - -LEVEL_COLOURS = { - "debug": Colour.blue(), - "info": Colour.green(), - "warning": Colour.gold(), - "error": Colour.red() -} - -DEFAULT_LEVEL_COLOUR = Colour.greyple() -EMBED_PARAMS = ( - "colour", "title", "url", "description", "timestamp" -) - -CONSUME_TIMEOUT = datetime.timedelta(seconds=10) - - -class RMQ: - """ - RabbitMQ event handling - """ - - rmq = None # type: aio_pika.Connection - channel = None # type: aio_pika.Channel - queue = None # type: aio_pika.Queue - - def __init__(self, bot: Bot): - self.bot = bot - - async def on_ready(self): - self.rmq = await aio_pika.connect_robust( - host=RabbitMQ.host, port=RabbitMQ.port, login=RabbitMQ.username, password=RabbitMQ.password - ) - - log.info("Connected to RabbitMQ") - - self.channel = await self.rmq.channel() - self.queue = await self.channel.declare_queue("bot_events", durable=True) - - log.debug("Channel opened, queue declared") - - async for message in self.queue: - with message.process(): - message.ack() - await self.handle_message(message, message.body.decode()) - - async def send_text(self, queue: str, data: str): - message = Message(data.encode("utf-8")) - await self.channel.default_exchange.publish(message, queue) - - async def send_json(self, queue: str, **data): - message = Message(json.dumps(data).encode("utf-8")) - await self.channel.default_exchange.publish(message, queue) - - async def consume(self, queue: str, **kwargs): - queue_obj = await self.channel.declare_queue(queue, **kwargs) - - result = None - start_time = datetime.datetime.now() - - while result is None: - if datetime.datetime.now() - start_time >= CONSUME_TIMEOUT: - result = "Timed out while waiting for a response." - else: - result = await queue_obj.get(timeout=5, fail=False) - await asyncio.sleep(0.5) - - if result: - result.ack() - - return result - - async def handle_message(self, message, data): - log.debug(f"Message: {message}") - log.debug(f"Data: {data}") - - try: - data = json.loads(data) - except Exception: - await self.do_mod_log("error", "Unable to parse event", data) - else: - event = data["event"] - event_data = data["data"] - - try: - func = getattr(self, f"do_{event}") - await func(**event_data) - except Exception as e: - await self.do_mod_log( - "error", f"Unable to handle event: {event}", - str(e) - ) - - async def do_mod_log(self, level: str, title: str, message: str): - colour = LEVEL_COLOURS.get(level, DEFAULT_LEVEL_COLOUR) - embed = Embed( - title=title, description=f"```\n{message}\n```", - colour=colour, timestamp=datetime.datetime.now() - ) - - await self.bot.get_channel(Channels.modlog).send(embed=embed) - log.log(logging._nameToLevel[level.upper()], f"Modlog: {title} | {message}") - - async def do_send_message(self, target: int, message: str): - channel = self.bot.get_channel(target) - - if channel is None: - await self.do_mod_log( - "error", "Failed: Send Message", - f"Unable to find channel: {target}" - ) - else: - await channel.send(message) - - await self.do_mod_log( - "info", "Succeeded: Send Embed", - f"Message sent to channel {target}\n\n{message}" - ) - - async def do_send_embed(self, target: int, **embed_params): - for param, value in list(embed_params.items()): # To keep a full copy - if param not in EMBED_PARAMS: - await self.do_mod_log( - "warning", "Warning: Send Embed", - f"Unknown embed parameter: {param}" - ) - del embed_params[param] - - if param == "timestamp": - embed_params[param] = date_parser.parse(value) - elif param == "colour": - embed_params[param] = Colour(value) - - channel = self.bot.get_channel(target) - - if channel is None: - await self.do_mod_log( - "error", "Failed: Send Embed", - f"Unable to find channel: {target}" - ) - else: - await channel.send(embed=Embed(**embed_params)) - - await self.do_mod_log( - "info", "Succeeded: Send Embed", - f"Embed sent to channel {target}\n\n{pprint.pformat(embed_params, 4)}" - ) - - async def do_add_role(self, target: int, role_id: int, reason: str): - guild = self.bot.get_guild(Guild.id) - member = guild.get_member(int(target)) - - if member is None: - return await self.do_mod_log( - "error", "Failed: Add Role", - f"Unable to find member: {target}" - ) - - role = get(guild.roles, id=int(role_id)) - - if role is None: - return await self.do_mod_log( - "error", "Failed: Add Role", - f"Unable to find role: {role_id}" - ) - - try: - await member.add_roles(role, reason=reason) - except Exception as e: - await self.do_mod_log( - "error", "Failed: Add Role", - f"Error while adding role {role.name}: {e}" - ) - else: - await self.do_mod_log( - "info", "Succeeded: Add Role", - f"Role {role.name} added to member {target}" - ) - - async def do_remove_role(self, target: int, role_id: int, reason: str): - guild = self.bot.get_guild(Guild.id) - member = guild.get_member(int(target)) - - if member is None: - return await self.do_mod_log( - "error", "Failed: Remove Role", - f"Unable to find member: {target}" - ) - - role = get(guild.roles, id=int(role_id)) - - if role is None: - return await self.do_mod_log( - "error", "Failed: Remove Role", - f"Unable to find role: {role_id}" - ) - - try: - await member.remove_roles(role, reason=reason) - except Exception as e: - await self.do_mod_log( - "error", "Failed: Remove Role", - f"Error while adding role {role.name}: {e}" - ) - else: - await self.do_mod_log( - "info", "Succeeded: Remove Role", - f"Role {role.name} removed from member {target}" - ) - - -def setup(bot): - bot.add_cog(RMQ(bot)) - log.info("Cog loaded: RMQ") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index cb0454249..0f8d3e4b6 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -3,13 +3,14 @@ import logging import random import re import textwrap +from signal import Signals +from typing import Optional, Tuple from discord import Colour, Embed from discord.ext.commands import ( Bot, CommandError, Context, NoPrivateMessage, command, guild_only ) -from bot.cogs.rmq import RMQ from bot.constants import Channels, ERROR_REPLIES, NEGATIVE_REPLIES, Roles, URLs from bot.decorators import InChannelCheckFailure, in_channel from bot.utils.messages import wait_for_deletion @@ -17,22 +18,6 @@ from bot.utils.messages import wait_for_deletion log = logging.getLogger(__name__) -RMQ_ARGS = { - "durable": False, - "arguments": {"x-message-ttl": 5000}, - "auto_delete": True -} - -CODE_TEMPLATE = """ -venv_file = "/snekbox/.venv/bin/activate_this.py" -exec(open(venv_file).read(), dict(__file__=venv_file)) - -try: -{CODE} -except Exception as e: - print(e) -""" - ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") FORMATTED_CODE_REGEX = re.compile( r"^\s*" # any leading whitespace from the beginning of the string @@ -53,43 +38,47 @@ RAW_CODE_REGEX = re.compile( ) BYPASS_ROLES = (Roles.owner, Roles.admin, Roles.moderator, Roles.helpers) +MAX_PASTE_LEN = 1000 class Snekbox: """ - Safe evaluation using Snekbox + Safe evaluation of Python code using Snekbox """ def __init__(self, bot: Bot): self.bot = bot self.jobs = {} - @property - def rmq(self) -> RMQ: - return self.bot.get_cog("RMQ") - - @command(name='eval', aliases=('e',)) - @guild_only() - @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES) - async def eval_command(self, ctx: Context, *, code: str = None): - """ - Run some code. get the result back. We've done our best to make this safe, but do let us know if you - manage to find an issue with it! + async def post_eval(self, code: str) -> dict: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + url = URLs.snekbox_eval_api + data = {"input": code} + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return await resp.json() - This command supports multiple lines of code, including code wrapped inside a formatted code block. - """ + async def upload_output(self, output: str) -> Optional[str]: + """Upload the eval output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") - if ctx.author.id in self.jobs: - await ctx.send(f"{ctx.author.mention} You've already got a job running - please wait for it to finish!") - return + if len(output) > MAX_PASTE_LEN: + log.info("Full output is too long to upload") + return "too long to upload" - if not code: # None or empty string - return await ctx.invoke(self.bot.get_command("help"), "eval") + url = URLs.paste_service.format(key="documents") + try: + async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp: + data = await resp.json() - log.info(f"Received code from {ctx.author.name}#{ctx.author.discriminator} for evaluation:\n{code}") - self.jobs[ctx.author.id] = datetime.datetime.now() + if "key" in data: + return URLs.paste_service.format(key=data["key"]) + except Exception: + # 400 (Bad Request) means there are too many characters + log.exception("Failed to upload full output to paste service!") - # Strip whitespace and inline or block code markdown and extract the code and some formatting info + @staticmethod + def prepare_input(code: str) -> str: + """Extract code from the Markdown, format it, and insert it into the code template.""" match = FORMATTED_CODE_REGEX.fullmatch(code) if match: code, block, lang, delim = match.group("code", "block", "lang", "delim") @@ -101,86 +90,140 @@ class Snekbox: log.trace(f"Extracted {info} for evaluation:\n{code}") else: code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) - log.trace(f"Eval message contains not or badly formatted code, stripping whitespace only:\n{code}") + log.trace( + f"Eval message contains unformatted or badly formatted code, " + f"stripping whitespace only:\n{code}" + ) - code = textwrap.indent(code, " ") - code = CODE_TEMPLATE.replace("{CODE}", code) + return code + + @staticmethod + def get_results_message(results: dict) -> Tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + stdout, returncode = results["stdout"], results["returncode"] + msg = f"Your eval job has completed with return code {returncode}" + error = "" + + if returncode is None: + msg = "Your eval job has failed" + error = stdout.strip() + elif returncode == 128 + Signals.SIGKILL: + msg = "Your eval job timed out or ran out of memory" + elif returncode == 255: + msg = "Your eval job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + try: + name = Signals(returncode - 128).name + msg = f"{msg} ({name})" + except ValueError: + pass - try: - await self.rmq.send_json( - "input", - snekid=str(ctx.author.id), message=code - ) + return msg, error - async with ctx.typing(): - message = await self.rmq.consume(str(ctx.author.id), **RMQ_ARGS) - paste_link = None + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. - if isinstance(message, str): - output = str.strip(" \n") - else: - output = message.body.decode().strip(" \n") + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + log.trace("Formatting output...") - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space + output = output.strip(" \n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None - if "<!@" in output: - output = output.replace("<!@", "<!@\u200B") # Zero-width space + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space - if ESCAPE_REGEX.findall(output): - output = "Code block escape attempt detected; will not output result" - else: - # the original output, to send to a pasting service if needed - full_output = output - truncated = False - if output.count("\n") > 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split("\n"), start=1)] - output = "\n".join(output) - - if output.count("\n") > 10: - output = "\n".join(output.split("\n")[:10]) - - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - truncated = True - - elif len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long)" - truncated = True - - if truncated: - try: - response = await self.bot.http_session.post( - URLs.paste_service.format(key="documents"), - data=full_output - ) - data = await response.json() - if "key" in data: - paste_link = URLs.paste_service.format(key=data["key"]) - except Exception: - log.exception("Failed to upload full output to paste service!") - - if output.strip(): - if paste_link: - msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```" \ - f"\nFull output: {paste_link}" - else: - msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```" - - response = await ctx.send(msg) - self.bot.loop.create_task(wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)) + if "<!@" in output: + output = output.replace("<!@", "<!@\u200B") # Zero-width space - else: - await ctx.send( - f"{ctx.author.mention} Your eval job has completed.\n\n```py\n[No output]\n```" - ) + if ESCAPE_REGEX.findall(output): + return "Code block escape attempt detected; will not output result", paste_link + truncated = False + lines = output.count("\n") + + if lines > 0: + output = output.split("\n")[:10] # Only first 10 cause the rest is truncated anyway + output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1)) + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output.strip() + if not output: + output = "[No output]" + + return output, paste_link + + @command(name="eval", aliases=("e",)) + @guild_only() + @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES) + async def eval_command(self, ctx: Context, *, code: str = None): + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. We've done our best to make this safe, but do let us know if you manage to find an + issue with it! + """ + if ctx.author.id in self.jobs: + return await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + f"please wait for it to finish!" + ) + + if not code: # None or empty string + return await ctx.invoke(self.bot.get_command("help"), "eval") + + log.info( + f"Received code from {ctx.author.name}#{ctx.author.discriminator} " + f"for evaluation:\n{code}" + ) + + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + + try: + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + msg = f"{ctx.author.mention} {msg}.\n\n```py\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info( + f"{ctx.author.name}#{ctx.author.discriminator}'s job had a return code of " + f"{results['returncode']}" + ) + finally: del self.jobs[ctx.author.id] - except Exception: - del self.jobs[ctx.author.id] - raise @eval_command.error async def eval_command_error(self, ctx: Context, error: CommandError): diff --git a/bot/constants.py b/bot/constants.py index d2e3bb315..0bd950a7d 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -374,18 +374,12 @@ class Keys(metaclass=YAMLGetter): youtube: str -class RabbitMQ(metaclass=YAMLGetter): - section = "rabbitmq" - - host: str - password: str - port: int - username: str - - class URLs(metaclass=YAMLGetter): section = "urls" + # Snekbox endpoints + snekbox_eval_api: str + # Discord API endpoints discord_api: str discord_invite_api: str diff --git a/bot/utils/service_discovery.py b/bot/utils/service_discovery.py deleted file mode 100644 index 8d79096bd..000000000 --- a/bot/utils/service_discovery.py +++ /dev/null @@ -1,22 +0,0 @@ -import datetime -import socket -import time -from contextlib import closing - -from bot.constants import RabbitMQ - -THIRTY_SECONDS = datetime.timedelta(seconds=30) - - -def wait_for_rmq(): - start = datetime.datetime.now() - - while True: - if datetime.datetime.now() - start > THIRTY_SECONDS: - return False - - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - if sock.connect_ex((RabbitMQ.host, RabbitMQ.port)) == 0: - return True - - time.sleep(0.5) diff --git a/config-default.yml b/config-default.yml index f6481cfcd..7854b5db9 100644 --- a/config-default.yml +++ b/config-default.yml @@ -204,13 +204,6 @@ keys: youtube: !ENV "YOUTUBE_API_KEY" -rabbitmq: - host: "pdrmq" - password: !ENV ["RABBITMQ_DEFAULT_PASS", "guest"] - port: 5672 - username: !ENV ["RABBITMQ_DEFAULT_USER", "guest"] - - urls: # PyDis site vars site: &DOMAIN "pythondiscord.com" @@ -242,6 +235,9 @@ urls: site_user_complete_api: !JOIN [*SCHEMA, *API, "/bot/users/complete"] paste_service: !JOIN [*SCHEMA, *PASTE, "/{key}"] + # Snekbox + snekbox_eval_api: "http://localhost:8060/eval" + # Env vars deploy: !ENV "DEPLOY_URL" status: !ENV "STATUS_URL" |