aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2019-06-22 15:06:16 +0200
committerGravatar GitHub <[email protected]>2019-06-22 15:06:16 +0200
commit25640adec9d042ccf249a91540fb09d354b04dfd (patch)
tree62cf2a21079510afd5d5a884006ddf04269780e7
parentMerge pull request #371 from python-discord/django-appeals (diff)
parentSnekbox: limit paste service uploads to 1000 characters (diff)
Merge pull request #372 from python-discord/snekbox
Replace RabbitMQ with Snekbox API
-rw-r--r--bot/__main__.py12
-rw-r--r--bot/cogs/rmq.py229
-rw-r--r--bot/cogs/snekbox.py261
-rw-r--r--bot/constants.py12
-rw-r--r--bot/utils/service_discovery.py22
-rw-r--r--config-default.yml10
6 files changed, 158 insertions, 388 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index 8687cc62c..8afec2718 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -8,7 +8,6 @@ from discord.ext.commands import Bot, when_mentioned_or
from bot.api import APIClient
from bot.constants import Bot as BotConfig, DEBUG_MODE
-from bot.utils.service_discovery import wait_for_rmq
log = logging.getLogger(__name__)
@@ -31,14 +30,6 @@ bot.http_session = ClientSession(
)
bot.api_client = APIClient(loop=asyncio.get_event_loop())
-log.info("Waiting for RabbitMQ...")
-has_rmq = wait_for_rmq()
-
-if has_rmq:
- log.info("RabbitMQ found")
-else:
- log.warning("Timed out while waiting for RabbitMQ")
-
# Internal/debug
bot.load_extension("bot.cogs.logging")
bot.load_extension("bot.cogs.security")
@@ -80,9 +71,6 @@ bot.load_extension("bot.cogs.token_remover")
bot.load_extension("bot.cogs.utils")
bot.load_extension("bot.cogs.wolfram")
-if has_rmq:
- bot.load_extension("bot.cogs.rmq")
-
bot.run(BotConfig.token)
bot.http_session.close() # Close the aiohttp session when the bot finishes running
diff --git a/bot/cogs/rmq.py b/bot/cogs/rmq.py
deleted file mode 100644
index 2742fb969..000000000
--- a/bot/cogs/rmq.py
+++ /dev/null
@@ -1,229 +0,0 @@
-import asyncio
-import datetime
-import json
-import logging
-import pprint
-
-import aio_pika
-from aio_pika import Message
-from dateutil import parser as date_parser
-from discord import Colour, Embed
-from discord.ext.commands import Bot
-from discord.utils import get
-
-from bot.constants import Channels, Guild, RabbitMQ
-
-log = logging.getLogger(__name__)
-
-LEVEL_COLOURS = {
- "debug": Colour.blue(),
- "info": Colour.green(),
- "warning": Colour.gold(),
- "error": Colour.red()
-}
-
-DEFAULT_LEVEL_COLOUR = Colour.greyple()
-EMBED_PARAMS = (
- "colour", "title", "url", "description", "timestamp"
-)
-
-CONSUME_TIMEOUT = datetime.timedelta(seconds=10)
-
-
-class RMQ:
- """
- RabbitMQ event handling
- """
-
- rmq = None # type: aio_pika.Connection
- channel = None # type: aio_pika.Channel
- queue = None # type: aio_pika.Queue
-
- def __init__(self, bot: Bot):
- self.bot = bot
-
- async def on_ready(self):
- self.rmq = await aio_pika.connect_robust(
- host=RabbitMQ.host, port=RabbitMQ.port, login=RabbitMQ.username, password=RabbitMQ.password
- )
-
- log.info("Connected to RabbitMQ")
-
- self.channel = await self.rmq.channel()
- self.queue = await self.channel.declare_queue("bot_events", durable=True)
-
- log.debug("Channel opened, queue declared")
-
- async for message in self.queue:
- with message.process():
- message.ack()
- await self.handle_message(message, message.body.decode())
-
- async def send_text(self, queue: str, data: str):
- message = Message(data.encode("utf-8"))
- await self.channel.default_exchange.publish(message, queue)
-
- async def send_json(self, queue: str, **data):
- message = Message(json.dumps(data).encode("utf-8"))
- await self.channel.default_exchange.publish(message, queue)
-
- async def consume(self, queue: str, **kwargs):
- queue_obj = await self.channel.declare_queue(queue, **kwargs)
-
- result = None
- start_time = datetime.datetime.now()
-
- while result is None:
- if datetime.datetime.now() - start_time >= CONSUME_TIMEOUT:
- result = "Timed out while waiting for a response."
- else:
- result = await queue_obj.get(timeout=5, fail=False)
- await asyncio.sleep(0.5)
-
- if result:
- result.ack()
-
- return result
-
- async def handle_message(self, message, data):
- log.debug(f"Message: {message}")
- log.debug(f"Data: {data}")
-
- try:
- data = json.loads(data)
- except Exception:
- await self.do_mod_log("error", "Unable to parse event", data)
- else:
- event = data["event"]
- event_data = data["data"]
-
- try:
- func = getattr(self, f"do_{event}")
- await func(**event_data)
- except Exception as e:
- await self.do_mod_log(
- "error", f"Unable to handle event: {event}",
- str(e)
- )
-
- async def do_mod_log(self, level: str, title: str, message: str):
- colour = LEVEL_COLOURS.get(level, DEFAULT_LEVEL_COLOUR)
- embed = Embed(
- title=title, description=f"```\n{message}\n```",
- colour=colour, timestamp=datetime.datetime.now()
- )
-
- await self.bot.get_channel(Channels.modlog).send(embed=embed)
- log.log(logging._nameToLevel[level.upper()], f"Modlog: {title} | {message}")
-
- async def do_send_message(self, target: int, message: str):
- channel = self.bot.get_channel(target)
-
- if channel is None:
- await self.do_mod_log(
- "error", "Failed: Send Message",
- f"Unable to find channel: {target}"
- )
- else:
- await channel.send(message)
-
- await self.do_mod_log(
- "info", "Succeeded: Send Embed",
- f"Message sent to channel {target}\n\n{message}"
- )
-
- async def do_send_embed(self, target: int, **embed_params):
- for param, value in list(embed_params.items()): # To keep a full copy
- if param not in EMBED_PARAMS:
- await self.do_mod_log(
- "warning", "Warning: Send Embed",
- f"Unknown embed parameter: {param}"
- )
- del embed_params[param]
-
- if param == "timestamp":
- embed_params[param] = date_parser.parse(value)
- elif param == "colour":
- embed_params[param] = Colour(value)
-
- channel = self.bot.get_channel(target)
-
- if channel is None:
- await self.do_mod_log(
- "error", "Failed: Send Embed",
- f"Unable to find channel: {target}"
- )
- else:
- await channel.send(embed=Embed(**embed_params))
-
- await self.do_mod_log(
- "info", "Succeeded: Send Embed",
- f"Embed sent to channel {target}\n\n{pprint.pformat(embed_params, 4)}"
- )
-
- async def do_add_role(self, target: int, role_id: int, reason: str):
- guild = self.bot.get_guild(Guild.id)
- member = guild.get_member(int(target))
-
- if member is None:
- return await self.do_mod_log(
- "error", "Failed: Add Role",
- f"Unable to find member: {target}"
- )
-
- role = get(guild.roles, id=int(role_id))
-
- if role is None:
- return await self.do_mod_log(
- "error", "Failed: Add Role",
- f"Unable to find role: {role_id}"
- )
-
- try:
- await member.add_roles(role, reason=reason)
- except Exception as e:
- await self.do_mod_log(
- "error", "Failed: Add Role",
- f"Error while adding role {role.name}: {e}"
- )
- else:
- await self.do_mod_log(
- "info", "Succeeded: Add Role",
- f"Role {role.name} added to member {target}"
- )
-
- async def do_remove_role(self, target: int, role_id: int, reason: str):
- guild = self.bot.get_guild(Guild.id)
- member = guild.get_member(int(target))
-
- if member is None:
- return await self.do_mod_log(
- "error", "Failed: Remove Role",
- f"Unable to find member: {target}"
- )
-
- role = get(guild.roles, id=int(role_id))
-
- if role is None:
- return await self.do_mod_log(
- "error", "Failed: Remove Role",
- f"Unable to find role: {role_id}"
- )
-
- try:
- await member.remove_roles(role, reason=reason)
- except Exception as e:
- await self.do_mod_log(
- "error", "Failed: Remove Role",
- f"Error while adding role {role.name}: {e}"
- )
- else:
- await self.do_mod_log(
- "info", "Succeeded: Remove Role",
- f"Role {role.name} removed from member {target}"
- )
-
-
-def setup(bot):
- bot.add_cog(RMQ(bot))
- log.info("Cog loaded: RMQ")
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index cb0454249..0f8d3e4b6 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -3,13 +3,14 @@ import logging
import random
import re
import textwrap
+from signal import Signals
+from typing import Optional, Tuple
from discord import Colour, Embed
from discord.ext.commands import (
Bot, CommandError, Context, NoPrivateMessage, command, guild_only
)
-from bot.cogs.rmq import RMQ
from bot.constants import Channels, ERROR_REPLIES, NEGATIVE_REPLIES, Roles, URLs
from bot.decorators import InChannelCheckFailure, in_channel
from bot.utils.messages import wait_for_deletion
@@ -17,22 +18,6 @@ from bot.utils.messages import wait_for_deletion
log = logging.getLogger(__name__)
-RMQ_ARGS = {
- "durable": False,
- "arguments": {"x-message-ttl": 5000},
- "auto_delete": True
-}
-
-CODE_TEMPLATE = """
-venv_file = "/snekbox/.venv/bin/activate_this.py"
-exec(open(venv_file).read(), dict(__file__=venv_file))
-
-try:
-{CODE}
-except Exception as e:
- print(e)
-"""
-
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
FORMATTED_CODE_REGEX = re.compile(
r"^\s*" # any leading whitespace from the beginning of the string
@@ -53,43 +38,47 @@ RAW_CODE_REGEX = re.compile(
)
BYPASS_ROLES = (Roles.owner, Roles.admin, Roles.moderator, Roles.helpers)
+MAX_PASTE_LEN = 1000
class Snekbox:
"""
- Safe evaluation using Snekbox
+ Safe evaluation of Python code using Snekbox
"""
def __init__(self, bot: Bot):
self.bot = bot
self.jobs = {}
- @property
- def rmq(self) -> RMQ:
- return self.bot.get_cog("RMQ")
-
- @command(name='eval', aliases=('e',))
- @guild_only()
- @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES)
- async def eval_command(self, ctx: Context, *, code: str = None):
- """
- Run some code. get the result back. We've done our best to make this safe, but do let us know if you
- manage to find an issue with it!
+ async def post_eval(self, code: str) -> dict:
+ """Send a POST request to the Snekbox API to evaluate code and return the results."""
+ url = URLs.snekbox_eval_api
+ data = {"input": code}
+ async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
+ return await resp.json()
- This command supports multiple lines of code, including code wrapped inside a formatted code block.
- """
+ async def upload_output(self, output: str) -> Optional[str]:
+ """Upload the eval output to a paste service and return a URL to it if successful."""
+ log.trace("Uploading full output to paste service...")
- if ctx.author.id in self.jobs:
- await ctx.send(f"{ctx.author.mention} You've already got a job running - please wait for it to finish!")
- return
+ if len(output) > MAX_PASTE_LEN:
+ log.info("Full output is too long to upload")
+ return "too long to upload"
- if not code: # None or empty string
- return await ctx.invoke(self.bot.get_command("help"), "eval")
+ url = URLs.paste_service.format(key="documents")
+ try:
+ async with self.bot.http_session.post(url, data=output, raise_for_status=True) as resp:
+ data = await resp.json()
- log.info(f"Received code from {ctx.author.name}#{ctx.author.discriminator} for evaluation:\n{code}")
- self.jobs[ctx.author.id] = datetime.datetime.now()
+ if "key" in data:
+ return URLs.paste_service.format(key=data["key"])
+ except Exception:
+ # 400 (Bad Request) means there are too many characters
+ log.exception("Failed to upload full output to paste service!")
- # Strip whitespace and inline or block code markdown and extract the code and some formatting info
+ @staticmethod
+ def prepare_input(code: str) -> str:
+ """Extract code from the Markdown, format it, and insert it into the code template."""
match = FORMATTED_CODE_REGEX.fullmatch(code)
if match:
code, block, lang, delim = match.group("code", "block", "lang", "delim")
@@ -101,86 +90,140 @@ class Snekbox:
log.trace(f"Extracted {info} for evaluation:\n{code}")
else:
code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))
- log.trace(f"Eval message contains not or badly formatted code, stripping whitespace only:\n{code}")
+ log.trace(
+ f"Eval message contains unformatted or badly formatted code, "
+ f"stripping whitespace only:\n{code}"
+ )
- code = textwrap.indent(code, " ")
- code = CODE_TEMPLATE.replace("{CODE}", code)
+ return code
+
+ @staticmethod
+ def get_results_message(results: dict) -> Tuple[str, str]:
+ """Return a user-friendly message and error corresponding to the process's return code."""
+ stdout, returncode = results["stdout"], results["returncode"]
+ msg = f"Your eval job has completed with return code {returncode}"
+ error = ""
+
+ if returncode is None:
+ msg = "Your eval job has failed"
+ error = stdout.strip()
+ elif returncode == 128 + Signals.SIGKILL:
+ msg = "Your eval job timed out or ran out of memory"
+ elif returncode == 255:
+ msg = "Your eval job has failed"
+ error = "A fatal NsJail error occurred"
+ else:
+ # Try to append signal's name if one exists
+ try:
+ name = Signals(returncode - 128).name
+ msg = f"{msg} ({name})"
+ except ValueError:
+ pass
- try:
- await self.rmq.send_json(
- "input",
- snekid=str(ctx.author.id), message=code
- )
+ return msg, error
- async with ctx.typing():
- message = await self.rmq.consume(str(ctx.author.id), **RMQ_ARGS)
- paste_link = None
+ async def format_output(self, output: str) -> Tuple[str, Optional[str]]:
+ """
+ Format the output and return a tuple of the formatted output and a URL to the full output.
- if isinstance(message, str):
- output = str.strip(" \n")
- else:
- output = message.body.decode().strip(" \n")
+ Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters
+ and upload the full output to a paste service.
+ """
+ log.trace("Formatting output...")
- if "<@" in output:
- output = output.replace("<@", "<@\u200B") # Zero-width space
+ output = output.strip(" \n")
+ original_output = output # To be uploaded to a pasting service if needed
+ paste_link = None
- if "<!@" in output:
- output = output.replace("<!@", "<!@\u200B") # Zero-width space
+ if "<@" in output:
+ output = output.replace("<@", "<@\u200B") # Zero-width space
- if ESCAPE_REGEX.findall(output):
- output = "Code block escape attempt detected; will not output result"
- else:
- # the original output, to send to a pasting service if needed
- full_output = output
- truncated = False
- if output.count("\n") > 0:
- output = [f"{i:03d} | {line}" for i, line in enumerate(output.split("\n"), start=1)]
- output = "\n".join(output)
-
- if output.count("\n") > 10:
- output = "\n".join(output.split("\n")[:10])
-
- if len(output) >= 1000:
- output = f"{output[:1000]}\n... (truncated - too long, too many lines)"
- else:
- output = f"{output}\n... (truncated - too many lines)"
- truncated = True
-
- elif len(output) >= 1000:
- output = f"{output[:1000]}\n... (truncated - too long)"
- truncated = True
-
- if truncated:
- try:
- response = await self.bot.http_session.post(
- URLs.paste_service.format(key="documents"),
- data=full_output
- )
- data = await response.json()
- if "key" in data:
- paste_link = URLs.paste_service.format(key=data["key"])
- except Exception:
- log.exception("Failed to upload full output to paste service!")
-
- if output.strip():
- if paste_link:
- msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```" \
- f"\nFull output: {paste_link}"
- else:
- msg = f"{ctx.author.mention} Your eval job has completed.\n\n```py\n{output}\n```"
-
- response = await ctx.send(msg)
- self.bot.loop.create_task(wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot))
+ if "<!@" in output:
+ output = output.replace("<!@", "<!@\u200B") # Zero-width space
- else:
- await ctx.send(
- f"{ctx.author.mention} Your eval job has completed.\n\n```py\n[No output]\n```"
- )
+ if ESCAPE_REGEX.findall(output):
+ return "Code block escape attempt detected; will not output result", paste_link
+ truncated = False
+ lines = output.count("\n")
+
+ if lines > 0:
+ output = output.split("\n")[:10] # Only first 10 cause the rest is truncated anyway
+ output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1))
+ output = "\n".join(output)
+
+ if lines > 10:
+ truncated = True
+ if len(output) >= 1000:
+ output = f"{output[:1000]}\n... (truncated - too long, too many lines)"
+ else:
+ output = f"{output}\n... (truncated - too many lines)"
+ elif len(output) >= 1000:
+ truncated = True
+ output = f"{output[:1000]}\n... (truncated - too long)"
+
+ if truncated:
+ paste_link = await self.upload_output(original_output)
+
+ output = output.strip()
+ if not output:
+ output = "[No output]"
+
+ return output, paste_link
+
+ @command(name="eval", aliases=("e",))
+ @guild_only()
+ @in_channel(Channels.bot, bypass_roles=BYPASS_ROLES)
+ async def eval_command(self, ctx: Context, *, code: str = None):
+ """
+ Run Python code and get the results.
+
+ This command supports multiple lines of code, including code wrapped inside a formatted code
+ block. We've done our best to make this safe, but do let us know if you manage to find an
+ issue with it!
+ """
+ if ctx.author.id in self.jobs:
+ return await ctx.send(
+ f"{ctx.author.mention} You've already got a job running - "
+ f"please wait for it to finish!"
+ )
+
+ if not code: # None or empty string
+ return await ctx.invoke(self.bot.get_command("help"), "eval")
+
+ log.info(
+ f"Received code from {ctx.author.name}#{ctx.author.discriminator} "
+ f"for evaluation:\n{code}"
+ )
+
+ self.jobs[ctx.author.id] = datetime.datetime.now()
+ code = self.prepare_input(code)
+
+ try:
+ async with ctx.typing():
+ results = await self.post_eval(code)
+ msg, error = self.get_results_message(results)
+
+ if error:
+ output, paste_link = error, None
+ else:
+ output, paste_link = await self.format_output(results["stdout"])
+
+ msg = f"{ctx.author.mention} {msg}.\n\n```py\n{output}\n```"
+ if paste_link:
+ msg = f"{msg}\nFull output: {paste_link}"
+
+ response = await ctx.send(msg)
+ self.bot.loop.create_task(
+ wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
+ )
+
+ log.info(
+ f"{ctx.author.name}#{ctx.author.discriminator}'s job had a return code of "
+ f"{results['returncode']}"
+ )
+ finally:
del self.jobs[ctx.author.id]
- except Exception:
- del self.jobs[ctx.author.id]
- raise
@eval_command.error
async def eval_command_error(self, ctx: Context, error: CommandError):
diff --git a/bot/constants.py b/bot/constants.py
index d2e3bb315..0bd950a7d 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -374,18 +374,12 @@ class Keys(metaclass=YAMLGetter):
youtube: str
-class RabbitMQ(metaclass=YAMLGetter):
- section = "rabbitmq"
-
- host: str
- password: str
- port: int
- username: str
-
-
class URLs(metaclass=YAMLGetter):
section = "urls"
+ # Snekbox endpoints
+ snekbox_eval_api: str
+
# Discord API endpoints
discord_api: str
discord_invite_api: str
diff --git a/bot/utils/service_discovery.py b/bot/utils/service_discovery.py
deleted file mode 100644
index 8d79096bd..000000000
--- a/bot/utils/service_discovery.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import datetime
-import socket
-import time
-from contextlib import closing
-
-from bot.constants import RabbitMQ
-
-THIRTY_SECONDS = datetime.timedelta(seconds=30)
-
-
-def wait_for_rmq():
- start = datetime.datetime.now()
-
- while True:
- if datetime.datetime.now() - start > THIRTY_SECONDS:
- return False
-
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
- if sock.connect_ex((RabbitMQ.host, RabbitMQ.port)) == 0:
- return True
-
- time.sleep(0.5)
diff --git a/config-default.yml b/config-default.yml
index f6481cfcd..7854b5db9 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -204,13 +204,6 @@ keys:
youtube: !ENV "YOUTUBE_API_KEY"
-rabbitmq:
- host: "pdrmq"
- password: !ENV ["RABBITMQ_DEFAULT_PASS", "guest"]
- port: 5672
- username: !ENV ["RABBITMQ_DEFAULT_USER", "guest"]
-
-
urls:
# PyDis site vars
site: &DOMAIN "pythondiscord.com"
@@ -242,6 +235,9 @@ urls:
site_user_complete_api: !JOIN [*SCHEMA, *API, "/bot/users/complete"]
paste_service: !JOIN [*SCHEMA, *PASTE, "/{key}"]
+ # Snekbox
+ snekbox_eval_api: "http://localhost:8060/eval"
+
# Env vars
deploy: !ENV "DEPLOY_URL"
status: !ENV "STATUS_URL"