aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/snekbox.py22
-rw-r--r--tests/bot/cogs/test_snekbox.py45
2 files changed, 59 insertions, 8 deletions
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index cff7c5786..454836921 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -232,7 +232,7 @@ class Snekbox(Cog):
timeout=10
)
- code = new_message.content.split(' ', maxsplit=1)[1]
+ code = await self.get_code(new_message)
await ctx.message.clear_reactions()
with contextlib.suppress(HTTPException):
await response.delete()
@@ -243,6 +243,26 @@ class Snekbox(Cog):
return code
+ async def get_code(self, message: Message) -> Optional[str]:
+ """
+ Return the code from `message` to be evaluated.
+
+ If the message is an invocation of the eval command, return the first argument or None if it
+ doesn't exist. Otherwise, return the full content of the message.
+ """
+ log.trace(f"Getting context for message {message.id}.")
+ new_ctx = await self.bot.get_context(message)
+
+ if new_ctx.command is self.eval_command:
+ log.trace(f"Message {message.id} invokes eval command.")
+ split = message.content.split(maxsplit=1)
+ code = split[1] if len(split) > 1 else None
+ else:
+ log.trace(f"Message {message.id} does not invoke eval command.")
+ code = message.content
+
+ return code
+
@command(name="eval", aliases=("e",))
@guild_only()
@in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
index fd9468829..1dec0ccaf 100644
--- a/tests/bot/cogs/test_snekbox.py
+++ b/tests/bot/cogs/test_snekbox.py
@@ -1,11 +1,13 @@
import asyncio
import logging
import unittest
-from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
+from discord.ext import commands
+
+from bot import constants
from bot.cogs import snekbox
from bot.cogs.snekbox import Snekbox
-from bot.constants import URLs
from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
@@ -23,7 +25,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(await self.cog.post_eval("import random"), "return")
self.bot.http_session.post.assert_called_with(
- URLs.snekbox_eval_api,
+ constants.URLs.snekbox_eval_api,
json={"input": "import random"},
raise_for_status=True
)
@@ -43,10 +45,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
await self.cog.upload_output("My awesome output"),
- URLs.paste_service.format(key=key)
+ constants.URLs.paste_service.format(key=key)
)
self.bot.http_session.post.assert_called_with(
- URLs.paste_service.format(key="documents"),
+ constants.URLs.paste_service.format(key="documents"),
data="My awesome output",
raise_for_status=True
)
@@ -279,11 +281,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Test that the continue_eval function does continue if required conditions are met."""
ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
response = MockMessage(delete=AsyncMock())
- new_msg = MockMessage(content='!e NewCode')
+ new_msg = MockMessage()
self.bot.wait_for.side_effect = ((None, new_msg), None)
+ expected = "NewCode"
+ self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
actual = await self.cog.continue_eval(ctx, response)
- self.assertEqual(actual, 'NewCode')
+ self.cog.get_code.assert_awaited_once_with(new_msg)
+ self.assertEqual(actual, expected)
self.bot.wait_for.assert_has_awaits(
(
call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10),
@@ -302,6 +307,32 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(actual, None)
ctx.message.clear_reactions.assert_called_once()
+ async def test_get_code(self):
+ """Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
+ prefix = constants.Bot.prefix
+ subtests = (
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"),
+ (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None),
+ (MagicMock(spec=commands.Command), f"{prefix}tags get foo"),
+ (None, "print(123)")
+ )
+
+ for command, content, *expected_code in subtests:
+ if not expected_code:
+ expected_code = content
+ else:
+ [expected_code] = expected_code
+
+ with self.subTest(content=content, expected_code=expected_code):
+ self.bot.get_context.reset_mock()
+ self.bot.get_context.return_value = MockContext(command=command)
+ message = MockMessage(content=content)
+
+ actual_code = await self.cog.get_code(message)
+
+ self.bot.get_context.assert_awaited_once_with(message)
+ self.assertEqual(actual_code, expected_code)
+
def test_predicate_eval_message_edit(self):
"""Test the predicate_eval_message_edit function."""
msg0 = MockMessage(id=1, content='abc')