aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Shirayuki Nekomata <[email protected]>2020-02-29 14:14:45 +0700
committerGravatar GitHub <[email protected]>2020-02-29 14:14:45 +0700
commitf822cb52b497939770e51e96a8c55deda2727be1 (patch)
tree0b2388b99d4351020effdccb783b5c1305d76cb2
parentMerge branch 'master' into fuzzy_zero_div (diff)
parentMerge pull request #710 from python-discord/eval-enhancements (diff)
Merge branch 'master' into fuzzy_zero_div
-rw-r--r--bot/cogs/snekbox.py131
-rw-r--r--tests/bot/cogs/test_snekbox.py368
-rw-r--r--tests/helpers.py12
3 files changed, 479 insertions, 32 deletions
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index aef12546d..cff7c5786 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -1,10 +1,14 @@
+import asyncio
+import contextlib
import datetime
import logging
import re
import textwrap
+from functools import partial
from signal import Signals
from typing import Optional, Tuple
+from discord import HTTPException, Message, NotFound, Reaction, User
from discord.ext.commands import Cog, Context, command, guild_only
from bot.bot import Bot
@@ -36,6 +40,10 @@ RAW_CODE_REGEX = re.compile(
MAX_PASTE_LEN = 1000
EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
+SIGKILL = 9
+
+REEVAL_EMOJI = '\U0001f501' # :repeat:
+
class Snekbox(Cog):
"""Safe evaluation of Python code using Snekbox."""
@@ -101,7 +109,7 @@ class Snekbox(Cog):
if returncode is None:
msg = "Your eval job has failed"
error = stdout.strip()
- elif returncode == 128 + Signals.SIGKILL:
+ elif returncode == 128 + SIGKILL:
msg = "Your eval job timed out or ran out of memory"
elif returncode == 255:
msg = "Your eval job has failed"
@@ -135,7 +143,7 @@ class Snekbox(Cog):
"""
log.trace("Formatting output...")
- output = output.strip(" \n")
+ output = output.rstrip("\n")
original_output = output # To be uploaded to a pasting service if needed
paste_link = None
@@ -152,8 +160,8 @@ class Snekbox(Cog):
lines = output.count("\n")
if lines > 0:
- output = output.split("\n")[:10] # Only first 10 cause the rest is truncated anyway
- output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1))
+ output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)]
+ output = output[:11] # Limiting to only 11 lines
output = "\n".join(output)
if lines > 10:
@@ -169,12 +177,72 @@ class Snekbox(Cog):
if truncated:
paste_link = await self.upload_output(original_output)
- output = output.strip()
- if not output:
- output = "[No output]"
+ output = output or "[No output]"
return output, paste_link
+ async def send_eval(self, ctx: Context, code: str) -> Message:
+ """
+ Evaluate code, format it, and send the output to the corresponding channel.
+
+ Return the bot response.
+ """
+ async with ctx.typing():
+ results = await self.post_eval(code)
+ msg, error = self.get_results_message(results)
+
+ if error:
+ output, paste_link = error, None
+ else:
+ output, paste_link = await self.format_output(results["stdout"])
+
+ icon = self.get_status_emoji(results)
+ msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"
+ if paste_link:
+ msg = f"{msg}\nFull output: {paste_link}"
+
+ response = await ctx.send(msg)
+ self.bot.loop.create_task(
+ wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
+ )
+
+ log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
+ return response
+
+ async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]:
+ """
+ Check if the eval session should continue.
+
+ Return the new code to evaluate or None if the eval session should be terminated.
+ """
+ _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx)
+ _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx)
+
+ with contextlib.suppress(NotFound):
+ try:
+ _, new_message = await self.bot.wait_for(
+ 'message_edit',
+ check=_predicate_eval_message_edit,
+ timeout=10
+ )
+ await ctx.message.add_reaction(REEVAL_EMOJI)
+ await self.bot.wait_for(
+ 'reaction_add',
+ check=_predicate_emoji_reaction,
+ timeout=10
+ )
+
+ code = new_message.content.split(' ', maxsplit=1)[1]
+ await ctx.message.clear_reactions()
+ with contextlib.suppress(HTTPException):
+ await response.delete()
+
+ except asyncio.TimeoutError:
+ await ctx.message.clear_reactions()
+ return None
+
+ return code
+
@command(name="eval", aliases=("e",))
@guild_only()
@in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
@@ -183,7 +251,10 @@ class Snekbox(Cog):
Run Python code and get the results.
This command supports multiple lines of code, including code wrapped inside a formatted code
- block. We've done our best to make this safe, but do let us know if you manage to find an
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
if ctx.author.id in self.jobs:
@@ -199,32 +270,28 @@ class Snekbox(Cog):
log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
- self.jobs[ctx.author.id] = datetime.datetime.now()
- code = self.prepare_input(code)
+ while True:
+ self.jobs[ctx.author.id] = datetime.datetime.now()
+ code = self.prepare_input(code)
+ try:
+ response = await self.send_eval(ctx, code)
+ finally:
+ del self.jobs[ctx.author.id]
+
+ code = await self.continue_eval(ctx, response)
+ if not code:
+ break
+ log.info(f"Re-evaluating message {ctx.message.id}")
+
+
+def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
+ """Return True if the edited message is the context message and the content was indeed modified."""
+ return new_msg.id == ctx.message.id and old_msg.content != new_msg.content
- try:
- async with ctx.typing():
- results = await self.post_eval(code)
- msg, error = self.get_results_message(results)
-
- if error:
- output, paste_link = error, None
- else:
- output, paste_link = await self.format_output(results["stdout"])
-
- icon = self.get_status_emoji(results)
- msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"
- if paste_link:
- msg = f"{msg}\nFull output: {paste_link}"
-
- response = await ctx.send(msg)
- self.bot.loop.create_task(
- wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
- )
- log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
- finally:
- del self.jobs[ctx.author.id]
+def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool:
+ """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message."""
+ return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI
def setup(bot: Bot) -> None:
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
new file mode 100644
index 000000000..985bc66a1
--- /dev/null
+++ b/tests/bot/cogs/test_snekbox.py
@@ -0,0 +1,368 @@
+import asyncio
+import logging
+import unittest
+from functools import partial
+from unittest.mock import MagicMock, Mock, call, patch
+
+from bot.cogs import snekbox
+from bot.cogs.snekbox import Snekbox
+from bot.constants import URLs
+from tests.helpers import (
+ AsyncContextManagerMock, AsyncMock, MockBot, MockContext, MockMessage, MockReaction, MockUser, async_test
+)
+
+
+class SnekboxTests(unittest.TestCase):
+ def setUp(self):
+ """Add mocked bot and cog to the instance."""
+ self.bot = MockBot()
+
+ self.mocked_post = MagicMock()
+ self.mocked_post.json = AsyncMock()
+ self.bot.http_session.post = MagicMock(return_value=AsyncContextManagerMock(self.mocked_post))
+
+ self.cog = Snekbox(bot=self.bot)
+
+ @async_test
+ async def test_post_eval(self):
+ """Post the eval code to the URLs.snekbox_eval_api endpoint."""
+ self.mocked_post.json.return_value = {'lemon': 'AI'}
+
+ self.assertEqual(await self.cog.post_eval("import random"), {'lemon': 'AI'})
+ self.bot.http_session.post.assert_called_once_with(
+ URLs.snekbox_eval_api,
+ json={"input": "import random"},
+ raise_for_status=True
+ )
+
+ @async_test
+ async def test_upload_output_reject_too_long(self):
+ """Reject output longer than MAX_PASTE_LEN."""
+ result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))
+ self.assertEqual(result, "too long to upload")
+
+ @async_test
+ async def test_upload_output(self):
+ """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
+ key = "RainbowDash"
+ self.mocked_post.json.return_value = {"key": key}
+
+ self.assertEqual(
+ await self.cog.upload_output("My awesome output"),
+ URLs.paste_service.format(key=key)
+ )
+ self.bot.http_session.post.assert_called_once_with(
+ URLs.paste_service.format(key="documents"),
+ data="My awesome output",
+ raise_for_status=True
+ )
+
+ @async_test
+ async def test_upload_output_gracefully_fallback_if_exception_during_request(self):
+ """Output upload gracefully fallback if the upload fail."""
+ self.mocked_post.json.side_effect = Exception
+ log = logging.getLogger("bot.cogs.snekbox")
+ with self.assertLogs(logger=log, level='ERROR'):
+ await self.cog.upload_output('My awesome output!')
+
+ @async_test
+ async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):
+ """Output upload gracefully fallback if there is no key entry in the response body."""
+ self.mocked_post.json.return_value = {}
+ self.assertEqual((await self.cog.upload_output('My awesome output!')), None)
+
+ def test_prepare_input(self):
+ cases = (
+ ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'),
+ ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
+ ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
+ ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=f'Extract code from {testname}.'):
+ self.assertEqual(self.cog.prepare_input(case), expected)
+
+ def test_get_results_message(self):
+ """Return error and message according to the eval result."""
+ cases = (
+ ('ERROR', None, ('Your eval job has failed', 'ERROR')),
+ ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')),
+ ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred'))
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ @patch('bot.cogs.snekbox.Signals', side_effect=ValueError)
+ def test_get_results_message_invalid_signal(self, mock_Signals: Mock):
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127', '')
+ )
+
+ @patch('bot.cogs.snekbox.Signals')
+ def test_get_results_message_valid_signal(self, mock_Signals: Mock):
+ mock_Signals.return_value.name = 'SIGTEST'
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127 (SIGTEST)', '')
+ )
+
+ def test_get_status_emoji(self):
+ """Return emoji according to the eval result."""
+ cases = (
+ (' ', -1, ':warning:'),
+ ('Hello world!', 0, ':white_check_mark:'),
+ ('Invalid beard size', -1, ':x:')
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ @async_test
+ async def test_format_output(self):
+ """Test output formatting."""
+ self.cog.upload_output = AsyncMock(return_value='https://testificate.com/')
+
+ too_many_lines = (
+ '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n'
+ '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)'
+ )
+ too_long_too_many_lines = (
+ "\n".join(
+ f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1)
+ )[:1000] + "\n... (truncated - too long, too many lines)"
+ )
+
+ cases = (
+ ('', ('[No output]', None), 'No output'),
+ ('My awesome output', ('My awesome output', None), 'One line output'),
+ ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'),
+ ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'),
+ (
+ '\u202E\u202E\u202E',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect RIGHT-TO-LEFT OVERRIDE'
+ ),
+ (
+ '\u200B\u200B\u200B',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect ZERO WIDTH SPACE'
+ ),
+ ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'),
+ (
+ 'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd',
+ (too_many_lines, 'https://testificate.com/'),
+ '12 lines output'
+ ),
+ (
+ 'verylongbeard' * 100,
+ ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'),
+ '1300 characters output'
+ ),
+ (
+ ('verylongbeard' * 10 + '\n') * 15,
+ (too_long_too_many_lines, 'https://testificate.com/'),
+ '15 lines, 1965 characters output'
+ ),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=testname, case=case, expected=expected):
+ self.assertEqual(await self.cog.format_output(case), expected)
+
+ @async_test
+ async def test_eval_command_evaluate_once(self):
+ """Test the eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock(return_value=None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
+ self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_once_with(ctx, response)
+
+ @async_test
+ async def test_eval_command_evaluate_twice(self):
+ """Test the eval and re-eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock()
+ self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
+ self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_with(ctx, response)
+
+ @async_test
+ async def test_eval_command_reject_two_eval_at_the_same_time(self):
+ """Test if the eval command rejects an eval if the author already have a running eval."""
+ ctx = MockContext()
+ ctx.author.id = 42
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.send = AsyncMock()
+ self.cog.jobs = (42,)
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
+ )
+
+ @async_test
+ async def test_eval_command_call_help(self):
+ """Test if the eval command call the help command if no code is provided."""
+ ctx = MockContext()
+ ctx.invoke = AsyncMock()
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
+ ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval")
+
+ @async_test
+ async def test_send_eval(self):
+ """Test the send_eval function."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
+ self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('[No output]', None))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('')
+
+ @async_test
+ async def test_send_eval_with_paste_link(self):
+ """Test the send_eval function with a too long output that generate a paste link."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.'
+ '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('Way too long beard')
+
+ @async_test
+ async def test_send_eval_with_non_zero_eval(self):
+ """Test the send_eval function with a code returning a non-zero code."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.typing = MagicMock(return_value=AsyncContextManagerMock(None))
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
+ self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
+ self.cog.format_output = AsyncMock() # This function isn't called
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.format_output.assert_not_called()
+
+ @async_test
+ async def test_continue_eval_does_continue(self):
+ """Test that the continue_eval function does continue if required conditions are met."""
+ ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
+ response = MockMessage(delete=AsyncMock())
+ new_msg = MockMessage(content='!e NewCode')
+ self.bot.wait_for.side_effect = ((None, new_msg), None)
+
+ actual = await self.cog.continue_eval(ctx, response)
+ self.assertEqual(actual, 'NewCode')
+ self.bot.wait_for.has_calls(
+ call('message_edit', partial(snekbox.predicate_eval_message_edit, ctx), timeout=10),
+ call('reaction_add', partial(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ )
+ ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
+ ctx.message.clear_reactions.assert_called_once()
+ response.delete.assert_called_once()
+
+ @async_test
+ async def test_continue_eval_does_not_continue(self):
+ ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))
+ self.bot.wait_for.side_effect = asyncio.TimeoutError
+
+ actual = await self.cog.continue_eval(ctx, MockMessage())
+ self.assertEqual(actual, None)
+ ctx.message.clear_reactions.assert_called_once()
+
+ def test_predicate_eval_message_edit(self):
+ """Test the predicate_eval_message_edit function."""
+ msg0 = MockMessage(id=1, content='abc')
+ msg1 = MockMessage(id=2, content='abcdef')
+ msg2 = MockMessage(id=1, content='abcdef')
+
+ cases = (
+ (msg0, msg0, False, 'same ID, same content'),
+ (msg0, msg1, False, 'different ID, different content'),
+ (msg0, msg2, True, 'same ID, different content')
+ )
+ for ctx_msg, new_msg, expected, testname in cases:
+ with self.subTest(msg=f'Messages with {testname} return {expected}'):
+ ctx = MockContext(message=ctx_msg)
+ actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg)
+ self.assertEqual(actual, expected)
+
+ def test_predicate_eval_emoji_reaction(self):
+ """Test the predicate_eval_emoji_reaction function."""
+ valid_reaction = MockReaction(message=MockMessage(id=1))
+ valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI
+ valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))
+ valid_user = MockUser(id=2)
+
+ invalid_reaction_id = MockReaction(message=MockMessage(id=42))
+ invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI
+ invalid_user_id = MockUser(id=42)
+ invalid_reaction_str = MockReaction(message=MockMessage(id=1))
+ invalid_reaction_str.__str__.return_value = ':longbeard:'
+
+ cases = (
+ (invalid_reaction_id, valid_user, False, 'invalid reaction ID'),
+ (valid_reaction, invalid_user_id, False, 'invalid user ID'),
+ (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'),
+ (valid_reaction, valid_user, True, 'matching attributes')
+ )
+ for reaction, user, expected, testname in cases:
+ with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
+ actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user)
+ self.assertEqual(actual, expected)
+
+
+class SnekboxSetupTests(unittest.TestCase):
+ """Tests setup of the `Snekbox` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ snekbox.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/helpers.py b/tests/helpers.py
index 9d9dd5da6..6f50f6ae3 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -128,6 +128,18 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
return super().__call__(*args, **kwargs)
+class AsyncContextManagerMock(unittest.mock.MagicMock):
+ def __init__(self, return_value: Any):
+ super().__init__()
+ self._return_value = return_value
+
+ async def __aenter__(self):
+ return self._return_value
+
+ async def __aexit__(self, *args):
+ pass
+
+
class AsyncIteratorMock:
"""
A class to mock asynchronous iterators.