From b065b2493b6adeb066aec1976eccde4a3bdbec2e Mon Sep 17 00:00:00 2001 From: wookie184 Date: Sat, 22 Oct 2022 13:56:25 +0100 Subject: Fix tests --- tests/bot/exts/recruitment/talentpool/test_review.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/recruitment/talentpool/test_review.py b/tests/bot/exts/recruitment/talentpool/test_review.py index ed9b66e12..295b0e221 100644 --- a/tests/bot/exts/recruitment/talentpool/test_review.py +++ b/tests/bot/exts/recruitment/talentpool/test_review.py @@ -1,6 +1,6 @@ import unittest from datetime import datetime, timedelta, timezone -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from bot.exts.recruitment.talentpool import _review from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel @@ -65,6 +65,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase): MockMessage(author=self.bot_user, content="Not a review", created_at=not_too_recent), MockMessage(author=self.bot_user, content="Not a review", created_at=not_too_recent), ], + not_too_recent.timestamp(), True, ), @@ -75,6 +76,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase): MockMessage(author=self.bot_user, content="Zig for Helper!", created_at=not_too_recent), MockMessage(author=self.bot_user, content="Scaleios for Helper!", created_at=not_too_recent), ], + not_too_recent.timestamp(), False, ), @@ -83,6 +85,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase): [ MockMessage(author=self.bot_user, content="Chrisjl for Helper!", created_at=too_recent), ], + too_recent.timestamp(), False, ), @@ -94,18 +97,25 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase): MockMessage(author=self.bot_user, content="wookie for Helper!", created_at=not_too_recent), MockMessage(author=self.bot_user, content="Not a review", created_at=not_too_recent), ], + not_too_recent.timestamp(), True, ), # No messages, so ready. - ([], True), + ([], None, True), ) - for messages, expected in cases: + for messages, last_review_timestamp, expected in cases: with self.subTest(messages=messages, expected=expected): self.voting_channel.history = AsyncIterator(messages) + + cache_get_mock = AsyncMock(return_value=last_review_timestamp) + self.reviewer.status_cache.get = cache_get_mock + res = await self.reviewer.is_ready_for_review() + self.assertIs(res, expected) + cache_get_mock.assert_called_with("last_vote_date") @patch("bot.exts.recruitment.talentpool._review.MIN_NOMINATION_TIME", timedelta(days=7)) async def test_get_user_for_review(self): -- cgit v1.2.3 From f6139c68ecfc39cc24a3c8075082b47a509d8bc5 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 20 Nov 2022 22:37:58 -0500 Subject: Update unit tests --- tests/bot/exts/utils/test_snekbox.py | 73 ++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index b1f32c210..9e3143776 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -8,7 +8,7 @@ from discord.ext import commands from bot import constants from bot.errors import LockedResourceError from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import Snekbox +from bot.exts.utils.snekbox import EvalJob, Snekbox from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser @@ -18,6 +18,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot = MockBot() self.cog = Snekbox(bot=self.bot) + @staticmethod + def code_args(code: str) -> tuple[EvalJob]: + """Converts code to a tuple of arguments expected.""" + return EvalJob.from_code(code), + async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() @@ -27,10 +32,22 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") + job = EvalJob.from_code("import random") + self.assertEqual(await self.cog.post_job(job, "3.10"), "return") + + expected = { + "args": ["main.py"], + "files": [ + { + "name": "main.py", + "content-encoding": "utf-8", + "content": "import random" + } + ] + } self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, - json={"input": "import random"}, + json=expected, raise_for_status=True ) resp.json.assert_awaited_once() @@ -76,18 +93,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') ) - for case, setup_code, testname in cases: + for case, setup_code, test_name in cases: setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) - expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) - with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] + with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" cases = ( - ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), - ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', [])), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', [])), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', [])) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): @@ -98,7 +115,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_get_results_message_invalid_signal(self, mock_signals: Mock): self.assertEqual( self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127', '') + ('Your 3.11 eval job has completed with return code 127', '', []) ) @patch('bot.exts.utils.snekbox.Signals') @@ -106,7 +123,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mock_signals.return_value.name = 'SIGTEST' self.assertEqual( self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') + ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '', []) ) def test_get_status_emoji(self): @@ -178,10 +195,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.command = MagicMock() self.cog.send_job = AsyncMock(return_value=response) - self.cog.continue_job = AsyncMock(return_value=(None, None)) + self.cog.continue_job = AsyncMock(return_value=None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) - self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + self.cog.send_job.assert_called_once_with('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')) self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): @@ -191,11 +208,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.command = MagicMock() self.cog.send_job = AsyncMock(return_value=response) self.cog.continue_job = AsyncMock() - self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) self.cog.send_job.assert_called_with( - ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' + 'eval', ctx, '3.11', *self.code_args('MyAwesomeFormattedCode') ) self.cog.continue_job.assert_called_with(ctx, response, 'eval') @@ -212,8 +229,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) with self.assertRaises(LockedResourceError): await asyncio.gather( - self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), - self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), + self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), ) async def test_send_job(self): @@ -224,7 +241,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.author = MockUser(mention='@LemonLemonishBeard#0042') self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '', [])) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('[No output]', None)) @@ -232,7 +249,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), ctx.send.assert_called_once() self.assertEqual( @@ -243,7 +260,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) + self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') self.cog.format_output.assert_called_once_with('') @@ -256,7 +273,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.author.mention = '@LemonLemonishBeard#0042' self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '', [])) self.cog.get_status_emoji = MagicMock(return_value=':yay!:') self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) @@ -264,7 +281,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), ctx.send.assert_called_once() self.assertEqual( @@ -273,7 +290,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) + self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with( {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' @@ -287,7 +304,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) + self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval', [])) self.cog.get_status_emoji = MagicMock(return_value=':nope!:') self.cog.format_output = AsyncMock() # This function isn't called @@ -295,7 +312,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), ctx.send.assert_called_once() self.assertEqual( @@ -303,7 +320,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) + self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') self.cog.format_output.assert_not_called() @@ -328,7 +345,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, (expected, None)) + self.assertEqual(actual, EvalJob.from_code(expected)) self.bot.wait_for.assert_has_awaits( ( call( @@ -348,7 +365,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.wait_for.side_effect = asyncio.TimeoutError actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) - self.assertEqual(actual, (None, None)) + self.assertEqual(actual, None) ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) async def test_get_code(self): -- cgit v1.2.3 From c7d6d05f29bd6f9958198b49f22a4c9623f7425f Mon Sep 17 00:00:00 2001 From: ionite34 Date: Thu, 24 Nov 2022 15:43:53 +0800 Subject: Update unit test --- tests/bot/exts/utils/test_snekbox.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 9e3143776..1f226a6ce 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -1,5 +1,6 @@ import asyncio import unittest +from base64 import b64encode from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch from discord import AllowedMentions @@ -39,9 +40,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "args": ["main.py"], "files": [ { - "name": "main.py", - "content-encoding": "utf-8", - "content": "import random" + "path": "main.py", + "content": b64encode("import random".encode()).decode() } ] } -- cgit v1.2.3 From 7711e2d5b53a3da5481bc74105f7a1c0dbf99d6c Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 30 Nov 2022 07:48:28 +0800 Subject: Update unit tests for snekbox --- tests/bot/exts/utils/test_snekbox.py | 111 ++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 54 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 1f226a6ce..b52159101 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -9,7 +9,7 @@ from discord.ext import commands from bot import constants from bot.errors import LockedResourceError from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import EvalJob, Snekbox +from bot.exts.utils.snekbox import EvalJob, Snekbox, EvalResult from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser @@ -18,6 +18,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Add mocked bot and cog to the instance.""" self.bot = MockBot() self.cog = Snekbox(bot=self.bot) + self.job = EvalJob.from_code("import random") @staticmethod def code_args(code: str) -> tuple[EvalJob]: @@ -34,7 +35,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.http_session.post.return_value = context_manager job = EvalJob.from_code("import random") - self.assertEqual(await self.cog.post_job(job, "3.10"), "return") + self.assertEqual(await self.cog.post_job(job), "return") expected = { "args": ["main.py"], @@ -99,34 +100,37 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): self.assertEqual(self.cog.prepare_timeit_input(case), expected) - def test_get_results_message(self): - """Return error and message according to the eval result.""" + def test_eval_result_message(self): + """EvalResult.message, should return error and message.""" cases = ( - ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', [])), - ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', [])), - ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', [])) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') - self.assertEqual(actual, expected) + result = EvalResult(stdout=stdout, returncode=returncode) + job = EvalJob([]) + self.assertEqual(result.message(job), expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): + def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): + result = EvalResult(stdout="", returncode=127) self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127', '', []) + result.message(EvalJob([], version="3.10")), + ("Your 3.10 eval job has completed with return code 127", "") ) @patch('bot.exts.utils.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' + def test_eval_result_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = "SIGTEST" + result = EvalResult(stdout="", returncode=127) self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '', []) + result.message(EvalJob([], version="3.11")), + ("Your 3.11 eval job has completed with return code 127 (SIGTEST)", "") ) - def test_get_status_emoji(self): + def test_eval_result_status_emoji(self): """Return emoji according to the eval result.""" cases = ( (' ', -1, ':warning:'), @@ -135,8 +139,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) + result = EvalResult(stdout=stdout, returncode=returncode) + self.assertEqual(result.status_emoji, expected) async def test_format_output(self): """Test output formatting.""" @@ -198,7 +202,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.continue_job = AsyncMock(return_value=None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) - self.cog.send_job.assert_called_once_with('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')) + job = EvalJob.from_code("MyAwesomeCode") + self.cog.send_job.assert_called_once_with(ctx, job) self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): @@ -211,10 +216,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) - self.cog.send_job.assert_called_with( - 'eval', ctx, '3.11', *self.code_args('MyAwesomeFormattedCode') - ) - self.cog.continue_job.assert_called_with(ctx, response, 'eval') + + expected_job = EvalJob.from_code("MyAwesomeFormattedCode") + self.cog.send_job.assert_called_with(ctx, expected_job) + self.cog.continue_job.assert_called_with(ctx, response, "eval") async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -229,8 +234,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) with self.assertRaises(LockedResourceError): await asyncio.gather( - self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), - self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), ) async def test_send_job(self): @@ -240,30 +245,31 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '', [])) - self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + eval_result = EvalResult("", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), + job = EvalJob.from_code('MyAwesomeCode') + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' + '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed ' + 'with return code 0.\n\n```\n[No output]\n```' ) allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') - self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') + self.cog.post_job.assert_called_once_with(job) self.cog.format_output.assert_called_once_with('') + self.cog.upload_output.assert_not_called() async def test_send_job_with_paste_link(self): """Test the send_job function with a too long output that generate a paste link.""" @@ -272,29 +278,26 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '', [])) - self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + eval_result = EvalResult("Way too long beard", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :yay!: Return code 0.' + '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job ' + 'has completed with return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') - self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with( - {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' - ) + self.cog.post_job.assert_called_once_with(job) self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_job_with_non_zero_eval(self): @@ -303,27 +306,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval', [])) - self.cog.get_status_emoji = MagicMock(return_value=':nope!:') - self.cog.format_output = AsyncMock() # This function isn't called + + eval_result = EvalResult("ERROR", 127) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job('eval', ctx, '3.11', *self.code_args('MyAwesomeCode')), + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' + '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.' + '\n\n```\nERROR\n```' ) - self.cog.post_job.assert_called_once_with(*self.code_args('MyAwesomeCode'), '3.11') - self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') - self.cog.format_output.assert_not_called() + self.cog.post_job.assert_called_once_with(job) + self.cog.upload_output.assert_not_called() @patch("bot.exts.utils.snekbox.partial") async def test_continue_job_does_continue(self, partial_mock): -- cgit v1.2.3 From 5eba5306a21e373b55a43828b4455395bbc671fc Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 30 Nov 2022 08:40:41 +0800 Subject: Reorder imports --- bot/exts/utils/snekbox.py | 4 ++-- tests/bot/exts/utils/test_snekbox.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index cd090ed79..1d003fb9a 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -18,8 +18,8 @@ from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Roles, URLs from bot.decorators import redirect_output -from bot.exts.utils.snekio import FileAttachment, sizeof_fmt, FILE_SIZE_LIMIT from bot.exts.help_channels._channel import is_help_forum_post +from bot.exts.utils.snekio import FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt from bot.log import get_logger from bot.utils import send_to_paste_service from bot.utils.lock import LockedResourceError, lock_arg @@ -133,7 +133,7 @@ class EvalResult: err_files: list[str] = field(default_factory=list) @property - def status_emoji(self): + def status_emoji(self) -> str: """Return an emoji corresponding to the status code or lack of output in result.""" # If there are attachments, skip empty output warning if not self.stdout.strip() and not self.files: # No output diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index b52159101..3f9789031 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -9,7 +9,7 @@ from discord.ext import commands from bot import constants from bot.errors import LockedResourceError from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import EvalJob, Snekbox, EvalResult +from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser -- cgit v1.2.3 From bf5327305717ea87568d72f8ed6d8e79bd8969c6 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 14 Dec 2022 11:45:44 +0800 Subject: Fix test_post_job unit test --- tests/bot/exts/utils/test_snekbox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 3f9789031..f8222761a 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -28,14 +28,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() - resp.json = AsyncMock(return_value="return") + resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137}) context_manager = MagicMock() context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager job = EvalJob.from_code("import random") - self.assertEqual(await self.cog.post_job(job), "return") + self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) expected = { "args": ["main.py"], -- cgit v1.2.3 From e248e9e1e0fd48b4c76a3806d89d995f2df1a512 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 14 Dec 2022 16:50:55 +0800 Subject: Update function name `get_message` --- bot/exts/utils/snekbox.py | 4 ++-- tests/bot/exts/utils/test_snekbox.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 018417005..b5ba7335b 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -143,7 +143,7 @@ class EvalResult: else: # Exception return ":x:" - def message(self, job: EvalJob) -> tuple[str, str]: + def get_message(self, job: EvalJob) -> tuple[str, str]: """Return a user-friendly message and error corresponding to the process's return code.""" msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}" error = "" @@ -383,7 +383,7 @@ class Snekbox(Cog): """ async with ctx.typing(): result = await self.post_job(job) - msg, error = result.message(job) + msg, error = result.get_message(job) if error: output, paste_link = error, None diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index f8222761a..e54e80732 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -111,13 +111,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): with self.subTest(stdout=stdout, returncode=returncode, expected=expected): result = EvalResult(stdout=stdout, returncode=returncode) job = EvalJob([]) - self.assertEqual(result.message(job), expected) + self.assertEqual(result.get_message(job), expected) @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.message(EvalJob([], version="3.10")), + result.get_message(EvalJob([], version="3.10")), ("Your 3.10 eval job has completed with return code 127", "") ) @@ -126,7 +126,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mock_signals.return_value.name = "SIGTEST" result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.message(EvalJob([], version="3.11")), + result.get_message(EvalJob([], version="3.11")), ("Your 3.11 eval job has completed with return code 127 (SIGTEST)", "") ) -- cgit v1.2.3 From a1d1926de98393324894e8392c7d55d49a0273a1 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 14 Dec 2022 16:56:03 +0800 Subject: Update test_post_job to use 3.10 snekbox --- tests/bot/exts/utils/test_snekbox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index e54e80732..722c5c569 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -34,7 +34,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - job = EvalJob.from_code("import random") + job = EvalJob.from_code("import random").as_version("3.10") self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) expected = { -- cgit v1.2.3 From 727a146f2de0f37c43d6939dc4368ef780373cd4 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Wed, 14 Dec 2022 17:30:33 +0800 Subject: Refactor to move snekbox to module --- bot/exts/utils/snekbox.py | 629 ----------------------------------- bot/exts/utils/snekbox/__init__.py | 12 + bot/exts/utils/snekbox/_cog.py | 519 +++++++++++++++++++++++++++++ bot/exts/utils/snekbox/_eval.py | 117 +++++++ bot/exts/utils/snekbox/_io.py | 65 ++++ bot/exts/utils/snekio.py | 65 ---- tests/bot/exts/utils/test_snekbox.py | 40 ++- 7 files changed, 735 insertions(+), 712 deletions(-) delete mode 100644 bot/exts/utils/snekbox.py create mode 100644 bot/exts/utils/snekbox/__init__.py create mode 100644 bot/exts/utils/snekbox/_cog.py create mode 100644 bot/exts/utils/snekbox/_eval.py create mode 100644 bot/exts/utils/snekbox/_io.py delete mode 100644 bot/exts/utils/snekio.py (limited to 'tests') diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py deleted file mode 100644 index b89882a65..000000000 --- a/bot/exts/utils/snekbox.py +++ /dev/null @@ -1,629 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import re -from dataclasses import dataclass, field -from functools import partial -from operator import attrgetter -from signal import Signals -from textwrap import dedent -from typing import Literal, Optional, TYPE_CHECKING, Tuple - -from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui -from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only -from pydis_core.utils import interactions -from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX - -from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, Roles, URLs -from bot.decorators import redirect_output -from bot.exts.help_channels._channel import is_help_forum_post -from bot.exts.utils.snekio import FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt -from bot.log import get_logger -from bot.utils import send_to_paste_service -from bot.utils.lock import LockedResourceError, lock_arg -from bot.utils.services import PasteTooLongError, PasteUploadError - -if TYPE_CHECKING: - from bot.exts.filters.filtering import Filtering - -log = get_logger(__name__) - -ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") - -# The timeit command should only output the very last line, so all other output should be suppressed. -# This will be used as the setup code along with any setup code provided. -TIMEIT_SETUP_WRAPPER = """ -import atexit -import sys -from collections import deque - -if not hasattr(sys, "_setup_finished"): - class Writer(deque): - '''A single-item deque wrapper for sys.stdout that will return the last line when read() is called.''' - - def __init__(self): - super().__init__(maxlen=1) - - def write(self, string): - '''Append the line to the queue if it is not empty.''' - if string.strip(): - self.append(string) - - def read(self): - '''This method will be called when print() is called. - - The queue is emptied as we don't need the output later. - ''' - return self.pop() - - def flush(self): - '''This method will be called eventually, but we don't need to do anything here.''' - pass - - sys.stdout = Writer() - - def print_last_line(): - if sys.stdout: # If the deque is empty (i.e. an error happened), calling read() will raise an error - # Use sys.__stdout__ here because sys.stdout is set to a Writer() instance - print(sys.stdout.read(), file=sys.__stdout__) - - atexit.register(print_last_line) # When exiting, print the last line (hopefully it will be the timeit output) - sys._setup_finished = None -{setup} -""" - -MAX_PASTE_LENGTH = 10_000 - -# The Snekbox commands' whitelists and blacklists. -NO_SNEKBOX_CHANNELS = (Channels.python_general,) -NO_SNEKBOX_CATEGORIES = () -SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) - -SIGKILL = 9 - -REDO_EMOJI = '\U0001f501' # :repeat: -REDO_TIMEOUT = 30 - -PythonVersion = Literal["3.10", "3.11"] - - -@dataclass -class EvalJob: - """Job to be evaluated by snekbox.""" - - args: list[str] - files: list[FileAttachment] = field(default_factory=list) - name: str = "eval" - version: PythonVersion = "3.11" - - @classmethod - def from_code(cls, code: str, path: str = "main.py") -> EvalJob: - """Create an EvalJob from a code string.""" - return cls( - args=[path], - files=[FileAttachment(path, code.encode())], - ) - - def as_version(self, version: PythonVersion) -> EvalJob: - """Return a copy of the job with a different Python version.""" - return EvalJob( - args=self.args, - files=self.files, - name=self.name, - version=version, - ) - - def to_dict(self) -> dict[str, list[str | dict[str, str]]]: - """Convert the job to a dict.""" - return { - "args": self.args, - "files": [file.to_dict() for file in self.files], - } - - -@dataclass(frozen=True) -class EvalResult: - """The result of an eval job.""" - - stdout: str - returncode: int | None - files: list[FileAttachment] = field(default_factory=list) - err_files: list[str] = field(default_factory=list) - - @property - def status_emoji(self) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - # If there are attachments, skip empty output warning - if not self.stdout.strip() and not self.files: # No output - return ":warning:" - elif self.returncode == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - def get_message(self, job: EvalJob) -> tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}" - error = "" - - if self.returncode is None: - msg = f"Your {job.version} {job.name} job has failed" - error = self.stdout.strip() - elif self.returncode == 128 + SIGKILL: - msg = f"Your {job.version} {job.name} job timed out or ran out of memory" - elif self.returncode == 255: - msg = f"Your {job.version} {job.name} job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - with contextlib.suppress(ValueError): - name = Signals(self.returncode - 128).name - msg = f"{msg} ({name})" - - # Add error message for failed attachments - if self.err_files: - failed_files = f"({', '.join(self.err_files)})" - msg += ( - f".\n\n> Some attached files were not able to be uploaded {failed_files}." - f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" - ) - - return msg, error - - @classmethod - def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: - """Create an EvalResult from a dict.""" - res = cls( - stdout=data["stdout"], - returncode=data["returncode"], - ) - - for file in data.get("files", []): - try: - res.files.append(FileAttachment.from_dict(file)) - except ValueError as e: - log.info(f"Failed to parse file from snekbox response: {e}") - res.err_files.append(file["path"]) - - return res - - -class CodeblockConverter(Converter): - """Attempts to extract code from a codeblock, if provided.""" - - @classmethod - async def convert(cls, ctx: Context, code: str) -> list[str]: - """ - Extract code from the Markdown, format it, and insert it into the code template. - - If there is any code block, ignore text outside the code block. - Use the first code block, but prefer a fenced code block. - If there are several fenced code blocks, concatenate only the fenced code blocks. - - Return a list of code blocks if any, otherwise return a list with a single string of code. - """ - if match := list(FORMATTED_CODE_REGEX.finditer(code)): - blocks = [block for block in match if block.group("block")] - - if len(blocks) > 1: - codeblocks = [block.group("code") for block in blocks] - info = "several code blocks" - else: - match = match[0] if len(blocks) == 0 else blocks[0] - code, block, lang, delim = match.group("code", "block", "lang", "delim") - codeblocks = [dedent(code)] - if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" - else: - info = f"{delim}-enclosed inline code" - else: - codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] - info = "unformatted or badly formatted code" - - code = "\n".join(codeblocks) - log.trace(f"Extracted {info} for evaluation:\n{code}") - return codeblocks - - -class PythonVersionSwitcherButton(ui.Button): - """A button that allows users to re-run their eval command in a different Python version.""" - - def __init__( - self, - version_to_switch_to: PythonVersion, - snekbox_cog: Snekbox, - ctx: Context, - job: EvalJob, - ) -> None: - self.version_to_switch_to = version_to_switch_to - super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary) - - self.snekbox_cog = snekbox_cog - self.ctx = ctx - self.job = job - - async def callback(self, interaction: Interaction) -> None: - """ - Tell snekbox to re-run the user's code in the alternative Python version. - - Use a task calling snekbox, as run_job is blocking while it waits for edit/reaction on the message. - """ - # Defer response here so that the Discord UI doesn't mark this interaction as failed if the job - # takes too long to run. - await interaction.response.defer() - - with contextlib.suppress(NotFound): - # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button. - # The log arg on send_job will stop the actual job from running. - await interaction.message.delete() - - await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to)) - - -class Snekbox(Cog): - """Safe evaluation of Python code using Snekbox.""" - - def __init__(self, bot: Bot): - self.bot = bot - self.jobs = {} - - def build_python_version_switcher_view( - self, - current_python_version: PythonVersion, - ctx: Context, - job: EvalJob, - ) -> interactions.ViewWithUserAndRoleCheck: - """Return a view that allows the user to change what version of Python their code is run on.""" - if current_python_version == "3.10": - alt_python_version = "3.11" - else: - alt_python_version = "3.10" - - view = interactions.ViewWithUserAndRoleCheck( - allowed_users=(ctx.author.id,), - allowed_roles=MODERATION_ROLES, - ) - view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job)) - view.add_item(interactions.DeleteMessageButton()) - - return view - - async def post_job(self, job: EvalJob) -> EvalResult: - """Send a POST request to the Snekbox API to evaluate code and return the results.""" - if job.version == "3.10": - url = URLs.snekbox_eval_api - else: - url = URLs.snekbox_311_eval_api - - data = job.to_dict() - - async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return EvalResult.from_dict(await resp.json()) - - @staticmethod - async def upload_output(output: str) -> Optional[str]: - """Upload the job's output to a paste service and return a URL to it if successful.""" - log.trace("Uploading full output to paste service...") - - try: - return await send_to_paste_service(output, extension="txt", max_length=MAX_PASTE_LENGTH) - except PasteTooLongError: - return "too long to upload" - except PasteUploadError: - return "unable to upload" - - @staticmethod - def prepare_timeit_input(codeblocks: list[str]) -> list[str]: - """ - Join the codeblocks into a single string, then return the arguments in a list. - - If there are multiple codeblocks, insert the first one into the wrapped setup code. - """ - args = ["-m", "timeit"] - setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else "" - code = "\n".join(codeblocks) - - args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) - return args - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: - """ - Format the output and return a tuple of the formatted output and a URL to the full output. - - Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters - and upload the full output to a paste service. - """ - output = output.rstrip("\n") - original_output = output # To be uploaded to a pasting service if needed - paste_link = None - - if "<@" in output: - output = output.replace("<@", "<@\u200B") # Zero-width space - - if " 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) - - if lines > 10: - truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" - else: - output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: - truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" - - if truncated: - paste_link = await self.upload_output(original_output) - - output = output or "[No output]" - - return output, paste_link - - @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) - async def send_job(self, ctx: Context, job: EvalJob) -> Message: - """ - Evaluate code, format it, and send the output to the corresponding channel. - - Return the bot response. - """ - async with ctx.typing(): - result = await self.post_job(job) - msg, error = result.get_message(job) - - if error: - output, paste_link = error, None - else: - log.trace("Formatting output...") - output, paste_link = await self.format_output(result.stdout) - - msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" - if not result.files or output not in ("[No output]", ""): - msg += f"\n```\n{output}\n```" - - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - # Collect stats of job fails + successes - if result.returncode != 0: - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - - filter_cog: Filtering | None = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) - view = self.build_python_version_switcher_view(job.version, ctx, job) - - # Attach files if provided - files = [f.to_file() for f in result.files] - response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) - view.message = response - - log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}") - return response - - async def continue_job( - self, ctx: Context, response: Message, job_name: str - ) -> EvalJob | None: - """ - Check if the job's session should continue. - - If the code is to be re-evaluated, return the new EvalJob. - Otherwise, return None if the job's session should be terminated. - """ - _predicate_message_edit = partial(predicate_message_edit, ctx) - _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) - - with contextlib.suppress(NotFound): - try: - _, new_message = await self.bot.wait_for( - 'message_edit', - check=_predicate_message_edit, - timeout=REDO_TIMEOUT - ) - await ctx.message.add_reaction(REDO_EMOJI) - await self.bot.wait_for( - 'reaction_add', - check=_predicate_emoji_reaction, - timeout=10 - ) - - # Ensure the response that's about to be edited is still the most recent. - # This could have already been updated via a button press to switch to an alt Python version. - if self.jobs[ctx.message.id] != response.id: - return None - - code = await self.get_code(new_message, ctx.command) - with contextlib.suppress(HTTPException): - await ctx.message.clear_reaction(REDO_EMOJI) - await response.delete() - - if code is None: - return None - - except asyncio.TimeoutError: - with contextlib.suppress(HTTPException): - await ctx.message.clear_reaction(REDO_EMOJI) - return None - - codeblocks = await CodeblockConverter.convert(ctx, code) - - if job_name == "timeit": - return EvalJob(self.prepare_timeit_input(codeblocks)) - else: - return EvalJob.from_code("\n".join(codeblocks)) - - return None - - async def get_code(self, message: Message, command: Command) -> Optional[str]: - """ - Return the code from `message` to be evaluated. - - If the message is an invocation of the command, return the first argument or None if it - doesn't exist. Otherwise, return the full content of the message. - """ - log.trace(f"Getting context for message {message.id}.") - new_ctx = await self.bot.get_context(message) - - if new_ctx.command is command: - log.trace(f"Message {message.id} invokes {command} command.") - split = message.content.split(maxsplit=1) - code = split[1] if len(split) > 1 else None - else: - log.trace(f"Message {message.id} does not invoke {command} command.") - code = message.content - - return code - - async def run_job( - self, - ctx: Context, - job: EvalJob, - ) -> None: - """Handles checks, stats and re-evaluation of a snekbox job.""" - if Roles.helpers in (role.id for role in ctx.author.roles): - self.bot.stats.incr("snekbox_usages.roles.helpers") - else: - self.bot.stats.incr("snekbox_usages.roles.developers") - - if is_help_forum_post(ctx.channel): - self.bot.stats.incr("snekbox_usages.channels.help") - elif ctx.channel.id == Channels.bot_commands: - self.bot.stats.incr("snekbox_usages.channels.bot_commands") - else: - self.bot.stats.incr("snekbox_usages.channels.topical") - - log.info(f"Received code from {ctx.author} for evaluation:\n{job}") - - while True: - try: - response = await self.send_job(ctx, job) - except LockedResourceError: - await ctx.send( - f"{ctx.author.mention} You've already got a job running - " - "please wait for it to finish!" - ) - return - - # Store the bot's response message id per invocation, to ensure the `wait_for`s in `continue_job` - # don't trigger if the response has already been replaced by a new response. - # This can happen when a button is pressed and then original code is edited and re-run. - self.jobs[ctx.message.id] = response.id - - job = await self.continue_job(ctx, response, job.name) - if not job: - break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}") - - @command(name="eval", aliases=("e",), usage="[python_version] ") - @guild_only() - @redirect_output( - destination_channel=Channels.bot_commands, - bypass_roles=SNEKBOX_ROLES, - categories=NO_SNEKBOX_CATEGORIES, - channels=NO_SNEKBOX_CHANNELS, - ping_user=False - ) - async def eval_command( - self, - ctx: Context, - python_version: PythonVersion | None, - *, - code: CodeblockConverter - ) -> None: - """ - Run Python code and get the results. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - If multiple codeblocks are in a message, all of them will be joined and evaluated, - ignoring the text outside them. - - By default, your code is run on Python's 3.11 beta release, to assist with testing. If you - run into issues related to this Python version, you can request the bot to use Python - 3.10 by specifying the `python_version` arg and setting it to `3.10`. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - code: list[str] - python_version = python_version or "3.11" - job = EvalJob.from_code("\n".join(code)).as_version(python_version) - await self.run_job(ctx, job) - - @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] ") - @guild_only() - @redirect_output( - destination_channel=Channels.bot_commands, - bypass_roles=SNEKBOX_ROLES, - categories=NO_SNEKBOX_CATEGORIES, - channels=NO_SNEKBOX_CHANNELS, - ping_user=False - ) - async def timeit_command( - self, - ctx: Context, - python_version: PythonVersion | None, - *, - code: CodeblockConverter - ) -> None: - """ - Profile Python Code to find execution time. - - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - If multiple formatted codeblocks are provided, the first one will be the setup code, which will - not be timed. The remaining codeblocks will be joined together and timed. - - By default your code is run on Python's 3.11 beta release, to assist with testing. If you - run into issues related to this Python version, you can request the bot to use Python - 3.10 by specifying the `python_version` arg and setting it to `3.10`. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! - """ - code: list[str] - python_version = python_version or "3.11" - args = self.prepare_timeit_input(code) - job = EvalJob(args, version=python_version, name="timeit") - - await self.run_job(ctx, job) - - -def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: - """Return True if the edited message is the context message and the content was indeed modified.""" - return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - - -def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: - """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI - - -async def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox/__init__.py b/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..cd1d3b059 --- /dev/null +++ b/bot/exts/utils/snekbox/__init__.py @@ -0,0 +1,12 @@ +from bot.bot import Bot +from bot.exts.utils.snekbox._cog import CodeblockConverter, Snekbox +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult + +__all__ = ("CodeblockConverter", "Snekbox", "EvalJob", "EvalResult") + + +async def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + # Defer import to reduce side effects from importing the codeblock package. + from bot.exts.utils.snekbox._cog import Snekbox + await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py new file mode 100644 index 000000000..9abbbcfc4 --- /dev/null +++ b/bot/exts/utils/snekbox/_cog.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import asyncio +import contextlib +import re +from functools import partial +from operator import attrgetter +from textwrap import dedent +from typing import Literal, Optional, TYPE_CHECKING, Tuple + +from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui +from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only +from pydis_core.utils import interactions +from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX + +from bot.bot import Bot +from bot.constants import Channels, MODERATION_ROLES, Roles, URLs +from bot.decorators import redirect_output +from bot.exts.help_channels._channel import is_help_forum_post +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult +from bot.log import get_logger +from bot.utils import send_to_paste_service +from bot.utils.lock import LockedResourceError, lock_arg +from bot.utils.services import PasteTooLongError, PasteUploadError + +if TYPE_CHECKING: + from bot.exts.filters.filtering import Filtering + +log = get_logger(__name__) + +ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") + +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided. +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque + +if not hasattr(sys, "_setup_finished"): + class Writer(deque): + '''A single-item deque wrapper for sys.stdout that will return the last line when read() is called.''' + + def __init__(self): + super().__init__(maxlen=1) + + def write(self, string): + '''Append the line to the queue if it is not empty.''' + if string.strip(): + self.append(string) + + def read(self): + '''This method will be called when print() is called. + + The queue is emptied as we don't need the output later. + ''' + return self.pop() + + def flush(self): + '''This method will be called eventually, but we don't need to do anything here.''' + pass + + sys.stdout = Writer() + + def print_last_line(): + if sys.stdout: # If the deque is empty (i.e. an error happened), calling read() will raise an error + # Use sys.__stdout__ here because sys.stdout is set to a Writer() instance + print(sys.stdout.read(), file=sys.__stdout__) + + atexit.register(print_last_line) # When exiting, print the last line (hopefully it will be the timeit output) + sys._setup_finished = None +{setup} +""" + +MAX_PASTE_LENGTH = 10_000 + +# The Snekbox commands' whitelists and blacklists. +NO_SNEKBOX_CHANNELS = (Channels.python_general,) +NO_SNEKBOX_CATEGORIES = () +SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +REDO_EMOJI = '\U0001f501' # :repeat: +REDO_TIMEOUT = 30 + +PythonVersion = Literal["3.10", "3.11"] + + +class CodeblockConverter(Converter): + """Attempts to extract code from a codeblock, if provided.""" + + @classmethod + async def convert(cls, ctx: Context, code: str) -> list[str]: + """ + Extract code from the Markdown, format it, and insert it into the code template. + + If there is any code block, ignore text outside the code block. + Use the first code block, but prefer a fenced code block. + If there are several fenced code blocks, concatenate only the fenced code blocks. + + Return a list of code blocks if any, otherwise return a list with a single string of code. + """ + if match := list(FORMATTED_CODE_REGEX.finditer(code)): + blocks = [block for block in match if block.group("block")] + + if len(blocks) > 1: + codeblocks = [block.group("code") for block in blocks] + info = "several code blocks" + else: + match = match[0] if len(blocks) == 0 else blocks[0] + code, block, lang, delim = match.group("code", "block", "lang", "delim") + codeblocks = [dedent(code)] + if block: + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + else: + info = f"{delim}-enclosed inline code" + else: + codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] + info = "unformatted or badly formatted code" + + code = "\n".join(codeblocks) + log.trace(f"Extracted {info} for evaluation:\n{code}") + return codeblocks + + +class PythonVersionSwitcherButton(ui.Button): + """A button that allows users to re-run their eval command in a different Python version.""" + + def __init__( + self, + version_to_switch_to: PythonVersion, + snekbox_cog: Snekbox, + ctx: Context, + job: EvalJob, + ) -> None: + self.version_to_switch_to = version_to_switch_to + super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary) + + self.snekbox_cog = snekbox_cog + self.ctx = ctx + self.job = job + + async def callback(self, interaction: Interaction) -> None: + """ + Tell snekbox to re-run the user's code in the alternative Python version. + + Use a task calling snekbox, as run_job is blocking while it waits for edit/reaction on the message. + """ + # Defer response here so that the Discord UI doesn't mark this interaction as failed if the job + # takes too long to run. + await interaction.response.defer() + + with contextlib.suppress(NotFound): + # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button. + # The log arg on send_job will stop the actual job from running. + await interaction.message.delete() + + await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to)) + + +class Snekbox(Cog): + """Safe evaluation of Python code using Snekbox.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.jobs = {} + + def build_python_version_switcher_view( + self, + current_python_version: PythonVersion, + ctx: Context, + job: EvalJob, + ) -> interactions.ViewWithUserAndRoleCheck: + """Return a view that allows the user to change what version of Python their code is run on.""" + if current_python_version == "3.10": + alt_python_version = "3.11" + else: + alt_python_version = "3.10" + + view = interactions.ViewWithUserAndRoleCheck( + allowed_users=(ctx.author.id,), + allowed_roles=MODERATION_ROLES, + ) + view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job)) + view.add_item(interactions.DeleteMessageButton()) + + return view + + async def post_job(self, job: EvalJob) -> EvalResult: + """Send a POST request to the Snekbox API to evaluate code and return the results.""" + if job.version == "3.10": + url = URLs.snekbox_eval_api + else: + url = URLs.snekbox_311_eval_api + + data = job.to_dict() + + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: + return EvalResult.from_dict(await resp.json()) + + @staticmethod + async def upload_output(output: str) -> Optional[str]: + """Upload the job's output to a paste service and return a URL to it if successful.""" + log.trace("Uploading full output to paste service...") + + try: + return await send_to_paste_service(output, extension="txt", max_length=MAX_PASTE_LENGTH) + except PasteTooLongError: + return "too long to upload" + except PasteUploadError: + return "unable to upload" + + @staticmethod + def prepare_timeit_input(codeblocks: list[str]) -> list[str]: + """ + Join the codeblocks into a single string, then return the arguments in a list. + + If there are multiple codeblocks, insert the first one into the wrapped setup code. + """ + args = ["-m", "timeit"] + setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else "" + code = "\n".join(codeblocks) + + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) + return args + + async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + """ + Format the output and return a tuple of the formatted output and a URL to the full output. + + Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters + and upload the full output to a paste service. + """ + output = output.rstrip("\n") + original_output = output # To be uploaded to a pasting service if needed + paste_link = None + + if "<@" in output: + output = output.replace("<@", "<@\u200B") # Zero-width space + + if " 0: + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines + output = "\n".join(output) + + if lines > 10: + truncated = True + if len(output) >= 1000: + output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + else: + output = f"{output}\n... (truncated - too many lines)" + elif len(output) >= 1000: + truncated = True + output = f"{output[:1000]}\n... (truncated - too long)" + + if truncated: + paste_link = await self.upload_output(original_output) + + output = output or "[No output]" + + return output, paste_link + + @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) + async def send_job(self, ctx: Context, job: EvalJob) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + result = await self.post_job(job) + msg, error = result.get_message(job) + + if error: + output, paste_link = error, None + else: + log.trace("Formatting output...") + output, paste_link = await self.format_output(result.stdout) + + msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" + if not result.files or output not in ("[No output]", ""): + msg += f"\n```\n{output}\n```" + + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + # Collect stats of job fails + successes + if result.returncode != 0: + self.bot.stats.incr("snekbox.python.fail") + else: + self.bot.stats.incr("snekbox.python.success") + + filter_cog: Filtering | None = self.bot.get_cog("Filtering") + filter_triggered = False + if filter_cog: + filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) + if filter_triggered: + response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + else: + allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + view = self.build_python_version_switcher_view(job.version, ctx, job) + + # Attach files if provided + files = [f.to_file() for f in result.files] + response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) + view.message = response + + log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}") + return response + + async def continue_job( + self, ctx: Context, response: Message, job_name: str + ) -> EvalJob | None: + """ + Check if the job's session should continue. + + If the code is to be re-evaluated, return the new EvalJob. + Otherwise, return None if the job's session should be terminated. + """ + _predicate_message_edit = partial(predicate_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_message_edit, + timeout=REDO_TIMEOUT + ) + await ctx.message.add_reaction(REDO_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + # Ensure the response that's about to be edited is still the most recent. + # This could have already been updated via a button press to switch to an alt Python version. + if self.jobs[ctx.message.id] != response.id: + return None + + code = await self.get_code(new_message, ctx.command) + with contextlib.suppress(HTTPException): + await ctx.message.clear_reaction(REDO_EMOJI) + await response.delete() + + if code is None: + return None + + except asyncio.TimeoutError: + with contextlib.suppress(HTTPException): + await ctx.message.clear_reaction(REDO_EMOJI) + return None + + codeblocks = await CodeblockConverter.convert(ctx, code) + + if job_name == "timeit": + return EvalJob(self.prepare_timeit_input(codeblocks)) + else: + return EvalJob.from_code("\n".join(codeblocks)) + + return None + + async def get_code(self, message: Message, command: Command) -> Optional[str]: + """ + Return the code from `message` to be evaluated. + + If the message is an invocation of the command, return the first argument or None if it + doesn't exist. Otherwise, return the full content of the message. + """ + log.trace(f"Getting context for message {message.id}.") + new_ctx = await self.bot.get_context(message) + + if new_ctx.command is command: + log.trace(f"Message {message.id} invokes {command} command.") + split = message.content.split(maxsplit=1) + code = split[1] if len(split) > 1 else None + else: + log.trace(f"Message {message.id} does not invoke {command} command.") + code = message.content + + return code + + async def run_job( + self, + ctx: Context, + job: EvalJob, + ) -> None: + """Handles checks, stats and re-evaluation of a snekbox job.""" + if Roles.helpers in (role.id for role in ctx.author.roles): + self.bot.stats.incr("snekbox_usages.roles.helpers") + else: + self.bot.stats.incr("snekbox_usages.roles.developers") + + if is_help_forum_post(ctx.channel): + self.bot.stats.incr("snekbox_usages.channels.help") + elif ctx.channel.id == Channels.bot_commands: + self.bot.stats.incr("snekbox_usages.channels.bot_commands") + else: + self.bot.stats.incr("snekbox_usages.channels.topical") + + log.info(f"Received code from {ctx.author} for evaluation:\n{job}") + + while True: + try: + response = await self.send_job(ctx, job) + except LockedResourceError: + await ctx.send( + f"{ctx.author.mention} You've already got a job running - " + "please wait for it to finish!" + ) + return + + # Store the bot's response message id per invocation, to ensure the `wait_for`s in `continue_job` + # don't trigger if the response has already been replaced by a new response. + # This can happen when a button is pressed and then original code is edited and re-run. + self.jobs[ctx.message.id] = response.id + + job = await self.continue_job(ctx, response, job.name) + if not job: + break + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}") + + @command(name="eval", aliases=("e",), usage="[python_version] ") + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, + ping_user=False + ) + async def eval_command( + self, + ctx: Context, + python_version: PythonVersion | None, + *, + code: CodeblockConverter + ) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + If multiple codeblocks are in a message, all of them will be joined and evaluated, + ignoring the text outside them. + + By default, your code is run on Python's 3.11 beta release, to assist with testing. If you + run into issues related to this Python version, you can request the bot to use Python + 3.10 by specifying the `python_version` arg and setting it to `3.10`. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code: list[str] + python_version = python_version or "3.11" + job = EvalJob.from_code("\n".join(code)).as_version(python_version) + await self.run_job(ctx, job) + + @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] ") + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=SNEKBOX_ROLES, + categories=NO_SNEKBOX_CATEGORIES, + channels=NO_SNEKBOX_CHANNELS, + ping_user=False + ) + async def timeit_command( + self, + ctx: Context, + python_version: PythonVersion | None, + *, + code: CodeblockConverter + ) -> None: + """ + Profile Python Code to find execution time. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + If multiple formatted codeblocks are provided, the first one will be the setup code, which will + not be timed. The remaining codeblocks will be joined together and timed. + + By default your code is run on Python's 3.11 beta release, to assist with testing. If you + run into issues related to this Python version, you can request the bot to use Python + 3.10 by specifying the `python_version` arg and setting it to `3.10`. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code: list[str] + python_version = python_version or "3.11" + args = self.prepare_timeit_input(code) + job = EvalJob(args, version=python_version, name="timeit") + + await self.run_job(ctx, job) + + +def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content + + +def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py new file mode 100644 index 000000000..784de5a10 --- /dev/null +++ b/bot/exts/utils/snekbox/_eval.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from signal import Signals +from typing import TYPE_CHECKING + +from bot.exts.utils.snekbox._io import FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt +from bot.log import get_logger + +if TYPE_CHECKING: + from bot.exts.utils.snekbox._cog import PythonVersion + +log = get_logger(__name__) + +SIGKILL = 9 + + +@dataclass +class EvalJob: + """Job to be evaluated by snekbox.""" + + args: list[str] + files: list[FileAttachment] = field(default_factory=list) + name: str = "eval" + version: PythonVersion = "3.11" + + @classmethod + def from_code(cls, code: str, path: str = "main.py") -> EvalJob: + """Create an EvalJob from a code string.""" + return cls( + args=[path], + files=[FileAttachment(path, code.encode())], + ) + + def as_version(self, version: PythonVersion) -> EvalJob: + """Return a copy of the job with a different Python version.""" + return EvalJob( + args=self.args, + files=self.files, + name=self.name, + version=version, + ) + + def to_dict(self) -> dict[str, list[str | dict[str, str]]]: + """Convert the job to a dict.""" + return { + "args": self.args, + "files": [file.to_dict() for file in self.files], + } + + +@dataclass(frozen=True) +class EvalResult: + """The result of an eval job.""" + + stdout: str + returncode: int | None + files: list[FileAttachment] = field(default_factory=list) + err_files: list[str] = field(default_factory=list) + + @property + def status_emoji(self) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + # If there are attachments, skip empty output warning + if not self.stdout.strip() and not self.files: # No output + return ":warning:" + elif self.returncode == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + def get_message(self, job: EvalJob) -> tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}" + error = "" + + if self.returncode is None: + msg = f"Your {job.version} {job.name} job has failed" + error = self.stdout.strip() + elif self.returncode == 128 + SIGKILL: + msg = f"Your {job.version} {job.name} job timed out or ran out of memory" + elif self.returncode == 255: + msg = f"Your {job.version} {job.name} job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + with contextlib.suppress(ValueError): + name = Signals(self.returncode - 128).name + msg = f"{msg} ({name})" + + # Add error message for failed attachments + if self.err_files: + failed_files = f"({', '.join(self.err_files)})" + msg += ( + f".\n\n> Some attached files were not able to be uploaded {failed_files}." + f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" + ) + + return msg, error + + @classmethod + def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: + """Create an EvalResult from a dict.""" + res = cls( + stdout=data["stdout"], + returncode=data["returncode"], + ) + + for file in data.get("files", []): + try: + res.files.append(FileAttachment.from_dict(file)) + except ValueError as e: + log.info(f"Failed to parse file from snekbox response: {e}") + res.err_files.append(file["path"]) + + return res diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py new file mode 100644 index 000000000..a7f84a241 --- /dev/null +++ b/bot/exts/utils/snekbox/_io.py @@ -0,0 +1,65 @@ +"""I/O File protocols for snekbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path + +from discord import File + +# Note discord bot upload limit is 8 MiB per file, +# or 50 MiB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024 + + +def sizeof_fmt(num: int, suffix: str = "B") -> str: + """Return a human-readable file size.""" + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024 + return f"{num:.1f}Yi{suffix}" + + +@dataclass +class FileAttachment: + """File Attachment from Snekbox eval.""" + + path: str + content: bytes + + def __repr__(self) -> str: + """Return the content as a string.""" + content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content + return f"FileAttachment(path={self.path!r}, content={content})" + + @classmethod + def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: + """Create a FileAttachment from a dict response.""" + size = data.get("size") + if (size and size > size_limit) or (len(data["content"]) > size_limit): + raise ValueError("File size exceeds limit") + + content = b64decode(data["content"]) + + if len(content) > size_limit: + raise ValueError("File size exceeds limit") + + return cls(data["path"], content) + + def to_dict(self) -> dict[str, str]: + """Convert the attachment to a json dict.""" + content = self.content + if isinstance(content, str): + content = content.encode("utf-8") + + return { + "path": self.path, + "content": b64encode(content).decode("ascii"), + } + + def to_file(self) -> File: + """Convert to a discord.File.""" + name = Path(self.path).name + return File(BytesIO(self.content), filename=name) diff --git a/bot/exts/utils/snekio.py b/bot/exts/utils/snekio.py deleted file mode 100644 index a7f84a241..000000000 --- a/bot/exts/utils/snekio.py +++ /dev/null @@ -1,65 +0,0 @@ -"""I/O File protocols for snekbox.""" -from __future__ import annotations - -from base64 import b64decode, b64encode -from dataclasses import dataclass -from io import BytesIO -from pathlib import Path - -from discord import File - -# Note discord bot upload limit is 8 MiB per file, -# or 50 MiB for lvl 2 boosted servers -FILE_SIZE_LIMIT = 8 * 1024 * 1024 - - -def sizeof_fmt(num: int, suffix: str = "B") -> str: - """Return a human-readable file size.""" - for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): - if abs(num) < 1024: - return f"{num:3.1f}{unit}{suffix}" - num /= 1024 - return f"{num:.1f}Yi{suffix}" - - -@dataclass -class FileAttachment: - """File Attachment from Snekbox eval.""" - - path: str - content: bytes - - def __repr__(self) -> str: - """Return the content as a string.""" - content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content - return f"FileAttachment(path={self.path!r}, content={content})" - - @classmethod - def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: - """Create a FileAttachment from a dict response.""" - size = data.get("size") - if (size and size > size_limit) or (len(data["content"]) > size_limit): - raise ValueError("File size exceeds limit") - - content = b64decode(data["content"]) - - if len(content) > size_limit: - raise ValueError("File size exceeds limit") - - return cls(data["path"], content) - - def to_dict(self) -> dict[str, str]: - """Convert the attachment to a json dict.""" - content = self.content - if isinstance(content, str): - content = content.encode("utf-8") - - return { - "path": self.path, - "content": b64encode(content).decode("ascii"), - } - - def to_file(self) -> File: - """Convert to a discord.File.""" - name = Path(self.path).name - return File(BytesIO(self.content), filename=name) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 722c5c569..31b1ca260 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -55,14 +55,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_upload_output_reject_too_long(self): """Reject output longer than MAX_PASTE_LENGTH.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) + result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") - @patch("bot.exts.utils.snekbox.send_to_paste_service") + @patch("bot.exts.utils.snekbox._cog.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) + mock_paste_util.assert_called_once_with( + "Test output.", + extension="txt", + max_length=snekbox._cog.MAX_PASTE_LENGTH + ) async def test_codeblock_converter(self): ctx = MockContext() @@ -95,7 +99,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, setup_code, test_name in cases: - setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): self.assertEqual(self.cog.prepare_timeit_input(case), expected) @@ -104,7 +108,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """EvalResult.message, should return error and message.""" cases = ( ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), + ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) ) for stdout, returncode, expected in cases: @@ -113,7 +117,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): job = EvalJob([]) self.assertEqual(result.get_message(job), expected) - @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) + @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( @@ -121,7 +125,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ("Your 3.10 eval job has completed with return code 127", "") ) - @patch('bot.exts.utils.snekbox.Signals') + @patch('bot.exts.utils.snekbox._eval.Signals') def test_eval_result_message_valid_signal(self, mock_signals: Mock): mock_signals.return_value.name = "SIGTEST" result = EvalResult(stdout="", returncode=127) @@ -328,7 +332,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() - @patch("bot.exts.utils.snekbox.partial") + @patch("bot.exts.utils.snekbox._cog.partial") async def test_continue_job_does_continue(self, partial_mock): """Test that the continue_job function does continue if required conditions are met.""" ctx = MockContext( @@ -353,14 +357,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ( call( 'message_edit', - check=partial_mock(snekbox.predicate_message_edit, ctx), - timeout=snekbox.REDO_TIMEOUT, + check=partial_mock(snekbox._cog.predicate_message_edit, ctx), + timeout=snekbox._cog.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) response.delete.assert_called_once() async def test_continue_job_does_not_continue(self): @@ -369,7 +373,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) self.assertEqual(actual, None) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -411,18 +415,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) def test_predicate_emoji_reaction(self): """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REDO_EMOJI + valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI + invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -435,7 +439,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) -- cgit v1.2.3 From 75d1fde1ed516b5698be2b652297765f1ba5ccfe Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 17 Dec 2022 19:17:49 +0800 Subject: Update unit tests for EvalResult message change --- tests/bot/exts/utils/test_snekbox.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 31b1ca260..3ce832771 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -28,7 +28,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() - resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137}) + resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) context_manager = MagicMock() context_manager.__aenter__.return_value = resp @@ -107,23 +107,32 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_eval_result_message(self): """EvalResult.message, should return error and message.""" cases = ( - ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), - ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), - ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), + ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', '')) ) for stdout, returncode, expected in cases: + exp_msg, exp_err, exp_files_err = expected with self.subTest(stdout=stdout, returncode=returncode, expected=expected): result = EvalResult(stdout=stdout, returncode=returncode) job = EvalJob([]) - self.assertEqual(result.get_message(job), expected) + # Check all 3 message types + msg = result.get_message(job) + self.assertEqual(msg, exp_msg) + error = result.error_message + self.assertEqual(error, exp_err) + files_error = result.files_error_message + self.assertEqual(files_error, exp_files_err) @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_message(EvalJob([], version="3.10")), - ("Your 3.10 eval job has completed with return code 127", "") + "Your 3.10 eval job has completed with return code 127" ) + self.assertEqual(result.error_message, "") + self.assertEqual(result.files_error_message, "") @patch('bot.exts.utils.snekbox._eval.Signals') def test_eval_result_message_valid_signal(self, mock_signals: Mock): @@ -131,7 +140,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_message(EvalJob([], version="3.11")), - ("Your 3.11 eval job has completed with return code 127 (SIGTEST)", "") + "Your 3.11 eval job has completed with return code 127 (SIGTEST)" ) def test_eval_result_status_emoji(self): -- cgit v1.2.3 From a82c15dcb3643856ca1276679b4ba5e0a3854a3a Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 17 Dec 2022 20:07:45 +0800 Subject: Add unit test for files_error_message --- tests/bot/exts/utils/test_snekbox.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 3ce832771..afe48dceb 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -105,7 +105,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_eval_result_message(self): - """EvalResult.message, should return error and message.""" + """EvalResult.get_message(), should return message.""" cases = ( ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), @@ -124,6 +124,33 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): files_error = result.files_error_message self.assertEqual(files_error, exp_files_err) + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_message(self): + """EvalResult.files_error_message, should return files error message.""" + cases = [ + ([], ["abc"], ( + "Failed to upload 1 file (abc)." + " File sizes should each not exceed 8 MiB." + )), + ([], ["file1.bin", "f2.bin"], ( + "Failed to upload 2 files (file1.bin, f2.bin)." + " File sizes should each not exceed 8 MiB." + )), + (["a", "b"], ["c"], ( + "Failed to upload 1 file (c)" + " as it exceeded the 2 file limit." + )), + (["a"], ["b", "c"], ( + "Failed to upload 2 files (b, c)" + " as they exceeded the 2 file limit." + )), + ] + for files, failed_files, expected_msg in cases: + with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): + result = EvalResult("", 0, files, failed_files) + msg = result.files_error_message + self.assertEqual(msg, expected_msg) + @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) -- cgit v1.2.3 From d2de465e8fb3659eb1fe40aa2d1c9e9cb80e0d11 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Sat, 17 Dec 2022 20:16:57 +0800 Subject: Add unit tests for EvalResult.files_error_str --- bot/exts/utils/snekbox/_eval.py | 2 +- tests/bot/exts/utils/test_snekbox.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py index 748b58a3b..56a02d981 100644 --- a/bot/exts/utils/snekbox/_eval.py +++ b/bot/exts/utils/snekbox/_eval.py @@ -105,7 +105,7 @@ class EvalResult: names = [] for file in self.failed_files: char_max -= len(file) - if char_max <= 0 or len(names) >= file_max: + if char_max < 0 or len(names) >= file_max: names.append("...") break names.append(file) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index afe48dceb..5e13ac4bb 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -151,6 +151,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): msg = result.files_error_message self.assertEqual(msg, expected_msg) + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_str(self): + """EvalResult.files_error_message, should return files error message.""" + max_file_name = "a" * 32 + cases = [ + (["x.ini"], "x.ini"), + (["dog.py", "cat.py"], "dog.py, cat.py"), + # 3 files limit + (["a", "b", "c"], "a, b, c"), + (["a", "b", "c", "d"], "a, b, c, ..."), + (["x", "y", "z"] + ["a"] * 100, "x, y, z, ..."), + # 32 char limit + ([max_file_name], max_file_name), + ([max_file_name, "b"], f"{max_file_name}, ..."), + ([max_file_name + "a"], "...") + ] + for failed_files, expected in cases: + result = EvalResult("", 0, [], failed_files) + msg = result.failed_files_str(char_max=32, file_max=3) + self.assertEqual(msg, expected) + @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) -- cgit v1.2.3 From d7722b3c335af13d43d3b958c34802d2a98d0279 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Tue, 20 Dec 2022 12:39:56 +0800 Subject: Rename method get_failed_files_str --- bot/exts/utils/snekbox/_eval.py | 4 ++-- tests/bot/exts/utils/test_snekbox.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py index 56a02d981..afbb4a32d 100644 --- a/bot/exts/utils/snekbox/_eval.py +++ b/bot/exts/utils/snekbox/_eval.py @@ -86,7 +86,7 @@ class EvalResult: if not self.failed_files: return "" - failed_files = f"({self.failed_files_str()})" + failed_files = f"({self.get_failed_files_str()})" n_failed = len(self.failed_files) files = f"file{'s' if n_failed > 1 else ''}" @@ -100,7 +100,7 @@ class EvalResult: return msg - def failed_files_str(self, char_max: int = 85, file_max: int = 5) -> str: + def get_failed_files_str(self, char_max: int = 85, file_max: int = 5) -> str: """Return a string containing the names of failed files, truncated to lower of char_max and file_max.""" names = [] for file in self.failed_files: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 5e13ac4bb..b129bfcdb 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -169,7 +169,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ] for failed_files, expected in cases: result = EvalResult("", 0, [], failed_files) - msg = result.failed_files_str(char_max=32, file_max=3) + msg = result.get_failed_files_str(char_max=32, file_max=3) self.assertEqual(msg, expected) @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) -- cgit v1.2.3 From 7affd4816d7bbe1cfc92221c3a47553fed9d0cd8 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Tue, 20 Dec 2022 17:53:21 +0800 Subject: Refactor snekbox tests to module --- tests/bot/exts/utils/snekbox/__init__.py | 0 tests/bot/exts/utils/snekbox/test_snekbox.py | 510 +++++++++++++++++++++++++++ tests/bot/exts/utils/test_snekbox.py | 510 --------------------------- 3 files changed, 510 insertions(+), 510 deletions(-) create mode 100644 tests/bot/exts/utils/snekbox/__init__.py create mode 100644 tests/bot/exts/utils/snekbox/test_snekbox.py delete mode 100644 tests/bot/exts/utils/test_snekbox.py (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/__init__.py b/tests/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py new file mode 100644 index 000000000..b129bfcdb --- /dev/null +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -0,0 +1,510 @@ +import asyncio +import unittest +from base64 import b64encode +from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch + +from discord import AllowedMentions +from discord.ext import commands + +from bot import constants +from bot.errors import LockedResourceError +from bot.exts.utils import snekbox +from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox +from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + self.job = EvalJob.from_code("import random") + + @staticmethod + def code_args(code: str) -> tuple[EvalJob]: + """Converts code to a tuple of arguments expected.""" + return EvalJob.from_code(code), + + async def test_post_job(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) + + context_manager = MagicMock() + context_manager.__aenter__.return_value = resp + self.bot.http_session.post.return_value = context_manager + + job = EvalJob.from_code("import random").as_version("3.10") + self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) + + expected = { + "args": ["main.py"], + "files": [ + { + "path": "main.py", + "content": b64encode("import random".encode()).decode() + } + ] + } + self.bot.http_session.post.assert_called_with( + constants.URLs.snekbox_eval_api, + json=expected, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LENGTH.""" + result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1)) + self.assertEqual(result, "too long to upload") + + @patch("bot.exts.utils.snekbox._cog.send_to_paste_service") + async def test_upload_output(self, mock_paste_util): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + await self.cog.upload_output("Test output.") + mock_paste_util.assert_called_once_with( + "Test output.", + extension="txt", + max_length=snekbox._cog.MAX_PASTE_LENGTH + ) + + async def test_codeblock_converter(self): + ctx = MockContext() + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), + ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', + 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), + ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', + 'print("How\'s it going?")', 'code block preceded by inline code'), + ('`print("Hello world!")`\ntext\n`print("Hello world!")`', + 'print("Hello world!")', 'one inline code block of two') + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual( + '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + ) + + def test_prepare_timeit_input(self): + """Test the prepare_timeit_input codeblock detection.""" + base_args = ('-m', 'timeit', '-s') + cases = ( + (['print("Hello World")'], '', 'single block of code'), + (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), + (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') + ) + + for case, setup_code, test_name in cases: + setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] + with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): + self.assertEqual(self.cog.prepare_timeit_input(case), expected) + + def test_eval_result_message(self): + """EvalResult.get_message(), should return message.""" + cases = ( + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), + ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', '')) + ) + for stdout, returncode, expected in cases: + exp_msg, exp_err, exp_files_err = expected + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + result = EvalResult(stdout=stdout, returncode=returncode) + job = EvalJob([]) + # Check all 3 message types + msg = result.get_message(job) + self.assertEqual(msg, exp_msg) + error = result.error_message + self.assertEqual(error, exp_err) + files_error = result.files_error_message + self.assertEqual(files_error, exp_files_err) + + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_message(self): + """EvalResult.files_error_message, should return files error message.""" + cases = [ + ([], ["abc"], ( + "Failed to upload 1 file (abc)." + " File sizes should each not exceed 8 MiB." + )), + ([], ["file1.bin", "f2.bin"], ( + "Failed to upload 2 files (file1.bin, f2.bin)." + " File sizes should each not exceed 8 MiB." + )), + (["a", "b"], ["c"], ( + "Failed to upload 1 file (c)" + " as it exceeded the 2 file limit." + )), + (["a"], ["b", "c"], ( + "Failed to upload 2 files (b, c)" + " as they exceeded the 2 file limit." + )), + ] + for files, failed_files, expected_msg in cases: + with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): + result = EvalResult("", 0, files, failed_files) + msg = result.files_error_message + self.assertEqual(msg, expected_msg) + + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_str(self): + """EvalResult.files_error_message, should return files error message.""" + max_file_name = "a" * 32 + cases = [ + (["x.ini"], "x.ini"), + (["dog.py", "cat.py"], "dog.py, cat.py"), + # 3 files limit + (["a", "b", "c"], "a, b, c"), + (["a", "b", "c", "d"], "a, b, c, ..."), + (["x", "y", "z"] + ["a"] * 100, "x, y, z, ..."), + # 32 char limit + ([max_file_name], max_file_name), + ([max_file_name, "b"], f"{max_file_name}, ..."), + ([max_file_name + "a"], "...") + ] + for failed_files, expected in cases: + result = EvalResult("", 0, [], failed_files) + msg = result.get_failed_files_str(char_max=32, file_max=3) + self.assertEqual(msg, expected) + + @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) + def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): + result = EvalResult(stdout="", returncode=127) + self.assertEqual( + result.get_message(EvalJob([], version="3.10")), + "Your 3.10 eval job has completed with return code 127" + ) + self.assertEqual(result.error_message, "") + self.assertEqual(result.files_error_message, "") + + @patch('bot.exts.utils.snekbox._eval.Signals') + def test_eval_result_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = "SIGTEST" + result = EvalResult(stdout="", returncode=127) + self.assertEqual( + result.get_message(EvalJob([], version="3.11")), + "Your 3.11 eval job has completed with return code 127 (SIGTEST)" + ) + + def test_eval_result_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + result = EvalResult(stdout=stdout, returncode=returncode) + self.assertEqual(result.status_emoji, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + (' dict: + """Delay the post_job call to ensure the job runs long enough to conflict.""" + await asyncio.sleep(1) + return {'stdout': '', 'returncode': 0} + + self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) + with self.assertRaises(LockedResourceError): + await asyncio.gather( + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), + ) + + async def test_send_job(self): + """Test the send_job function.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author = MockUser(mention='@LemonLemonishBeard#0042') + + eval_result = EvalResult("", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + self.cog.upload_output = AsyncMock() # Should not be called + + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + + job = EvalJob.from_code('MyAwesomeCode') + await self.cog.send_job(ctx, job), + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], + '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed ' + 'with return code 0.\n\n```\n[No output]\n```' + ) + allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] + expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) + + self.cog.post_job.assert_called_once_with(job) + self.cog.format_output.assert_called_once_with('') + self.cog.upload_output.assert_not_called() + + async def test_send_job_with_paste_link(self): + """Test the send_job function with a too long output that generate a paste link.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + eval_result = EvalResult("Way too long beard", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], + '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job ' + 'has completed with return code 0.' + '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' + ) + + self.cog.post_job.assert_called_once_with(job) + self.cog.format_output.assert_called_once_with('Way too long beard') + + async def test_send_job_with_non_zero_eval(self): + """Test the send_job function with a code returning a non-zero code.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + eval_result = EvalResult("ERROR", 127) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.upload_output = AsyncMock() # This function isn't called + + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], + '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.' + '\n\n```\nERROR\n```' + ) + + self.cog.post_job.assert_called_once_with(job) + self.cog.upload_output.assert_not_called() + + @patch("bot.exts.utils.snekbox._cog.partial") + async def test_continue_job_does_continue(self, partial_mock): + """Test that the continue_job function does continue if required conditions are met.""" + ctx = MockContext( + message=MockMessage( + id=4, + add_reaction=AsyncMock(), + clear_reactions=AsyncMock() + ), + author=MockMember(id=14) + ) + response = MockMessage(id=42, delete=AsyncMock()) + new_msg = MockMessage() + self.cog.jobs = {4: 42} + self.bot.wait_for.side_effect = ((None, new_msg), None) + expected = "NewCode" + self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) + + actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) + self.assertEqual(actual, EvalJob.from_code(expected)) + self.bot.wait_for.assert_has_awaits( + ( + call( + 'message_edit', + check=partial_mock(snekbox._cog.predicate_message_edit, ctx), + timeout=snekbox._cog.REDO_TIMEOUT, + ), + call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10) + ) + ) + ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) + response.delete.assert_called_once() + + async def test_continue_job_does_not_continue(self): + ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) + self.bot.wait_for.side_effect = asyncio.TimeoutError + + actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) + self.assertEqual(actual, None) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) + + async def test_get_code(self): + """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" + prefix = constants.Bot.prefix + subtests = ( + (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"), + (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None), + (MagicMock(spec=commands.Command), f"{prefix}tags get foo"), + (None, "print(123)") + ) + + for command, content, *expected_code in subtests: + if not expected_code: + expected_code = content + else: + [expected_code] = expected_code + + with self.subTest(content=content, expected_code=expected_code): + self.bot.get_context.reset_mock() + self.bot.get_context.return_value = MockContext(command=command) + message = MockMessage(content=content) + + actual_code = await self.cog.get_code(message, self.cog.eval_command) + + self.bot.get_context.assert_awaited_once_with(message) + self.assertEqual(actual_code, expected_code) + + def test_predicate_message_edit(self): + """Test the predicate_message_edit function.""" + msg0 = MockMessage(id=1, content='abc') + msg1 = MockMessage(id=2, content='abcdef') + msg2 = MockMessage(id=1, content='abcdef') + + cases = ( + (msg0, msg0, False, 'same ID, same content'), + (msg0, msg1, False, 'different ID, different content'), + (msg0, msg2, True, 'same ID, different content') + ) + for ctx_msg, new_msg, expected, testname in cases: + with self.subTest(msg=f'Messages with {testname} return {expected}'): + ctx = MockContext(message=ctx_msg) + actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg) + self.assertEqual(actual, expected) + + def test_predicate_emoji_reaction(self): + """Test the predicate_emoji_reaction function.""" + valid_reaction = MockReaction(message=MockMessage(id=1)) + valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI + valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) + valid_user = MockUser(id=2) + + invalid_reaction_id = MockReaction(message=MockMessage(id=42)) + invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI + invalid_user_id = MockUser(id=42) + invalid_reaction_str = MockReaction(message=MockMessage(id=1)) + invalid_reaction_str.__str__.return_value = ':longbeard:' + + cases = ( + (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), + (valid_reaction, invalid_user_id, False, 'invalid user ID'), + (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), + (valid_reaction, valid_user, True, 'matching attributes') + ) + for reaction, user, expected, testname in cases: + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user) + self.assertEqual(actual, expected) + + +class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): + """Tests setup of the `Snekbox` cog.""" + + async def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + await snekbox.setup(bot) + bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py deleted file mode 100644 index b129bfcdb..000000000 --- a/tests/bot/exts/utils/test_snekbox.py +++ /dev/null @@ -1,510 +0,0 @@ -import asyncio -import unittest -from base64 import b64encode -from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch - -from discord import AllowedMentions -from discord.ext import commands - -from bot import constants -from bot.errors import LockedResourceError -from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox -from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser - - -class SnekboxTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Add mocked bot and cog to the instance.""" - self.bot = MockBot() - self.cog = Snekbox(bot=self.bot) - self.job = EvalJob.from_code("import random") - - @staticmethod - def code_args(code: str) -> tuple[EvalJob]: - """Converts code to a tuple of arguments expected.""" - return EvalJob.from_code(code), - - async def test_post_job(self): - """Post the eval code to the URLs.snekbox_eval_api endpoint.""" - resp = MagicMock() - resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) - - context_manager = MagicMock() - context_manager.__aenter__.return_value = resp - self.bot.http_session.post.return_value = context_manager - - job = EvalJob.from_code("import random").as_version("3.10") - self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) - - expected = { - "args": ["main.py"], - "files": [ - { - "path": "main.py", - "content": b64encode("import random".encode()).decode() - } - ] - } - self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json=expected, - raise_for_status=True - ) - resp.json.assert_awaited_once() - - async def test_upload_output_reject_too_long(self): - """Reject output longer than MAX_PASTE_LENGTH.""" - result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1)) - self.assertEqual(result, "too long to upload") - - @patch("bot.exts.utils.snekbox._cog.send_to_paste_service") - async def test_upload_output(self, mock_paste_util): - """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" - await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with( - "Test output.", - extension="txt", - max_length=snekbox._cog.MAX_PASTE_LENGTH - ) - - async def test_codeblock_converter(self): - ctx = MockContext() - cases = ( - ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), - ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), - ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), - ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), - ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', - 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), - ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', - 'print("How\'s it going?")', 'code block preceded by inline code'), - ('`print("Hello world!")`\ntext\n`print("Hello world!")`', - 'print("Hello world!")', 'one inline code block of two') - ) - for case, expected, testname in cases: - with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual( - '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected - ) - - def test_prepare_timeit_input(self): - """Test the prepare_timeit_input codeblock detection.""" - base_args = ('-m', 'timeit', '-s') - cases = ( - (['print("Hello World")'], '', 'single block of code'), - (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), - (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') - ) - - for case, setup_code, test_name in cases: - setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) - expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] - with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): - self.assertEqual(self.cog.prepare_timeit_input(case), expected) - - def test_eval_result_message(self): - """EvalResult.get_message(), should return message.""" - cases = ( - ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), - ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), - ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', '')) - ) - for stdout, returncode, expected in cases: - exp_msg, exp_err, exp_files_err = expected - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - result = EvalResult(stdout=stdout, returncode=returncode) - job = EvalJob([]) - # Check all 3 message types - msg = result.get_message(job) - self.assertEqual(msg, exp_msg) - error = result.error_message - self.assertEqual(error, exp_err) - files_error = result.files_error_message - self.assertEqual(files_error, exp_files_err) - - @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) - def test_eval_result_files_error_message(self): - """EvalResult.files_error_message, should return files error message.""" - cases = [ - ([], ["abc"], ( - "Failed to upload 1 file (abc)." - " File sizes should each not exceed 8 MiB." - )), - ([], ["file1.bin", "f2.bin"], ( - "Failed to upload 2 files (file1.bin, f2.bin)." - " File sizes should each not exceed 8 MiB." - )), - (["a", "b"], ["c"], ( - "Failed to upload 1 file (c)" - " as it exceeded the 2 file limit." - )), - (["a"], ["b", "c"], ( - "Failed to upload 2 files (b, c)" - " as they exceeded the 2 file limit." - )), - ] - for files, failed_files, expected_msg in cases: - with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): - result = EvalResult("", 0, files, failed_files) - msg = result.files_error_message - self.assertEqual(msg, expected_msg) - - @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) - def test_eval_result_files_error_str(self): - """EvalResult.files_error_message, should return files error message.""" - max_file_name = "a" * 32 - cases = [ - (["x.ini"], "x.ini"), - (["dog.py", "cat.py"], "dog.py, cat.py"), - # 3 files limit - (["a", "b", "c"], "a, b, c"), - (["a", "b", "c", "d"], "a, b, c, ..."), - (["x", "y", "z"] + ["a"] * 100, "x, y, z, ..."), - # 32 char limit - ([max_file_name], max_file_name), - ([max_file_name, "b"], f"{max_file_name}, ..."), - ([max_file_name + "a"], "...") - ] - for failed_files, expected in cases: - result = EvalResult("", 0, [], failed_files) - msg = result.get_failed_files_str(char_max=32, file_max=3) - self.assertEqual(msg, expected) - - @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) - def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): - result = EvalResult(stdout="", returncode=127) - self.assertEqual( - result.get_message(EvalJob([], version="3.10")), - "Your 3.10 eval job has completed with return code 127" - ) - self.assertEqual(result.error_message, "") - self.assertEqual(result.files_error_message, "") - - @patch('bot.exts.utils.snekbox._eval.Signals') - def test_eval_result_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = "SIGTEST" - result = EvalResult(stdout="", returncode=127) - self.assertEqual( - result.get_message(EvalJob([], version="3.11")), - "Your 3.11 eval job has completed with return code 127 (SIGTEST)" - ) - - def test_eval_result_status_emoji(self): - """Return emoji according to the eval result.""" - cases = ( - (' ', -1, ':warning:'), - ('Hello world!', 0, ':white_check_mark:'), - ('Invalid beard size', -1, ':x:') - ) - for stdout, returncode, expected in cases: - with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - result = EvalResult(stdout=stdout, returncode=returncode) - self.assertEqual(result.status_emoji, expected) - - async def test_format_output(self): - """Test output formatting.""" - self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') - - too_many_lines = ( - '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' - '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' - ) - too_long_too_many_lines = ( - "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" - ) - - cases = ( - ('', ('[No output]', None), 'No output'), - ('My awesome output', ('My awesome output', None), 'One line output'), - ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), - (' dict: - """Delay the post_job call to ensure the job runs long enough to conflict.""" - await asyncio.sleep(1) - return {'stdout': '', 'returncode': 0} - - self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) - with self.assertRaises(LockedResourceError): - await asyncio.gather( - self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), - self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), - ) - - async def test_send_job(self): - """Test the send_job function.""" - ctx = MockContext() - ctx.message = MockMessage() - ctx.send = AsyncMock() - ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - - eval_result = EvalResult("", 0) - self.cog.post_job = AsyncMock(return_value=eval_result) - self.cog.format_output = AsyncMock(return_value=('[No output]', None)) - self.cog.upload_output = AsyncMock() # Should not be called - - mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) - self.bot.get_cog.return_value = mocked_filter_cog - - job = EvalJob.from_code('MyAwesomeCode') - await self.cog.send_job(ctx, job), - - ctx.send.assert_called_once() - self.assertEqual( - ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed ' - 'with return code 0.\n\n```\n[No output]\n```' - ) - allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] - expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) - self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - - self.cog.post_job.assert_called_once_with(job) - self.cog.format_output.assert_called_once_with('') - self.cog.upload_output.assert_not_called() - - async def test_send_job_with_paste_link(self): - """Test the send_job function with a too long output that generate a paste link.""" - ctx = MockContext() - ctx.message = MockMessage() - ctx.send = AsyncMock() - ctx.author.mention = '@LemonLemonishBeard#0042' - - eval_result = EvalResult("Way too long beard", 0) - self.cog.post_job = AsyncMock(return_value=eval_result) - self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) - - mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) - self.bot.get_cog.return_value = mocked_filter_cog - - job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") - await self.cog.send_job(ctx, job), - - ctx.send.assert_called_once() - self.assertEqual( - ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job ' - 'has completed with return code 0.' - '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' - ) - - self.cog.post_job.assert_called_once_with(job) - self.cog.format_output.assert_called_once_with('Way too long beard') - - async def test_send_job_with_non_zero_eval(self): - """Test the send_job function with a code returning a non-zero code.""" - ctx = MockContext() - ctx.message = MockMessage() - ctx.send = AsyncMock() - ctx.author.mention = '@LemonLemonishBeard#0042' - - eval_result = EvalResult("ERROR", 127) - self.cog.post_job = AsyncMock(return_value=eval_result) - self.cog.upload_output = AsyncMock() # This function isn't called - - mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) - self.bot.get_cog.return_value = mocked_filter_cog - - job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") - await self.cog.send_job(ctx, job), - - ctx.send.assert_called_once() - self.assertEqual( - ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.' - '\n\n```\nERROR\n```' - ) - - self.cog.post_job.assert_called_once_with(job) - self.cog.upload_output.assert_not_called() - - @patch("bot.exts.utils.snekbox._cog.partial") - async def test_continue_job_does_continue(self, partial_mock): - """Test that the continue_job function does continue if required conditions are met.""" - ctx = MockContext( - message=MockMessage( - id=4, - add_reaction=AsyncMock(), - clear_reactions=AsyncMock() - ), - author=MockMember(id=14) - ) - response = MockMessage(id=42, delete=AsyncMock()) - new_msg = MockMessage() - self.cog.jobs = {4: 42} - self.bot.wait_for.side_effect = ((None, new_msg), None) - expected = "NewCode" - self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) - - actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) - self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, EvalJob.from_code(expected)) - self.bot.wait_for.assert_has_awaits( - ( - call( - 'message_edit', - check=partial_mock(snekbox._cog.predicate_message_edit, ctx), - timeout=snekbox._cog.REDO_TIMEOUT, - ), - call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10) - ) - ) - ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) - response.delete.assert_called_once() - - async def test_continue_job_does_not_continue(self): - ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) - self.bot.wait_for.side_effect = asyncio.TimeoutError - - actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) - self.assertEqual(actual, None) - ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) - - async def test_get_code(self): - """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" - prefix = constants.Bot.prefix - subtests = ( - (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name} print(1)", "print(1)"), - (self.cog.eval_command, f"{prefix}{self.cog.eval_command.name}", None), - (MagicMock(spec=commands.Command), f"{prefix}tags get foo"), - (None, "print(123)") - ) - - for command, content, *expected_code in subtests: - if not expected_code: - expected_code = content - else: - [expected_code] = expected_code - - with self.subTest(content=content, expected_code=expected_code): - self.bot.get_context.reset_mock() - self.bot.get_context.return_value = MockContext(command=command) - message = MockMessage(content=content) - - actual_code = await self.cog.get_code(message, self.cog.eval_command) - - self.bot.get_context.assert_awaited_once_with(message) - self.assertEqual(actual_code, expected_code) - - def test_predicate_message_edit(self): - """Test the predicate_message_edit function.""" - msg0 = MockMessage(id=1, content='abc') - msg1 = MockMessage(id=2, content='abcdef') - msg2 = MockMessage(id=1, content='abcdef') - - cases = ( - (msg0, msg0, False, 'same ID, same content'), - (msg0, msg1, False, 'different ID, different content'), - (msg0, msg2, True, 'same ID, different content') - ) - for ctx_msg, new_msg, expected, testname in cases: - with self.subTest(msg=f'Messages with {testname} return {expected}'): - ctx = MockContext(message=ctx_msg) - actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg) - self.assertEqual(actual, expected) - - def test_predicate_emoji_reaction(self): - """Test the predicate_emoji_reaction function.""" - valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI - valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) - valid_user = MockUser(id=2) - - invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI - invalid_user_id = MockUser(id=42) - invalid_reaction_str = MockReaction(message=MockMessage(id=1)) - invalid_reaction_str.__str__.return_value = ':longbeard:' - - cases = ( - (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), - (valid_reaction, invalid_user_id, False, 'invalid user ID'), - (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), - (valid_reaction, valid_user, True, 'matching attributes') - ) - for reaction, user, expected, testname in cases: - with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user) - self.assertEqual(actual, expected) - - -class SnekboxSetupTests(unittest.IsolatedAsyncioTestCase): - """Tests setup of the `Snekbox` cog.""" - - async def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - await snekbox.setup(bot) - bot.add_cog.assert_awaited_once() -- cgit v1.2.3 From b04143ca971e2f272ab78da808a63b5fd700ab68 Mon Sep 17 00:00:00 2001 From: ionite34 Date: Tue, 20 Dec 2022 17:54:48 +0800 Subject: Add normalize file name tests --- tests/bot/exts/utils/snekbox/test_io.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/bot/exts/utils/snekbox/test_io.py (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py new file mode 100644 index 000000000..36ac720ba --- /dev/null +++ b/tests/bot/exts/utils/snekbox/test_io.py @@ -0,0 +1,34 @@ +import unittest + +# noinspection PyProtectedMember +from bot.exts.utils.snekbox import _io + + +class SnekboxIOTests(unittest.TestCase): + # noinspection SpellCheckingInspection + def test_normalize_file_name(self): + """Invalid file names should be normalized.""" + cases = [ + # ANSI escape sequences -> underscore + (r"\u001b[31mText", "_Text"), + # (Multiple consecutive should be collapsed to one underscore) + (r"a\u001b[35m\u001b[37mb", "a_b"), + # Backslash escaped chars -> underscore + (r"\n", "_"), + (r"\r", "_"), + (r"A\0\tB", "A__B"), + # Any other disallowed chars -> underscore + (r"\\.txt", "_.txt"), + (r"A!@#$%^&*B, C()[]{}+=D.txt", "A_B_C_D.txt"), # noqa: P103 + (" ", "_"), + # Normal file names should be unchanged + ("legal_file-name.txt", "legal_file-name.txt"), + ("_-.", "_-."), + ] + for name, expected in cases: + with self.subTest(name=name, expected=expected): + # Test function directly + self.assertEqual(_io.normalize_discord_file_name(name), expected) + # Test FileAttachment.to_file() + obj = _io.FileAttachment(name, b"") + self.assertEqual(obj.to_file().filename, expected) -- cgit v1.2.3 From 2b4c85e947a73295c72161190302d960597420be Mon Sep 17 00:00:00 2001 From: ionite34 Date: Tue, 20 Dec 2022 18:26:34 +0800 Subject: Change failed files str to truncate on chars only --- bot/exts/utils/snekbox/_eval.py | 23 +++++++++++++++++++---- tests/bot/exts/utils/snekbox/test_snekbox.py | 24 +++++++++++------------- 2 files changed, 30 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py index 95039f0bd..6bc7d7bb3 100644 --- a/bot/exts/utils/snekbox/_eval.py +++ b/bot/exts/utils/snekbox/_eval.py @@ -102,15 +102,30 @@ class EvalResult: return msg - def get_failed_files_str(self, char_max: int = 85, file_max: int = 5) -> str: - """Return a string containing the names of failed files, truncated to lower of char_max and file_max.""" + def get_failed_files_str(self, char_max: int = 85) -> str: + """ + Return a string containing the names of failed files, truncated char_max. + + Will truncate on whole file names if less than 3 characters remaining. + """ names = [] for file in self.failed_files: - char_max -= len(file) - if char_max < 0 or len(names) >= file_max: + # Only attempt to truncate name if more than 3 chars remaining + if char_max < 3: names.append("...") break + + to_display = min(char_max, len(file)) + name_short = file[:to_display] + # Add ellipsis if name was truncated + if to_display < len(file): + name_short += "..." + names.append(name_short) + break + + char_max -= len(file) names.append(file) + text = ", ".join(names) # Since the file names are provided by user text = escape_markdown(text) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index b129bfcdb..faa849178 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -154,23 +154,21 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) def test_eval_result_files_error_str(self): """EvalResult.files_error_message, should return files error message.""" - max_file_name = "a" * 32 cases = [ + # Normal (["x.ini"], "x.ini"), - (["dog.py", "cat.py"], "dog.py, cat.py"), - # 3 files limit - (["a", "b", "c"], "a, b, c"), - (["a", "b", "c", "d"], "a, b, c, ..."), - (["x", "y", "z"] + ["a"] * 100, "x, y, z, ..."), - # 32 char limit - ([max_file_name], max_file_name), - ([max_file_name, "b"], f"{max_file_name}, ..."), - ([max_file_name + "a"], "...") + (["123456", "879"], "123456, 879"), + # Break on whole name if less than 3 characters remaining + (["12345678", "9"], "12345678, ..."), + # Otherwise break on max chars + (["123", "345", "67890000"], "123, 345, 6789..."), + (["abcdefg1234567"], "abcdefg123..."), ] for failed_files, expected in cases: - result = EvalResult("", 0, [], failed_files) - msg = result.get_failed_files_str(char_max=32, file_max=3) - self.assertEqual(msg, expected) + with self.subTest(failed_files=failed_files, expected=expected): + result = EvalResult("", 0, [], failed_files) + msg = result.get_failed_files_str(char_max=10) + self.assertEqual(msg, expected) @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): -- cgit v1.2.3 From 5d6b042a2acc104178b1a4e68229a5b9714e9920 Mon Sep 17 00:00:00 2001 From: Ionite Date: Mon, 6 Feb 2023 21:28:11 -0500 Subject: Add disallowed file extensions tests --- tests/bot/exts/utils/snekbox/test_snekbox.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index faa849178..686dc0291 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -10,6 +10,7 @@ from bot import constants from bot.errors import LockedResourceError from bot.exts.utils import snekbox from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox +from bot.exts.utils.snekbox._io import FileAttachment from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser @@ -387,6 +388,34 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() + async def test_send_job_with_disallowed_file_ext(self): + """Test send_job with disallowed file extensions.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = "@user#7700" + + eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")]) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.upload_output = AsyncMock() # This function isn't called + + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), + + ctx.send.assert_called_once() + self.assertEqual( + ctx.send.call_args.args[0], + '@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.' + '\n\n1 file was not uploaded due to disallowed extension: **.disallowed**' + ) + + self.cog.post_job.assert_called_once_with(job) + self.cog.upload_output.assert_not_called() + @patch("bot.exts.utils.snekbox._cog.partial") async def test_continue_job_does_continue(self, partial_mock): """Test that the continue_job function does continue if required conditions are met.""" -- cgit v1.2.3 From 810165935dce10bc76bcbeb4c510d5510fdfb42c Mon Sep 17 00:00:00 2001 From: Ionite Date: Tue, 7 Feb 2023 02:53:21 -0500 Subject: Fix unit tests for new failmail emoji --- tests/bot/exts/utils/snekbox/test_snekbox.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 686dc0291..8f4b2e85c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -150,7 +150,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): result = EvalResult("", 0, files, failed_files) msg = result.files_error_message - self.assertEqual(msg, expected_msg) + self.assertIn(expected_msg, msg) @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) def test_eval_result_files_error_str(self): @@ -407,11 +407,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.send_job(ctx, job), ctx.send.assert_called_once() - self.assertEqual( - ctx.send.call_args.args[0], - '@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.' - '\n\n1 file was not uploaded due to disallowed extension: **.disallowed**' + res = ctx.send.call_args.args[0] + self.assertTrue( + res.startswith("@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.") ) + self.assertIn("Some files with disallowed extensions can't be uploaded: **.disallowed**", res) self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() -- cgit v1.2.3 From 8f2b323b083de318351cc856c7eeee5f44537253 Mon Sep 17 00:00:00 2001 From: Ionite Date: Tue, 7 Feb 2023 02:53:36 -0500 Subject: Add skip condition for windows not able to test path escapes --- tests/bot/exts/utils/snekbox/test_io.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py index 36ac720ba..a544a2056 100644 --- a/tests/bot/exts/utils/snekbox/test_io.py +++ b/tests/bot/exts/utils/snekbox/test_io.py @@ -1,11 +1,15 @@ -import unittest +import platform +from unittest import TestCase, skipIf # noinspection PyProtectedMember from bot.exts.utils.snekbox import _io -class SnekboxIOTests(unittest.TestCase): +class SnekboxIOTests(TestCase): # noinspection SpellCheckingInspection + # Skip Windows since both pathlib and os strips the escape sequences + # and many of these aren't valid Windows file paths + @skipIf(platform.system() == "Windows", "File names normalizer tests requires Unix-like OS.") def test_normalize_file_name(self): """Invalid file names should be normalized.""" cases = [ -- cgit v1.2.3 From a59eaf405eb91ce8c2961b820a0a9ae44d2215e1 Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 8 Feb 2023 12:04:30 -0500 Subject: Update unit tests for file error message changes --- tests/bot/exts/utils/snekbox/test_snekbox.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 8f4b2e85c..9dcf7fd8c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -130,20 +130,16 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """EvalResult.files_error_message, should return files error message.""" cases = [ ([], ["abc"], ( - "Failed to upload 1 file (abc)." - " File sizes should each not exceed 8 MiB." + "1 file upload (abc) failed because its file size exceeds 8 MiB." )), ([], ["file1.bin", "f2.bin"], ( - "Failed to upload 2 files (file1.bin, f2.bin)." - " File sizes should each not exceed 8 MiB." + "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." )), (["a", "b"], ["c"], ( - "Failed to upload 1 file (c)" - " as it exceeded the 2 file limit." + "1 file upload (c) failed as it exceeded the 2 file limit." )), (["a"], ["b", "c"], ( - "Failed to upload 2 files (b, c)" - " as they exceeded the 2 file limit." + "2 file uploads (b, c) failed as they exceeded the 2 file limit." )), ] for files, failed_files, expected_msg in cases: @@ -411,7 +407,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertTrue( res.startswith("@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.") ) - self.assertIn("Some files with disallowed extensions can't be uploaded: **.disallowed**", res) + self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res) self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() -- cgit v1.2.3 From 1cff5bf589a848576d3d1f4a9c1ab71633406caf Mon Sep 17 00:00:00 2001 From: Ibrahim2750mi Date: Tue, 14 Feb 2023 21:08:09 +0530 Subject: Update tests for `/tag` as of migration to slash commands --- bot/exts/backend/error_handler.py | 22 +++++++++----- tests/bot/exts/backend/test_error_handler.py | 44 ++++++++++++++-------------- tests/helpers.py | 20 +++++++++++++ 3 files changed, 56 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index cc2b5ef56..561bf8068 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,8 @@ import copy import difflib +import typing as t -from discord import Embed +from discord import Embed, Interaction from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -21,6 +22,10 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot + @staticmethod + async def _can_run(_: Interaction) -> bool: + return False + def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -159,7 +164,7 @@ class ErrorHandler(Cog): return True return False - async def try_get_tag(self, ctx: Context) -> None: + async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -168,27 +173,28 @@ class ErrorHandler(Cog): the context to prevent infinite recursion in the case of a CommandNotFound exception. """ tags_get_command = self.bot.get_command("tags get") + tags_get_command.can_run = can_run if can_run else self._can_run if not tags_get_command: log.debug("Not attempting to parse message as a tag as could not find `tags get` command.") return - ctx.invoked_from_error_handler = True + interaction.invoked_from_error_handler = True log_msg = "Cancelling attempt to fall back to a tag due to failed checks." try: - if not await tags_get_command.can_run(ctx): + if not await tags_get_command.can_run(interaction): log.debug(log_msg) return except errors.CommandError as tag_error: log.debug(log_msg) - await self.on_command_error(ctx, tag_error) + await self.on_command_error(interaction, tag_error) return - if await ctx.invoke(tags_get_command, argument_string=ctx.message.content): + if await interaction.invoke(tags_get_command, tag_name=interaction.message.content): return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in interaction.user.roles): + await self.send_command_suggestion(interaction, interaction.invoked_with) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index adb0252a5..83bc3c4a1 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel +from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -331,7 +331,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() - self.ctx = MockContext() + self.interaction = MockInteraction() self.tag = Tags(self.bot) self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command @@ -339,57 +339,57 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" self.bot.get_command.reset_mock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction) self.bot.get_command.assert_called_once_with("tags get") async def test_try_get_tag_invoked_from_error_handler(self): - """`self.ctx` should have `invoked_from_error_handler` `True`.""" - self.ctx.invoked_from_error_handler = False - await self.cog.try_get_tag(self.ctx) - self.assertTrue(self.ctx.invoked_from_error_handler) + """`self.interaction` should have `invoked_from_error_handler` `True`.""" + self.interaction.invoked_from_error_handler = False + await self.cog.try_get_tag(self.interaction) + self.assertTrue(self.interaction.invoked_from_error_handler) async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" self.tag.get_command.can_run = AsyncMock(return_value=False) - self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.interaction.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False))) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err))) + self.cog.on_command_error.assert_awaited_once_with(self.interaction, err) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.ctx.message = MagicMock(content="foo") - self.ctx.invoke = AsyncMock(return_value=True) + self.interaction.message = MagicMock(content="foo") + self.interaction.invoke = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) async def test_dont_call_suggestion_if_user_mod(self): """Should not call command suggestion if user is a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) - self.ctx.author.roles = [MockRole(id=1234)] + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) + self.interaction.user.roles = [MockRole(id=1234)] self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) - self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + await self.cog.try_get_tag(self.interaction, AsyncMock()) + self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo") class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 4b980ac21..2d20b4d07 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -479,6 +479,26 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) +class MockInteraction(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Interaction objects. + + Instances of this class will follow the specifications of `discord.Interaction` + instances. For more information, see the `MockGuild` docstring. + """ + # spec_set = context_instance + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.me = kwargs.get('me', MockMember()) + self.client = kwargs.get('client', MockBot()) + self.guild = kwargs.get('guild', MockGuild()) + self.user = kwargs.get('user', MockMember()) + self.channel = kwargs.get('channel', MockTextChannel()) + self.message = kwargs.get('message', MockMessage()) + self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) + + attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) -- cgit v1.2.3 From d29e8c6c4240a5e9cb5293788529621ee919c0b7 Mon Sep 17 00:00:00 2001 From: Ionite Date: Wed, 22 Feb 2023 16:24:56 -0500 Subject: Use PurePosixPath so tests work on windows --- bot/exts/utils/snekbox/_io.py | 9 ++++----- tests/bot/exts/utils/snekbox/test_io.py | 6 +----- 2 files changed, 5 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py index 404681936..9be396335 100644 --- a/bot/exts/utils/snekbox/_io.py +++ b/bot/exts/utils/snekbox/_io.py @@ -4,7 +4,7 @@ from __future__ import annotations from base64 import b64decode, b64encode from dataclasses import dataclass from io import BytesIO -from pathlib import Path +from pathlib import PurePosixPath import regex from discord import File @@ -64,12 +64,12 @@ class FileAttachment: @property def suffix(self) -> str: """Return the file suffix.""" - return Path(self.path).suffix + return PurePosixPath(self.path).suffix @property def name(self) -> str: """Return the file name.""" - return Path(self.path).name + return PurePosixPath(self.path).name @classmethod def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: @@ -98,6 +98,5 @@ class FileAttachment: def to_file(self) -> File: """Convert to a discord.File.""" - name = Path(self.path).name - name = normalize_discord_file_name(name) + name = normalize_discord_file_name(self.name) return File(BytesIO(self.content), filename=name) diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py index a544a2056..bcf1162b8 100644 --- a/tests/bot/exts/utils/snekbox/test_io.py +++ b/tests/bot/exts/utils/snekbox/test_io.py @@ -1,5 +1,4 @@ -import platform -from unittest import TestCase, skipIf +from unittest import TestCase # noinspection PyProtectedMember from bot.exts.utils.snekbox import _io @@ -7,9 +6,6 @@ from bot.exts.utils.snekbox import _io class SnekboxIOTests(TestCase): # noinspection SpellCheckingInspection - # Skip Windows since both pathlib and os strips the escape sequences - # and many of these aren't valid Windows file paths - @skipIf(platform.system() == "Windows", "File names normalizer tests requires Unix-like OS.") def test_normalize_file_name(self): """Invalid file names should be normalized.""" cases = [ -- cgit v1.2.3 From 9b98dfe78bb226e26a8d9cb6e8a0e8f8504286dd Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Thu, 23 Feb 2023 04:08:57 +0530 Subject: Implement all reviews + Remove commented code + Remove unecessarily syncting the bot + Handle direct tag commads + 3.10 type hinting in concerned functions + Add `MockInteractionMessage` + Fix tests for `try_get_tag` --- bot/exts/backend/error_handler.py | 40 ++++++++----- bot/exts/info/tags.py | 90 +++++++++++++++++++--------- bot/pagination.py | 4 +- bot/utils/messages.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 50 ++++++++-------- tests/helpers.py | 11 +++- 6 files changed, 125 insertions(+), 72 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 561bf8068..6561f84e4 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -2,7 +2,7 @@ import copy import difflib import typing as t -from discord import Embed, Interaction +from discord import Embed, Interaction, utils from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -23,8 +23,19 @@ class ErrorHandler(Cog): self.bot = bot @staticmethod - async def _can_run(_: Interaction) -> bool: - return False + async def _can_run(ctx: Context) -> bool: + """ + Add checks for the `get_command_ctx` function here. + + Use discord.utils to run the checks. + """ + checks = [] + predicates = checks + if not predicates: + # Since we have no checks, then we just return True. + return True + + return await utils.async_all(predicate(ctx) for predicate in predicates) def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" @@ -164,7 +175,7 @@ class ErrorHandler(Cog): return True return False - async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None: + async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -172,29 +183,30 @@ class ErrorHandler(Cog): by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to the context to prevent infinite recursion in the case of a CommandNotFound exception. """ - tags_get_command = self.bot.get_command("tags get") - tags_get_command.can_run = can_run if can_run else self._can_run - if not tags_get_command: - log.debug("Not attempting to parse message as a tag as could not find `tags get` command.") + tags_cog = self.bot.get_cog("Tags") + if not tags_cog: + log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.") return + tags_get_command = tags_cog.get_command_ctx + can_run = can_run if can_run else self._can_run - interaction.invoked_from_error_handler = True + ctx.invoked_from_error_handler = True log_msg = "Cancelling attempt to fall back to a tag due to failed checks." try: - if not await tags_get_command.can_run(interaction): + if not await can_run(ctx): log.debug(log_msg) return except errors.CommandError as tag_error: log.debug(log_msg) - await self.on_command_error(interaction, tag_error) + await self.on_command_error(ctx, tag_error) return - if await interaction.invoke(tags_get_command, tag_name=interaction.message.content): + if await tags_get_command(ctx, ctx.message.content): return - if not any(role.id in MODERATION_ROLES for role in interaction.user.roles): - await self.send_command_suggestion(interaction, interaction.invoked_with) + if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): + await self.send_command_suggestion(ctx, ctx.invoked_with) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 25c51def9..60f730586 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,7 +8,7 @@ from typing import Literal, NamedTuple, Optional, Union import discord import frontmatter -from discord import Embed, Member, app_commands +from discord import Embed, Interaction, Member, app_commands from discord.ext.commands import Cog from bot import constants @@ -140,15 +140,8 @@ class Tags(Cog): self.bot = bot self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() - self.bot.tree.copy_global_to(guild=discord.Object(id=GUILD_ID)) tag_group = app_commands.Group(name="tag", description="...") - # search_tag = app_commands.Group(name="search", description="...", parent=tag_group) - - @Cog.listener() - async def on_ready(self) -> None: - """Called when the cog is ready.""" - await self.bot.tree.sync(guild=discord.Object(id=GUILD_ID)) def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" @@ -195,7 +188,8 @@ class Tags(Cog): async def get_tag_embed( self, - interaction: discord.Interaction, + author: discord.Member, + channel: discord.TextChannel | discord.Thread, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -206,7 +200,7 @@ class Tags(Cog): filtered_tags = [ (ident, tag) for ident, tag in self.get_fuzzy_matches(tag_identifier)[:10] - if tag.accessible_by(interaction.user) + if tag.accessible_by(author) ] # Try exact match, includes checking through alt names @@ -225,10 +219,10 @@ class Tags(Cog): tag = filtered_tags[0][1] if tag is not None: - if tag.on_cooldown_in(interaction.channel): + if tag.on_cooldown_in(channel): log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.") return COOLDOWN.obj - tag.set_cooldown_for(interaction.channel) + tag.set_cooldown_for(channel) self.bot.stats.incr( f"tags.usages" @@ -243,7 +237,7 @@ class Tags(Cog): suggested_tags_text = "\n".join( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" for identifier, tag in filtered_tags - if not tag.on_cooldown_in(interaction.channel) + if not tag.on_cooldown_in(channel) ) return Embed( title="Did you mean ...", @@ -292,8 +286,37 @@ class Tags(Cog): if identifier.group == group and tag.accessible_by(user) ) + async def get_command_ctx( + self, + ctx: discord.Context, + name: str + ) -> bool: + """Made specifically for `error_handler.py`, See `get_command` for more info.""" + identifier = TagIdentifier.from_string(name) + + if identifier.group is None: + # Try to find accessible tags from a group matching the identifier's name. + if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author): + await LinePaginator.paginate( + group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS + ) + return True + + embed = await self.get_tag_embed(ctx.author, ctx.channel, identifier) + if embed is None: + return False + + if embed is not COOLDOWN.obj: + + await wait_for_deletion( + await ctx.send(embed=embed), + (ctx.author.id,) + ) + # A valid tag was found and was either sent, or is on cooldown + return True + @tag_group.command(name="get") - async def get_command(self, interaction: discord.Interaction, *, tag_name: Optional[str]) -> bool: + async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name. @@ -303,7 +326,7 @@ class Tags(Cog): Returns True if a message was sent, or if the tag is on cooldown. Returns False if no message was sent. """ # noqa: D205, D415 - if not tag_name: + if not name: if self.tags: await LinePaginator.paginate( self.accessible_tags(interaction.user), @@ -314,7 +337,7 @@ class Tags(Cog): await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) return True - identifier = TagIdentifier.from_string(tag_name) + identifier = TagIdentifier.from_string(name) if identifier.group is None: # Try to find accessible tags from a group matching the identifier's name. @@ -324,33 +347,43 @@ class Tags(Cog): ) return True - embed = await self.get_tag_embed(interaction, identifier) + embed = await self.get_tag_embed(interaction.user, interaction.channel, identifier) + ephemeral = False if embed is None: - return False - - if embed is not COOLDOWN.obj: + description = f"**There are no tags matching the name {name!r}!**" + embed = Embed(description=description) + ephemeral = True + elif embed is COOLDOWN.obj: + description = f"Tag {name!r} is on cooldown." + embed = Embed(description=description) + ephemeral = True + + await interaction.response.send_message(embed=embed, ephemeral=ephemeral) + if not ephemeral: await wait_for_deletion( - await interaction.response.send_message(embed=embed), + await interaction.original_response(), (interaction.user.id,) ) + # A valid tag was found and was either sent, or is on cooldown return True - @get_command.autocomplete("tag_name") - async def tag_name_autocomplete( + @get_command.autocomplete("name") + async def name_autocomplete( self, - interaction: discord.Interaction, + interaction: Interaction, current: str ) -> list[app_commands.Choice[str]]: """Autocompleter for `/tag get` command.""" - tag_names = [tag.name for tag in self.tags.keys()] - return [ + names = [tag.name for tag in self.tags.keys()] + choices = [ app_commands.Choice(name=tag, value=tag) - for tag in tag_names if current.lower() in tag + for tag in names if current.lower() in tag ] + return choices[:25] if len(choices) > 25 else choices @tag_group.command(name="list") - async def list_command(self, interaction: discord.Interaction) -> bool: + async def list_command(self, interaction: Interaction) -> bool: """Lists all accessible tags.""" if self.tags: await LinePaginator.paginate( @@ -367,4 +400,3 @@ class Tags(Cog): async def setup(bot: Bot) -> None: """Load the Tags cog.""" await bot.add_cog(Tags(bot)) - await bot.tree.sync(guild=discord.Object(id=GUILD_ID)) diff --git a/bot/pagination.py b/bot/pagination.py index 1c63a4768..c39ce211b 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -190,8 +190,8 @@ class LinePaginator(Paginator): @classmethod async def paginate( cls, - lines: t.List[str], - ctx: t.Union[Context, discord.Interaction], + lines: list[str], + ctx: Context | discord.Interaction, embed: discord.Embed, prefix: str = "", suffix: str = "", diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 27f2eac97..f6bdceaef 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -58,7 +58,7 @@ def reaction_check( async def wait_for_deletion( - message: discord.Message, + message: discord.Message | discord.InteractionMessage, user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 83bc3c4a1..14e7a4125 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel +from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -331,65 +331,65 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() - self.interaction = MockInteraction() + self.ctx = MockContext() self.tag = Tags(self.bot) self.cog = error_handler.ErrorHandler(self.bot) - self.bot.get_command.return_value = self.tag.get_command + self.bot.get_cog.return_value = self.tag async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" - self.bot.get_command.reset_mock() - await self.cog.try_get_tag(self.interaction) - self.bot.get_command.assert_called_once_with("tags get") + self.bot.get_cog.reset_mock() + await self.cog.try_get_tag(self.ctx) + self.bot.get_cog.assert_called_once_with("Tags") async def test_try_get_tag_invoked_from_error_handler(self): - """`self.interaction` should have `invoked_from_error_handler` `True`.""" - self.interaction.invoked_from_error_handler = False - await self.cog.try_get_tag(self.interaction) - self.assertTrue(self.interaction.invoked_from_error_handler) + """`self.ctx` should have `invoked_from_error_handler` `True`.""" + self.ctx.invoked_from_error_handler = False + await self.cog.try_get_tag(self.ctx) + self.assertTrue(self.ctx.invoked_from_error_handler) async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" self.tag.get_command.can_run = AsyncMock(return_value=False) - self.interaction.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False))) + self.ctx.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False))) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err))) - self.cog.on_command_error.assert_awaited_once_with(self.interaction, err) + self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err))) + self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.interaction.message = MagicMock(content="foo") - self.interaction.invoke = AsyncMock(return_value=True) + self.ctx.message = MagicMock(content="foo") + self.tag.get_command_ctx = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) + await self.cog.try_get_tag(self.ctx) self.cog.send_command_suggestion.assert_not_awaited() @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) async def test_dont_call_suggestion_if_user_mod(self): """Should not call command suggestion if user is a mod.""" - self.interaction.invoked_with = "foo" - self.interaction.invoke = AsyncMock(return_value=False) - self.interaction.user.roles = [MockRole(id=1234)] + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) + self.ctx.author.roles = [MockRole(id=1234)] self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) + await self.cog.try_get_tag(self.ctx) self.cog.send_command_suggestion.assert_not_awaited() async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" - self.interaction.invoked_with = "foo" - self.interaction.invoke = AsyncMock(return_value=False) + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) - self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo") + await self.cog.try_get_tag(self.ctx) + self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 2d20b4d07..0d955b521 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -486,7 +486,6 @@ class MockInteraction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Interaction` instances. For more information, see the `MockGuild` docstring. """ - # spec_set = context_instance def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -550,6 +549,16 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): self.channel = kwargs.get('channel', MockTextChannel()) +class MockInteractionMessage(MockMessage): + """ + A MagicMock subclass to mock InteractionMessage objects. + + Instances of this class will follow the specifications of `discord.InteractionMessage` instances. For more + information, see the `MockGuild` docstring. + """ + pass + + emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) -- cgit v1.2.3 From 03614c313341497e61c45bbb2a364b969d2bb163 Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Sun, 26 Feb 2023 17:44:47 +0530 Subject: Implement reviews + used both `discord.User` and `discord.Member` in typehinting as `InteractionResponse.user` returns `discord.User` object + removed `ErrorHandler()._can_run` + edited `try_get_tag` to use `bot.can_run` + removed `/tag list` + change `/tag get ` to `/tag ` + remove redundant `GUILD_ID` in `tags.py` + using `discord.abc.Messageable` because `ctx.channel` returns that instead of `Channel` Object --- bot/exts/backend/error_handler.py | 50 ++++++++++------------------ bot/exts/info/tags.py | 36 +++++--------------- tests/bot/exts/backend/test_error_handler.py | 10 +++--- 3 files changed, 32 insertions(+), 64 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 839d882de..e274e337a 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,8 +1,7 @@ import copy import difflib -import typing as t -from discord import Embed, Interaction, utils +from discord import Embed, Member from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -22,22 +21,6 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot - @staticmethod - async def _can_run(ctx: Context) -> bool: - """ - Add checks for the `get_command_ctx` function here. - - The command code style is copied from discord.ext.commands.Command.can_run itself. - Append checks in the checks list. - """ - checks = [] - predicates = checks - if not predicates: - # Since we have no checks, then we just return True. - return True - - return await utils.async_all(predicate(ctx) for predicate in predicates) - def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -176,7 +159,7 @@ class ErrorHandler(Cog): return True return False - async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None: + async def try_get_tag(self, ctx: Context) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -189,25 +172,28 @@ class ErrorHandler(Cog): log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.") return tags_get_command = tags_cog.get_command_ctx - can_run = can_run if can_run else self._can_run - ctx.invoked_from_error_handler = True + maybe_tag_name = ctx.invoked_with + if not maybe_tag_name or not isinstance(ctx.author, Member): + return - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + ctx.invoked_from_error_handler = True try: - if not await can_run(ctx): - log.debug(log_msg) + if not await self.bot.can_run(ctx): + log.debug("Cancelling attempt to fall back to a tag due to failed checks.") return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - if await tags_get_command(ctx, ctx.message.content): - return + if await tags_get_command(ctx, maybe_tag_name): + return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): + await self.send_command_suggestion(ctx, maybe_tag_name) + except Exception as err: + log.debug("Error while attempting to invoke tag fallback.") + if isinstance(err, errors.CommandError): + await self.on_command_error(ctx, err) + else: + await self.on_command_error(ctx, errors.CommandInvokeError(err)) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 60f730586..0c244ff37 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,8 +8,8 @@ from typing import Literal, NamedTuple, Optional, Union import discord import frontmatter -from discord import Embed, Interaction, Member, app_commands -from discord.ext.commands import Cog +from discord import Embed, Interaction, Member, User, app_commands +from discord.ext.commands import Cog, Context from bot import constants from bot.bot import Bot @@ -27,8 +27,6 @@ TEST_CHANNELS = ( REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." -GUILD_ID = constants.Guild.id - class COOLDOWN(enum.Enum): """Sentinel value to signal that a tag is on cooldown.""" @@ -93,7 +91,7 @@ class Tag: embed.description = self.content return embed - def accessible_by(self, member: discord.Member) -> bool: + def accessible_by(self, member: Member | User) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to @@ -141,8 +139,6 @@ class Tags(Cog): self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() - tag_group = app_commands.Group(name="tag", description="...") - def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" base_path = Path("bot", "resources", "tags") @@ -188,8 +184,8 @@ class Tags(Cog): async def get_tag_embed( self, - author: discord.Member, - channel: discord.TextChannel | discord.Thread, + author: Member | User, + channel: discord.abc.Messageable, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -244,7 +240,7 @@ class Tags(Cog): description=suggested_tags_text ) - def accessible_tags(self, user: Member) -> list[str]: + def accessible_tags(self, user: Member | User) -> list[str]: """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted.""" def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: group, name = tag_item[0] @@ -278,7 +274,7 @@ class Tags(Cog): return result_lines - def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: Member | User) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" @@ -288,7 +284,7 @@ class Tags(Cog): async def get_command_ctx( self, - ctx: discord.Context, + ctx: Context, name: str ) -> bool: """Made specifically for `error_handler.py`, See `get_command` for more info.""" @@ -315,7 +311,7 @@ class Tags(Cog): # A valid tag was found and was either sent, or is on cooldown return True - @tag_group.command(name="get") + @app_commands.command(name="tag") async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group @@ -382,20 +378,6 @@ class Tags(Cog): ] return choices[:25] if len(choices) > 25 else choices - @tag_group.command(name="list") - async def list_command(self, interaction: Interaction) -> bool: - """Lists all accessible tags.""" - if self.tags: - await LinePaginator.paginate( - self.accessible_tags(interaction.user), - interaction, - Embed(title="Available tags"), - **self.PAGINATOR_DEFAULTS, - ) - else: - await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) - return True - async def setup(bot: Bot) -> None: """Load the Tags cog.""" diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 14e7a4125..533eaeda6 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -350,16 +350,16 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" - self.tag.get_command.can_run = AsyncMock(return_value=False) + self.bot.can_run = AsyncMock(return_value=False) self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() - self.tag.get_command.can_run = AsyncMock(side_effect=err) + self.bot.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) async def test_dont_call_suggestion_tag_sent(self): @@ -385,7 +385,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.tag.get_command_ctx = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() await self.cog.try_get_tag(self.ctx) -- cgit v1.2.3 From bec7980bf02246c7572a0a20acf6768337535613 Mon Sep 17 00:00:00 2001 From: shtlrs Date: Tue, 28 Feb 2023 16:16:46 +0100 Subject: add the `flags` key to the member_data dictionary The value 2 represents the `COMPLETED_ONBOARDING` flag, found here https://discord.com/developers/docs/resources/guild#guild-member-object-guild-member-flags --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/helpers.py b/tests/helpers.py index 0d955b521..1a71f210a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -222,7 +222,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): # Create a Member instance to get a realistic Mock of `discord.Member` -member_data = {'user': 'lemon', 'roles': [1]} +member_data = {'user': 'lemon', 'roles': [1], 'flags': 2} state_mock = unittest.mock.MagicMock() member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) -- cgit v1.2.3 From bf8f8f4c1f9522a942c88ca69a2d48427d2bbc28 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Mar 2023 21:00:53 +0000 Subject: Bump markdownify from 0.6.1 to 0.11.6 (#2429) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: wookie184 --- bot/exts/info/doc/_markdown.py | 12 ++++++++++-- poetry.lock | 12 ++++++------ pyproject.toml | 6 +----- tests/bot/exts/info/doc/test_parsing.py | 23 +++++++++++++++++++++++ 4 files changed, 40 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/bot/exts/info/doc/_markdown.py b/bot/exts/info/doc/_markdown.py index 1b7d8232b..315adda66 100644 --- a/bot/exts/info/doc/_markdown.py +++ b/bot/exts/info/doc/_markdown.py @@ -1,10 +1,14 @@ +import re from urllib.parse import urljoin +import markdownify from bs4.element import PageElement -from markdownify import MarkdownConverter +# See https://github.com/matthewwithanm/python-markdownify/issues/31 +markdownify.whitespace_re = re.compile(r"[\r\n\s\t ]+") -class DocMarkdownConverter(MarkdownConverter): + +class DocMarkdownConverter(markdownify.MarkdownConverter): """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" def __init__(self, *, page_url: str, **options): @@ -56,3 +60,7 @@ class DocMarkdownConverter(MarkdownConverter): if parent is not None and parent.name == "li": return f"{text}\n" return super().convert_p(el, text, convert_as_inline) + + def convert_hr(self, el: PageElement, text: str, convert_as_inline: bool) -> str: + """Ignore `hr` tag.""" + return "" diff --git a/poetry.lock b/poetry.lock index 8a718ab82..de777c828 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1041,19 +1041,19 @@ source = ["Cython (>=0.29.7)"] [[package]] name = "markdownify" -version = "0.6.1" +version = "0.11.6" description = "Convert HTML to markdown." category = "main" optional = false python-versions = "*" files = [ - {file = "markdownify-0.6.1-py3-none-any.whl", hash = "sha256:7489fd5c601536996a376c4afbcd1dd034db7690af807120681461e82fbc0acc"}, - {file = "markdownify-0.6.1.tar.gz", hash = "sha256:31d7c13ac2ada8bfc7535a25fee6622ca720e1b5f2d4a9cbc429d167c21f886d"}, + {file = "markdownify-0.11.6-py3-none-any.whl", hash = "sha256:ba35fe289d5e9073bcd7d2cad629278fe25f1a93741fcdc0bfb4f009076d8324"}, + {file = "markdownify-0.11.6.tar.gz", hash = "sha256:009b240e0c9f4c8eaf1d085625dcd4011e12f0f8cec55dedf9ea6f7655e49bfe"}, ] [package.dependencies] -beautifulsoup4 = "*" -six = "*" +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" [[package]] name = "mccabe" @@ -2317,4 +2317,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "3.10.*" -content-hash = "68bfdf2115a5242df097155a2660a1c0276cf25b4785bdb761580bd35b77383c" +content-hash = "4b3549e9e47535d1fea6015a0f7ebf056a42e4d27e766583ccd8b59ebe8297d6" diff --git a/pyproject.toml b/pyproject.toml index 71981e8d0..11e99ecbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,11 +22,7 @@ deepdiff = "6.2.1" emoji = "2.2.0" feedparser = "6.0.10" lxml = "4.9.1" - -# Must be kept on this version unless doc command output is fixed -# See https://github.com/python-discord/bot/pull/2156 -markdownify = "0.6.1" - +markdownify = "0.11.6" more-itertools = "9.0.0" python-dateutil = "2.8.2" python-frontmatter = "1.0.0" diff --git a/tests/bot/exts/info/doc/test_parsing.py b/tests/bot/exts/info/doc/test_parsing.py index 1663d8491..d2105a53c 100644 --- a/tests/bot/exts/info/doc/test_parsing.py +++ b/tests/bot/exts/info/doc/test_parsing.py @@ -1,6 +1,7 @@ from unittest import TestCase from bot.exts.info.doc import _parsing as parsing +from bot.exts.info.doc._markdown import DocMarkdownConverter class SignatureSplitter(TestCase): @@ -64,3 +65,25 @@ class SignatureSplitter(TestCase): for input_string, expected_output in test_cases: with self.subTest(input_string=input_string): self.assertEqual(list(parsing._split_parameters(input_string)), expected_output) + + +class MarkdownConverterTest(TestCase): + def test_hr_removed(self): + test_cases = ( + ('
', ""), + ("
", ""), + ) + self._run_tests(test_cases) + + def test_whitespace_removed(self): + test_cases = ( + ("lines\nof\ntext", "lines of text"), + ("lines\n\nof\n\ntext", "lines of text"), + ) + self._run_tests(test_cases) + + def _run_tests(self, test_cases: tuple[tuple[str, str], ...]): + for input_string, expected_output in test_cases: + with self.subTest(input_string=input_string): + d = DocMarkdownConverter(page_url="https://example.com") + self.assertEqual(d.convert(input_string), expected_output) -- cgit v1.2.3