diff options
author | 2022-09-20 17:58:52 -0700 | |
---|---|---|
committer | 2022-09-20 17:58:52 -0700 | |
commit | 02b1f09332e7dcbae8b9a00f21d2da6ec7983488 (patch) | |
tree | 2eb8acd044d7cbfcefcd73f67ac6cffcd308b3f6 | |
parent | Moved `escape_markdown` after Truthy check (#2279) (diff) | |
parent | Merge branch 'main' into 2108-invoke-rule-command-with-keywords (diff) |
Merge #2261 - add support to fetch rules via keywords
-rw-r--r-- | bot/constants.py | 1 | ||||
-rw-r--r-- | bot/exts/info/information.py | 65 | ||||
-rw-r--r-- | tests/bot/exts/info/test_information.py | 74 | ||||
-rw-r--r-- | tests/test_helpers.py | 2 |
4 files changed, 122 insertions, 20 deletions
diff --git a/bot/constants.py b/bot/constants.py index db98e6f47..68a96876f 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -397,6 +397,7 @@ class Categories(metaclass=YAMLGetter): # 2021 Summer Code Jam summer_code_jam: int + class Channels(metaclass=YAMLGetter): section = "guild" subsection = "channels" diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index e7d17c971..2592e093d 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -3,12 +3,12 @@ import pprint import textwrap from collections import defaultdict from textwrap import shorten -from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union +from typing import Any, DefaultDict, Mapping, Optional, Set, Tuple, Union import rapidfuzz from botcore.site_api import ResponseCodeError from discord import AllowedMentions, Colour, Embed, Guild, Message, Role -from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role from discord.utils import escape_markdown from bot import constants @@ -25,6 +25,12 @@ from bot.utils.members import get_or_fetch_member log = get_logger(__name__) +DEFAULT_RULES_DESCRIPTION = ( + "The rules and guidelines that apply to this community can be found on" + " our [rules page](https://www.pythondiscord.com/pages/rules). We expect" + " all members of the community to have read and understood these." +) + class Information(Cog): """A cog with commands for generating embeds with server info, such as server stats and user info.""" @@ -518,39 +524,60 @@ class Information(Cog): await self.send_raw_content(ctx, message, json=True) @command(aliases=("rule",)) - async def rules(self, ctx: Context, rules: Greedy[int]) -> None: - """Provides a link to all rules or, if specified, displays specific rule(s).""" - rules_embed = Embed(title="Rules", color=Colour.og_blurple(), url="https://www.pythondiscord.com/pages/rules") + async def rules(self, ctx: Context, *args: Optional[str]) -> Optional[Set[int]]: + """ + Provides a link to all rules or, if specified, displays specific rule(s). - if not rules: - # Rules were not submitted. Return the default description. - rules_embed.description = ( - "The rules and guidelines that apply to this community can be found on" - " our [rules page](https://www.pythondiscord.com/pages/rules). We expect" - " all members of the community to have read and understood these." - ) + It accepts either rule numbers or particular keywords that map to a particular rule. + Rule numbers and keywords can be sent in any order. + """ + rules_embed = Embed(title="Rules", color=Colour.og_blurple(), url="https://www.pythondiscord.com/pages/rules") + keywords, rule_numbers = [], [] + full_rules = await self.bot.api_client.get("rules", params={"link_format": "md"}) + keyword_to_rule_number = dict() + + for rule_number, (_, rule_keywords) in enumerate(full_rules, start=1): + for rule_keyword in rule_keywords: + keyword_to_rule_number[rule_keyword] = rule_number + + for word in args: + try: + rule_numbers.append(int(word)) + except ValueError: + if (kw := word.lower()) not in keyword_to_rule_number: + break + keywords.append(kw) + + if not rule_numbers and not keywords: + # Neither rules nor keywords were submitted. Return the default description. + rules_embed.description = DEFAULT_RULES_DESCRIPTION await ctx.send(embed=rules_embed) return - full_rules = await self.bot.api_client.get("rules", params={"link_format": "md"}) - # Remove duplicates and sort the rule indices - rules = sorted(set(rules)) + rule_numbers = sorted(set(rule_numbers)) - invalid = ", ".join(str(index) for index in rules if index < 1 or index > len(full_rules)) + invalid = ", ".join( + str(rule_number) for rule_number in rule_numbers + if rule_number < 1 or rule_number > len(full_rules)) if invalid: await ctx.send(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ...")) return - for rule in rules: - self.bot.stats.incr(f"rule_uses.{rule}") + final_rules = [] + final_rule_numbers = {keyword_to_rule_number[keyword] for keyword in keywords} + final_rule_numbers.update(rule_numbers) - final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules) + for rule_number in sorted(final_rule_numbers): + self.bot.stats.incr(f"rule_uses.{rule_number}") + final_rules.append(f"**{rule_number}.** {full_rules[rule_number - 1][0]}") await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) + return final_rule_numbers + async def setup(bot: Bot) -> None: """Load the Information cog.""" diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index d896b7652..9f5143c01 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -2,6 +2,7 @@ import textwrap import unittest import unittest.mock from datetime import datetime +from textwrap import shorten import discord @@ -573,3 +574,76 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase): create_embed.assert_called_once_with(ctx, self.target, False) ctx.send.assert_called_once() + + +class RuleCommandTests(unittest.IsolatedAsyncioTestCase): + """Tests for the `!rule` command.""" + + def setUp(self) -> None: + """Set up steps executed before each test is run.""" + self.bot = helpers.MockBot() + self.cog = information.Information(self.bot) + self.ctx = helpers.MockContext(author=helpers.MockMember(id=1, name="Bellaluma")) + self.full_rules = [ + ( + "First rule", + ["first", "number_one"] + ), + ( + "Second rule", + ["second", "number_two"] + ), + ( + "Third rule", + ["third", "number_three"] + ) + ] + self.bot.api_client.get.return_value = self.full_rules + + async def test_return_none_if_one_rule_number_is_invalid(self): + + test_cases = [ + (('1', '6', '7', '8'), (6, 7, 8)), + (('10', "first"), (10, )), + (("first", 10), (10, )) + ] + + for raw_user_input, extracted_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + invalid = ", ".join( + str(rule_number) for rule_number in extracted_rule_numbers + if rule_number < 1 or rule_number > len(self.full_rules)) + + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + + self.assertEqual( + self.ctx.send.call_args, + unittest.mock.call(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ..."))) + self.assertEqual(None, final_rule_numbers) + + async def test_return_correct_rule_numbers(self): + + test_cases = [ + (("1", "2", "first"), {1, 2}), + (("1", "hello", "2", "second"), {1}), + (("second", "third", "unknown", "999"), {2, 3}) + ] + + for raw_user_input, expected_matched_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + self.assertEqual(expected_matched_rule_numbers, final_rule_numbers) + + async def test_return_default_rules_when_no_input_or_no_match_are_found(self): + test_cases = [ + ((), None), + (("hello", "2", "second"), None), + (("hello", "999"), None), + ] + + for raw_user_input, expected_matched_rule_numbers in test_cases: + with self.subTest(identifier=raw_user_input): + final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input) + embed = self.ctx.send.call_args.kwargs['embed'] + self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description) + self.assertEqual(expected_matched_rule_numbers, final_rule_numbers) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index f3040b305..b2686b1d0 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -14,7 +14,7 @@ class DiscordMocksTests(unittest.TestCase): """Test if the default initialization of MockRole results in the correct object.""" role = helpers.MockRole() - # The `spec` argument makes sure `isistance` checks with `discord.Role` pass + # The `spec` argument makes sure `isinstance` checks with `discord.Role` pass self.assertIsInstance(role, discord.Role) self.assertEqual(role.name, "role") |