aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2022-10-27 09:59:51 +0100
committerGravatar GitHub <[email protected]>2022-10-27 09:59:51 +0100
commitdbf3dcb537d59c261f1d5aa1ef5faba3d13910b6 (patch)
treed58de5005ed75d97264cde28bad636c5572e2061
parentMerge pull request #2308 from python-discord/move_security_cog (diff)
parentReverse changes to invalid arg break (diff)
Merge pull request #2310 from python-discord/rules-fix
Fix for rules greedy parsing freezing
-rw-r--r--bot/exts/info/information.py18
-rw-r--r--tests/bot/exts/info/test_information.py24
2 files changed, 22 insertions, 20 deletions
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 2592e093d..2eb9382e3 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -524,7 +524,7 @@ class Information(Cog):
await self.send_raw_content(ctx, message, json=True)
@command(aliases=("rule",))
- async def rules(self, ctx: Context, *args: Optional[str]) -> Optional[Set[int]]:
+ async def rules(self, ctx: Context, *, args: Optional[str]) -> Optional[Set[int]]:
"""
Provides a link to all rules or, if specified, displays specific rule(s).
@@ -541,13 +541,15 @@ class Information(Cog):
for rule_keyword in rule_keywords:
keyword_to_rule_number[rule_keyword] = rule_number
- for word in args:
- try:
- rule_numbers.append(int(word))
- except ValueError:
- if (kw := word.lower()) not in keyword_to_rule_number:
- break
- keywords.append(kw)
+ if args:
+ for word in args.split(maxsplit=100):
+ try:
+ rule_numbers.append(int(word))
+ except ValueError:
+ # Stop on first invalid keyword/index to allow for normal messaging after
+ if (kw := word.lower()) not in keyword_to_rule_number:
+ break
+ keywords.append(kw)
if not rule_numbers and not keywords:
# Neither rules nor keywords were submitted. Return the default description.
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index 9f5143c01..65595e959 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -603,9 +603,9 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):
async def test_return_none_if_one_rule_number_is_invalid(self):
test_cases = [
- (('1', '6', '7', '8'), (6, 7, 8)),
- (('10', "first"), (10, )),
- (("first", 10), (10, ))
+ ("1 6 7 8", (6, 7, 8)),
+ ("10 first", (10,)),
+ ("first 10", (10,))
]
for raw_user_input, extracted_rule_numbers in test_cases:
@@ -614,7 +614,7 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):
str(rule_number) for rule_number in extracted_rule_numbers
if rule_number < 1 or rule_number > len(self.full_rules))
- final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)
self.assertEqual(
self.ctx.send.call_args,
@@ -624,26 +624,26 @@ class RuleCommandTests(unittest.IsolatedAsyncioTestCase):
async def test_return_correct_rule_numbers(self):
test_cases = [
- (("1", "2", "first"), {1, 2}),
- (("1", "hello", "2", "second"), {1}),
- (("second", "third", "unknown", "999"), {2, 3})
+ ("1 2 first", {1, 2}),
+ ("1 hello 2 second", {1}),
+ ("second third unknown 999", {2, 3}),
]
for raw_user_input, expected_matched_rule_numbers in test_cases:
with self.subTest(identifier=raw_user_input):
- final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)
self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)
async def test_return_default_rules_when_no_input_or_no_match_are_found(self):
test_cases = [
- ((), None),
- (("hello", "2", "second"), None),
- (("hello", "999"), None),
+ ("", None),
+ ("hello 2 second", None),
+ ("hello 999", None),
]
for raw_user_input, expected_matched_rule_numbers in test_cases:
with self.subTest(identifier=raw_user_input):
- final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, args=raw_user_input)
embed = self.ctx.send.call_args.kwargs['embed']
self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description)
self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)