From 6470d7d4216708e58b57be942fd69cb8a0deba1a Mon Sep 17 00:00:00 2001 From: Amrou Bellalouna Date: Thu, 21 Mar 2024 15:48:10 +0100 Subject: update configuration tests (#2951) The old test relied on the old system where we loaded config from a yaml file, which ended up doing nothing. --- tests/bot/.testenv | 2 ++ tests/bot/test_constants.py | 64 ++++++++++++++++++--------------------------- 2 files changed, 28 insertions(+), 38 deletions(-) create mode 100644 tests/bot/.testenv (limited to 'tests') diff --git a/tests/bot/.testenv b/tests/bot/.testenv new file mode 100644 index 000000000..484c8809d --- /dev/null +++ b/tests/bot/.testenv @@ -0,0 +1,2 @@ +unittests_goat=volcyy +unittests_nested__server_name=pydis diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index 3492021ce..87933d59a 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,53 +1,41 @@ -import inspect -import typing -import unittest +import os +from pathlib import Path +from unittest import TestCase, mock -from bot import constants +from pydantic import BaseModel +from bot.constants import EnvConfig -def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: - """ - Return True if `value` is an instance of the type represented by `annotation`. +current_path = Path(__file__) +env_file_path = current_path.parent / ".testenv" - This doesn't account for things like Unions or checking for homogenous types in collections. - """ - origin = typing.get_origin(annotation) - # This is done in case a bare e.g. `typing.List` is used. - # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. - # `get_origin()` does this normalisation for us. - type_ = annotation if origin is None else origin +class TestEnvConfig( + EnvConfig, + env_file=env_file_path, +): + """Our default configuration for models that should load from .env files.""" - return isinstance(value, type_) +class NestedModel(BaseModel): + server_name: str -def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: - """Return True if `value` is an instance of any type in `types`.""" - return any(is_annotation_instance(value, type_) for type_ in types) +class _TestConfig(TestEnvConfig, env_prefix="unittests_"): -class ConstantsTests(unittest.TestCase): + goat: str + execution_env: str = "local" + nested: NestedModel + + +class ConstantsTests(TestCase): """Tests for our constants.""" + @mock.patch.dict(os.environ, {"UNITTESTS_EXECUTION_ENV": "production"}) def test_section_configuration_matches_type_specification(self): """"The section annotations should match the actual types of the sections.""" - sections = ( - cls - for (name, cls) in inspect.getmembers(constants) - if hasattr(cls, "section") and isinstance(cls, type) - ) - for section in sections: - for name, annotation in section.__annotations__.items(): - with self.subTest(section=section.__name__, name=name, annotation=annotation): - value = getattr(section, name) - origin = typing.get_origin(annotation) - annotation_args = typing.get_args(annotation) - failure_msg = f"{value} is not an instance of {annotation}" - - if origin is typing.Union: - is_instance = is_any_instance(value, annotation_args) - self.assertTrue(is_instance, failure_msg) - else: - is_instance = is_annotation_instance(value, annotation) - self.assertTrue(is_instance, failure_msg) + testconfig = _TestConfig() + self.assertEqual("volcyy", testconfig.goat) + self.assertEqual("pydis", testconfig.nested.server_name) + self.assertEqual("production", testconfig.execution_env) -- cgit v1.2.3 From 8f261610557f8b2ab8f47425159fe8b1efdd47ce Mon Sep 17 00:00:00 2001 From: wookie184 Date: Mon, 25 Mar 2024 19:29:02 +0000 Subject: Make help showable through button on command error message. (#2439) * Make help showable through button on command error message. * Improve error message Improve error message for attempting to delete other users' command invocations Co-authored-by: Boris Muratov <8bee278@gmail.com> * Use double quotes instead of single * Refactor to use `ViewWithUserAndRoleCheck` --------- Co-authored-by: Boris Muratov <8bee278@gmail.com> --- bot/exts/backend/error_handler.py | 59 ++++++++++++++++++++++------ tests/bot/exts/backend/test_error_handler.py | 7 ++-- 2 files changed, 51 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index faa39db5d..5cf07613d 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,10 +1,12 @@ import copy import difflib -from discord import Embed, Forbidden, Member +import discord +from discord import ButtonStyle, Embed, Forbidden, Interaction, Member, User from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from pydis_core.utils.error_handling import handle_forbidden_from_block +from pydis_core.utils.interactions import DeleteMessageButton, ViewWithUserAndRoleCheck from sentry_sdk import push_scope from bot.bot import Bot @@ -16,6 +18,35 @@ from bot.utils.checks import ContextCheckFailure log = get_logger(__name__) +class HelpEmbedView(ViewWithUserAndRoleCheck): + """View to allow showing the help command for command error responses.""" + + def __init__(self, help_embed: Embed, owner: User | Member): + super().__init__(allowed_roles=MODERATION_ROLES, allowed_users=[owner.id]) + self.help_embed = help_embed + + self.delete_button = DeleteMessageButton() + self.add_item(self.delete_button) + + async def interaction_check(self, interaction: Interaction) -> bool: + """Overriden check to allow anyone to use the help button.""" + if (interaction.data or {}).get("custom_id") == self.help_button.custom_id: + log.trace( + "Allowed interaction by %s (%d) on %d as interaction was with the help button.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + return await super().interaction_check(interaction) + + @discord.ui.button(label="Help", style=ButtonStyle.primary) + async def help_button(self, interaction: Interaction, button: discord.ui.Button) -> None: + """Send an ephemeral message with the contents of the help command.""" + await interaction.response.send_message(embed=self.help_embed, ephemeral=True) + + class ErrorHandler(Cog): """Handles errors emitted from commands.""" @@ -117,15 +148,6 @@ class ErrorHandler(Cog): # ExtensionError await self.handle_unexpected_error(ctx, e) - async def send_command_help(self, ctx: Context) -> None: - """Return a prepared `help` command invocation coroutine.""" - if ctx.command: - self.bot.help_command.context = ctx - await ctx.send_help(ctx.command) - return - - await ctx.send_help() - async def try_silence(self, ctx: Context) -> bool: """ Attempt to invoke the silence or unsilence command if invoke with matches a pattern. @@ -300,8 +322,21 @@ class ErrorHandler(Cog): ) self.bot.stats.incr("errors.other_user_input_error") - await ctx.send(embed=embed) - await self.send_command_help(ctx) + await self.send_error_with_help(ctx, embed) + + async def send_error_with_help(self, ctx: Context, error_embed: Embed) -> None: + """Send error message, with button to show command help.""" + # Fall back to just sending the error embed if the custom help cog isn't loaded yet. + # ctx.command shouldn't be None here, but check just to be safe. + help_embed_creator = getattr(self.bot.help_command, "command_formatting", None) + if not help_embed_creator or not ctx.command: + await ctx.send(embed=error_embed) + return + + self.bot.help_command.context = ctx + help_embed, _ = await help_embed_creator(ctx.command) + view = HelpEmbedView(help_embed, ctx.author) + view.message = await ctx.send(embed=error_embed, view=view) @staticmethod async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 9670d42a0..dbc62270b 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -414,12 +414,13 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): for case in test_cases: with self.subTest(error=case["error"], call_prepared=case["call_prepared"]): self.ctx.reset_mock() + self.cog.send_error_with_help = AsyncMock() self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"])) - self.ctx.send.assert_awaited_once() if case["call_prepared"]: - self.ctx.send_help.assert_awaited_once() + self.cog.send_error_with_help.assert_awaited_once() else: - self.ctx.send_help.assert_not_awaited() + self.ctx.send.assert_awaited_once() + self.cog.send_error_with_help.assert_not_awaited() async def test_handle_check_failure_errors(self): """Should await `ctx.send` when error is check failure.""" -- cgit v1.2.3 From 9a8520178b4a966dffc140d46bbe83466a3cf39e Mon Sep 17 00:00:00 2001 From: Mark <1515135+MarkKoz@users.noreply.github.com> Date: Sat, 3 Feb 2024 17:12:12 -0800 Subject: Snekbox: truncate blocked file extensions Avoid Discord's character limit for messages. Fix #2464 --- bot/exts/utils/snekbox/_cog.py | 131 ++++++++++++++++----------- bot/exts/utils/snekbox/_eval.py | 2 +- tests/bot/exts/utils/snekbox/test_snekbox.py | 19 ++-- 3 files changed, 90 insertions(+), 62 deletions(-) (limited to 'tests') diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index db4181d68..f26bf1000 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import re +from collections.abc import Iterable from functools import partial from operator import attrgetter from textwrap import dedent @@ -289,6 +290,69 @@ class Snekbox(Cog): return output, paste_link + async def format_file_text(self, text_files: list[FileAttachment], output: str) -> str: + # Inline until budget, then upload to paste service + # Budget is shared with stdout, so subtract what we've already used + budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) + budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) + msg = "" + + for file in text_files: + file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" + # Override to always allow 1 line and <= 50 chars, since this is less than a link + if len(file_text) <= 50 and not file_text.count("\n"): + msg += f"\n`{file.name}`\n```\n{file_text}\n```" + # otherwise, use budget + else: + format_text, link_text = await self.format_output( + file_text, + budget_lines, + budget_chars, + line_nums=False, + output_default="[Empty]" + ) + # With any link, use it (don't use budget) + if link_text: + msg += f"\n`{file.name}`\n{link_text}" + else: + msg += f"\n`{file.name}`\n```\n{format_text}\n```" + budget_lines -= format_text.count("\n") + 1 + budget_chars -= len(file_text) + + return msg + + def format_blocked_extensions(self, blocked: list[FileAttachment]) -> str: + # Sort by length and then lexicographically to fit as many as possible before truncating. + blocked_sorted = sorted(set(f.suffix for f in blocked), key=lambda e: (len(e), e)) + + # Only no extension + if len(blocked_sorted) == 1 and blocked_sorted[0] == "": + blocked_msg = "Files with no extension can't be uploaded." + # Both + elif "" in blocked_sorted: + blocked_str = self.join_blocked_extensions(ext for ext in blocked_sorted if ext) + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) + else: + blocked_str = self.join_blocked_extensions(blocked_sorted) + blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + + return f"\n{Emojis.failed_file} {blocked_msg}" + + def join_blocked_extensions(self, extensions: Iterable, delimiter: str = ", ", char_limit: int = 100) -> str: + joined = "" + for ext in extensions: + cur_delimiter = delimiter if joined else "" + if len(joined) + len(cur_delimiter) + len(ext) >= char_limit: + joined += f"{cur_delimiter}..." + break + + joined += f"{cur_delimiter}{ext}" + + return joined + + def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" # Filter files into allowed and blocked @@ -318,16 +382,18 @@ class Snekbox(Cog): """ async with ctx.typing(): result = await self.post_job(job) - msg = result.get_message(job) - error = result.error_message - - if error: - output, paste_link = error, None + # Collect stats of job fails + successes + if result.returncode != 0: + self.bot.stats.incr("snekbox.python.fail") else: - log.trace("Formatting output...") - output, paste_link = await self.format_output(result.stdout) + self.bot.stats.incr("snekbox.python.success") + + log.trace("Formatting output...") + output = result.error_message if result.error_message else result.stdout + output, paste_link = await self.format_output(output) - msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" + status_msg = result.get_status_message(job) + msg = f"{ctx.author.mention} {result.status_emoji} {status_msg}.\n" # This is done to make sure the last line of output contains the error # and the error is not manually printed by the author with a syntax error. @@ -345,39 +411,9 @@ class Snekbox(Cog): if files_error := result.files_error_message: msg += f"\n{files_error}" - # Collect stats of job fails + successes - if result.returncode != 0: - self.bot.stats.incr("snekbox.python.fail") - else: - self.bot.stats.incr("snekbox.python.success") - # Split text files text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES] - # Inline until budget, then upload to paste service - # Budget is shared with stdout, so subtract what we've already used - budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) - budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) - for file in text_files: - file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" - # Override to always allow 1 line and <= 50 chars, since this is less than a link - if len(file_text) <= 50 and not file_text.count("\n"): - msg += f"\n`{file.name}`\n```\n{file_text}\n```" - # otherwise, use budget - else: - format_text, link_text = await self.format_output( - file_text, - budget_lines, - budget_chars, - line_nums=False, - output_default="[Empty]" - ) - # With any link, use it (don't use budget) - if link_text: - msg += f"\n`{file.name}`\n{link_text}" - else: - msg += f"\n`{file.name}`\n```\n{format_text}\n```" - budget_lines -= format_text.count("\n") + 1 - budget_chars -= len(file_text) + msg += await self.format_file_text(text_files, output) filter_cog: Filtering | None = self.bot.get_cog("Filtering") blocked_exts = set() @@ -392,23 +428,8 @@ class Snekbox(Cog): # Filter file extensions allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked) - # Add notice if any files were blocked if blocked: - blocked_sorted = sorted(set(f.suffix for f in blocked)) - # Only no extension - if len(blocked_sorted) == 1 and blocked_sorted[0] == "": - blocked_msg = "Files with no extension can't be uploaded." - # Both - elif "" in blocked_sorted: - blocked_str = ", ".join(ext for ext in blocked_sorted if ext) - blocked_msg = ( - f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" - ) - else: - blocked_str = ", ".join(blocked_sorted) - blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" - - msg += f"\n{Emojis.failed_file} {blocked_msg}" + msg += self.format_blocked_extensions(blocked) # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py index d3d1e7a18..3867b81de 100644 --- a/bot/exts/utils/snekbox/_eval.py +++ b/bot/exts/utils/snekbox/_eval.py @@ -141,7 +141,7 @@ class EvalResult: text = escape_mentions(text) return text - def get_message(self, job: EvalJob) -> str: + def get_status_message(self, job: EvalJob) -> str: """Return a user-friendly message corresponding to the process's return code.""" msg = f"Your {job.version} {job.name} job" diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 8ee0f46ff..d057b284d 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -113,7 +113,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): result = EvalResult(stdout=stdout, returncode=returncode) job = EvalJob([]) # Check all 3 message types - msg = result.get_message(job) + msg = result.get_status_message(job) self.assertEqual(msg, exp_msg) error = result.error_message self.assertEqual(error, exp_err) @@ -166,7 +166,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.get_message(EvalJob([], version="3.10")), + result.get_status_message(EvalJob([], version="3.10")), "Your 3.10 eval job has completed with return code 127" ) self.assertEqual(result.error_message, "") @@ -177,7 +177,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mock_signals.return_value.name = "SIGTEST" result = EvalResult(stdout="", returncode=127) self.assertEqual( - result.get_message(EvalJob([], version="3.12")), + result.get_status_message(EvalJob([], version="3.12")), "Your 3.12 eval job has completed with return code 127 (SIGTEST)" ) @@ -386,12 +386,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author.mention = "@user#7700" - eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")]) + files = [ + FileAttachment("test.disallowed2", b"test"), + FileAttachment("test.disallowed", b"test"), + FileAttachment("test.allowed", b"test"), + FileAttachment("test." + ("a" * 100), b"test") + ] + eval_result = EvalResult("", 0, files=files) self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.upload_output = AsyncMock() # This function isn't called + disallowed_exts = [".disallowed", "." + ("a" * 100), ".disallowed2"] mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, disallowed_exts)) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.12") @@ -402,7 +409,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.assertTrue( res.startswith("@user#7700 :white_check_mark: Your 3.12 eval job has completed with return code 0.") ) - self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res) + self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed, .disallowed2, ...**", res) self.cog.post_job.assert_called_once_with(job) self.cog.upload_output.assert_not_called() -- cgit v1.2.3