diff options
| author | 2022-10-27 09:59:51 +0100 | |
|---|---|---|
| committer | 2022-10-27 09:59:51 +0100 | |
| commit | dbf3dcb537d59c261f1d5aa1ef5faba3d13910b6 (patch) | |
| tree | d58de5005ed75d97264cde28bad636c5572e2061 | |
| parent | Merge pull request #2308 from python-discord/move_security_cog (diff) | |
| parent | Reverse changes to invalid arg break (diff) | |
Merge pull request #2310 from python-discord/rules-fix
Fix for rules greedy parsing freezing
| -rw-r--r-- | bot/exts/info/information.py | 18 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_information.py | 24 | 
2 files changed, 22 insertions, 20 deletions
| diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 2592e093d..2eb9382e3 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -524,7 +524,7 @@ class Information(Cog):          await self.send_raw_content(ctx, message, json=True)      @command(aliases=("rule",)) -    async def rules(self, ctx: Context, *args: Optional[str]) -> Optional[Set[int]]: +    async def rules(self, ctx: Context, *, args: Optional[str]) -> Optional[Set[int]]:          """          Provides a link to all rules or, if specified, displays specific rule(s). @@ -541,13 +541,15 @@ class Information(Cog):              for rule_keyword in rule_keywords:                  keyword_to_rule_number[rule_keyword] = rule_number -        for word in args: -            try: -                rule_numbers.append(int(word)) -            except ValueError: -                if (kw := word.lower()) not in keyword_to_rule_number: -                    break -                keywords.append(kw) +        if args: +            for word in args.split(maxsplit=100): +                try: +                    rule_numbers.append(int(word)) +                except ValueError: +                    # Stop on first invalid keyword/index to allow for normal messaging after +                    if (kw := word.lower()) not in keyword_to_rule_number: +                        break +                    keywords.append(kw)          if not rule_numbers and not keywords:              # Neither rules nor keywords were submitted. Return the default description. diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 9f5143c01..65595e959 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -603,9 +603,9 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):      async def test_return_none_if_one_rule_number_is_invalid(self):          test_cases = [ -            (('1', '6', '7', '8'), (6, 7, 8)), -            (('10', "first"), (10, )), -            (("first", 10), (10, )) +            ("1 6 7 8", (6, 7, 8)), +            ("10 first", (10,)), +            ("first 10", (10,))          ]          for raw_user_input, extracted_rule_numbers in test_cases: @@ -614,7 +614,7 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):                      str(rule_number) for rule_number in extracted_rule_numbers                      if rule_number < 1 or rule_number > len(self.full_rules)) -                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) +                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)                  self.assertEqual(                      self.ctx.send.call_args, @@ -624,26 +624,26 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):      async def test_return_correct_rule_numbers(self):          test_cases = [ -            (("1", "2", "first"), {1, 2}), -            (("1", "hello", "2", "second"), {1}), -            (("second", "third", "unknown", "999"), {2, 3}) +            ("1 2 first", {1, 2}), +            ("1 hello 2 second", {1}), +            ("second third unknown 999", {2, 3}),          ]          for raw_user_input, expected_matched_rule_numbers in test_cases:              with self.subTest(identifier=raw_user_input): -                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) +                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)                  self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)      async def test_return_default_rules_when_no_input_or_no_match_are_found(self):          test_cases = [ -            ((), None), -            (("hello", "2", "second"), None), -            (("hello", "999"), None), +            ("", None), +            ("hello 2 second", None), +            ("hello 999", None),          ]          for raw_user_input, expected_matched_rule_numbers in test_cases:              with self.subTest(identifier=raw_user_input): -                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) +                final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)                  embed = self.ctx.send.call_args.kwargs['embed']                  self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description)                  self.assertEqual(expected_matched_rule_numbers, final_rule_numbers) | 
