aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox.py41
-rw-r--r--tests/bot/exts/utils/test_snekbox.py7
2 files changed, 31 insertions, 17 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 213d57365..41cb00541 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -21,14 +21,12 @@ log = logging.getLogger(__name__)
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
FORMATTED_CODE_REGEX = re.compile(
- r"^\s*" # any leading whitespace from the beginning of the string
r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
r"(?P<code>.*?)" # extract all code inside the markup
r"\s*" # any more whitespace before the end of the code markup
- r"(?P=delim)" # match the exact same delimiter from the start again
- r"\s*$", # any trailing whitespace until the end of the string
+ r"(?P=delim)", # match the exact same delimiter from the start again
re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
)
RAW_CODE_REGEX = re.compile(
@@ -76,23 +74,32 @@ class Snekbox(Cog):
@staticmethod
def prepare_input(code: str) -> str:
- """Extract code from the Markdown, format it, and insert it into the code template."""
- match = FORMATTED_CODE_REGEX.fullmatch(code)
- if match:
- code, block, lang, delim = match.group("code", "block", "lang", "delim")
- code = textwrap.dedent(code)
- if block:
- info = (f"'{lang}' highlighted" if lang else "plain") + " code block"
+ """
+ Extract code from the Markdown, format it, and insert it into the code template.
+
+ If there is any code block, ignore text outside the code block.
+ Use the first code block, but prefer a fenced code block.
+ If there are several fenced code blocks, concatenate only the fenced code blocks.
+ """
+ if match := list(FORMATTED_CODE_REGEX.finditer(code)):
+ blocks = [block for block in match if block.group("block")]
+
+ if len(blocks) > 1:
+ code = '\n'.join(block.group("code") for block in blocks)
+ info = "several code blocks"
else:
- info = f"{delim}-enclosed inline code"
- log.trace(f"Extracted {info} for evaluation:\n{code}")
+ match = match[0] if len(blocks) == 0 else blocks[0]
+ code, block, lang, delim = match.group("code", "block", "lang", "delim")
+ if block:
+ info = (f"'{lang}' highlighted" if lang else "plain") + " code block"
+ else:
+ info = f"{delim}-enclosed inline code"
else:
- code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))
- log.trace(
- f"Eval message contains unformatted or badly formatted code, "
- f"stripping whitespace only:\n{code}"
- )
+ code = RAW_CODE_REGEX.fullmatch(code).group("code")
+ info = "unformatted or badly formatted code"
+ code = textwrap.dedent(code)
+ log.trace(f"Extracted {info} for evaluation:\n{code}")
return code
@staticmethod
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 6601fad2c..9a42d0610 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -52,6 +52,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'),
+ ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```',
+ 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'),
+ ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```',
+ 'print("How\'s it going?")', 'code block preceded by inline code'),
+ ('`print("Hello world!")`\ntext\n`print("Hello world!")`',
+ 'print("Hello world!")', 'one inline code block of two')
)
for case, expected, testname in cases:
with self.subTest(msg=f'Extract code from {testname}.'):