aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bot/constants.py1
-rw-r--r--bot/exts/info/information.py65
-rw-r--r--tests/bot/exts/info/test_information.py74
-rw-r--r--tests/test_helpers.py2
4 files changed, 122 insertions, 20 deletions
diff --git a/bot/constants.py b/bot/constants.py
index db98e6f47..68a96876f 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -397,6 +397,7 @@ class Categories(metaclass=YAMLGetter):
# 2021 Summer Code Jam
summer_code_jam: int
+
class Channels(metaclass=YAMLGetter):
section = "guild"
subsection = "channels"
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index e7d17c971..2592e093d 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -3,12 +3,12 @@ import pprint
import textwrap
from collections import defaultdict
from textwrap import shorten
-from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union
+from typing import Any, DefaultDict, Mapping, Optional, Set, Tuple, Union
import rapidfuzz
from botcore.site_api import ResponseCodeError
from discord import AllowedMentions, Colour, Embed, Guild, Message, Role
-from discord.ext.commands import BucketType, Cog, Context, Greedy, Paginator, command, group, has_any_role
+from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role
from discord.utils import escape_markdown
from bot import constants
@@ -25,6 +25,12 @@ from bot.utils.members import get_or_fetch_member
log = get_logger(__name__)
+DEFAULT_RULES_DESCRIPTION = (
+ "The rules and guidelines that apply to this community can be found on"
+ " our [rules page](https://www.pythondiscord.com/pages/rules). We expect"
+ " all members of the community to have read and understood these."
+)
+
class Information(Cog):
"""A cog with commands for generating embeds with server info, such as server stats and user info."""
@@ -518,39 +524,60 @@ class Information(Cog):
await self.send_raw_content(ctx, message, json=True)
@command(aliases=("rule",))
- async def rules(self, ctx: Context, rules: Greedy[int]) -> None:
- """Provides a link to all rules or, if specified, displays specific rule(s)."""
- rules_embed = Embed(title="Rules", color=Colour.og_blurple(), url="https://www.pythondiscord.com/pages/rules")
+ async def rules(self, ctx: Context, *args: Optional[str]) -> Optional[Set[int]]:
+ """
+ Provides a link to all rules or, if specified, displays specific rule(s).
- if not rules:
- # Rules were not submitted. Return the default description.
- rules_embed.description = (
- "The rules and guidelines that apply to this community can be found on"
- " our [rules page](https://www.pythondiscord.com/pages/rules). We expect"
- " all members of the community to have read and understood these."
- )
+ It accepts either rule numbers or particular keywords that map to a particular rule.
+ Rule numbers and keywords can be sent in any order.
+ """
+ rules_embed = Embed(title="Rules", color=Colour.og_blurple(), url="https://www.pythondiscord.com/pages/rules")
+ keywords, rule_numbers = [], []
+ full_rules = await self.bot.api_client.get("rules", params={"link_format": "md"})
+ keyword_to_rule_number = dict()
+
+ for rule_number, (_, rule_keywords) in enumerate(full_rules, start=1):
+ for rule_keyword in rule_keywords:
+ keyword_to_rule_number[rule_keyword] = rule_number
+
+ for word in args:
+ try:
+ rule_numbers.append(int(word))
+ except ValueError:
+ if (kw := word.lower()) not in keyword_to_rule_number:
+ break
+ keywords.append(kw)
+
+ if not rule_numbers and not keywords:
+ # Neither rules nor keywords were submitted. Return the default description.
+ rules_embed.description = DEFAULT_RULES_DESCRIPTION
await ctx.send(embed=rules_embed)
return
- full_rules = await self.bot.api_client.get("rules", params={"link_format": "md"})
-
# Remove duplicates and sort the rule indices
- rules = sorted(set(rules))
+ rule_numbers = sorted(set(rule_numbers))
- invalid = ", ".join(str(index) for index in rules if index < 1 or index > len(full_rules))
+ invalid = ", ".join(
+ str(rule_number) for rule_number in rule_numbers
+ if rule_number < 1 or rule_number > len(full_rules))
if invalid:
await ctx.send(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ..."))
return
- for rule in rules:
- self.bot.stats.incr(f"rule_uses.{rule}")
+ final_rules = []
+ final_rule_numbers = {keyword_to_rule_number[keyword] for keyword in keywords}
+ final_rule_numbers.update(rule_numbers)
- final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules)
+ for rule_number in sorted(final_rule_numbers):
+ self.bot.stats.incr(f"rule_uses.{rule_number}")
+ final_rules.append(f"**{rule_number}.** {full_rules[rule_number - 1][0]}")
await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3)
+ return final_rule_numbers
+
async def setup(bot: Bot) -> None:
"""Load the Information cog."""
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index d896b7652..9f5143c01 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -2,6 +2,7 @@ import textwrap
import unittest
import unittest.mock
from datetime import datetime
+from textwrap import shorten
import discord
@@ -573,3 +574,76 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):
create_embed.assert_called_once_with(ctx, self.target, False)
ctx.send.assert_called_once()
+
+
+class RuleCommandTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the `!rule` command."""
+
+ def setUp(self) -> None:
+ """Set up steps executed before each test is run."""
+ self.bot = helpers.MockBot()
+ self.cog = information.Information(self.bot)
+ self.ctx = helpers.MockContext(author=helpers.MockMember(id=1, name="Bellaluma"))
+ self.full_rules = [
+ (
+ "First rule",
+ ["first", "number_one"]
+ ),
+ (
+ "Second rule",
+ ["second", "number_two"]
+ ),
+ (
+ "Third rule",
+ ["third", "number_three"]
+ )
+ ]
+ self.bot.api_client.get.return_value = self.full_rules
+
+ async def test_return_none_if_one_rule_number_is_invalid(self):
+
+ test_cases = [
+ (('1', '6', '7', '8'), (6, 7, 8)),
+ (('10', "first"), (10, )),
+ (("first", 10), (10, ))
+ ]
+
+ for raw_user_input, extracted_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ invalid = ", ".join(
+ str(rule_number) for rule_number in extracted_rule_numbers
+ if rule_number < 1 or rule_number > len(self.full_rules))
+
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+
+ self.assertEqual(
+ self.ctx.send.call_args,
+ unittest.mock.call(shorten(":x: Invalid rule indices: " + invalid, 75, placeholder=" ...")))
+ self.assertEqual(None, final_rule_numbers)
+
+ async def test_return_correct_rule_numbers(self):
+
+ test_cases = [
+ (("1", "2", "first"), {1, 2}),
+ (("1", "hello", "2", "second"), {1}),
+ (("second", "third", "unknown", "999"), {2, 3})
+ ]
+
+ for raw_user_input, expected_matched_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)
+
+ async def test_return_default_rules_when_no_input_or_no_match_are_found(self):
+ test_cases = [
+ ((), None),
+ (("hello", "2", "second"), None),
+ (("hello", "999"), None),
+ ]
+
+ for raw_user_input, expected_matched_rule_numbers in test_cases:
+ with self.subTest(identifier=raw_user_input):
+ final_rule_numbers = await self.cog.rules(self.cog, self.ctx, *raw_user_input)
+ embed = self.ctx.send.call_args.kwargs['embed']
+ self.assertEqual(information.DEFAULT_RULES_DESCRIPTION, embed.description)
+ self.assertEqual(expected_matched_rule_numbers, final_rule_numbers)
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index f3040b305..b2686b1d0 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -14,7 +14,7 @@ class DiscordMocksTests(unittest.TestCase):
"""Test if the default initialization of MockRole results in the correct object."""
role = helpers.MockRole()
- # The `spec` argument makes sure `isistance` checks with `discord.Role` pass
+ # The `spec` argument makes sure `isinstance` checks with `discord.Role` pass
self.assertIsInstance(role, discord.Role)
self.assertEqual(role.name, "role")