aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/converters.py44
-rw-r--r--bot/exts/backend/error_handler.py20
-rw-r--r--bot/exts/info/source.py13
-rw-r--r--bot/exts/info/tags.py487
-rw-r--r--tests/bot/exts/backend/test_error_handler.py36
-rw-r--r--tests/bot/test_converters.py16
6 files changed, 332 insertions, 284 deletions
diff --git a/bot/converters.py b/bot/converters.py
index bd4044c7e..48a5e3dc2 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -18,6 +18,7 @@ from bot import exts
from bot.api import ResponseCodeError
from bot.constants import URLs
from bot.exts.info.doc import _inventory_parser
+from bot.exts.info.tags import TagIdentifier
from bot.utils.extensions import EXTENSIONS, unqualify
from bot.utils.regex import INVITE_RE
from bot.utils.time import parse_duration_string
@@ -279,41 +280,6 @@ class Snowflake(IDConverter):
return snowflake
-class TagNameConverter(Converter):
- """
- Ensure that a proposed tag name is valid.
-
- Valid tag names meet the following conditions:
- * All ASCII characters
- * Has at least one non-whitespace character
- * Not solely numeric
- * Shorter than 127 characters
- """
-
- @staticmethod
- async def convert(ctx: Context, tag_name: str) -> str:
- """Lowercase & strip whitespace from proposed tag_name & ensure it's valid."""
- tag_name = tag_name.lower().strip()
-
- # The tag name has at least one invalid character.
- if ascii(tag_name)[1:-1] != tag_name:
- raise BadArgument("Don't be ridiculous, you can't use that character!")
-
- # The tag name is either empty, or consists of nothing but whitespace.
- elif not tag_name:
- raise BadArgument("Tag names should not be empty, or filled with whitespace.")
-
- # The tag name is longer than 127 characters.
- elif len(tag_name) > 127:
- raise BadArgument("Are you insane? That's way too long!")
-
- # The tag name is ascii but does not contain any letters.
- elif not any(character.isalpha() for character in tag_name):
- raise BadArgument("Tag names must contain at least one letter.")
-
- return tag_name
-
-
class SourceConverter(Converter):
"""Convert an argument into a help command, tag, command, or cog."""
@@ -336,9 +302,10 @@ class SourceConverter(Converter):
if not tags_cog:
show_tag = False
- elif argument.lower() in tags_cog._cache:
- return argument.lower()
-
+ else:
+ identifier = TagIdentifier.from_string(argument.lower())
+ if identifier in tags_cog.tags:
+ return identifier
escaped_arg = escape_markdown(argument)
raise BadArgument(
@@ -579,7 +546,6 @@ if t.TYPE_CHECKING:
ValidURL = str # noqa: F811
Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811
Snowflake = int # noqa: F811
- TagNameConverter = str # noqa: F811
SourceConverter = SourceType # noqa: F811
DurationDelta = relativedelta # noqa: F811
Duration = datetime # noqa: F811
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index 578c372c3..128e72c84 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -9,8 +9,8 @@ from sentry_sdk import push_scope
from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Colours, Icons, MODERATION_ROLES
-from bot.converters import TagNameConverter
from bot.errors import InvalidInfractedUserError, LockedResourceError
+from bot.exts.info import tags
from bot.utils.checks import ContextCheckFailure
log = logging.getLogger(__name__)
@@ -174,16 +174,16 @@ class ErrorHandler(Cog):
await self.on_command_error(ctx, tag_error)
return
- try:
- tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with)
- except errors.BadArgument:
- log.debug(
- f"{ctx.author} tried to use an invalid command "
- f"and the fallback tag failed validation in TagNameConverter."
- )
+ tag_identifier = tags.TagIdentifier.from_string(ctx.message.content)
+ if tag_identifier.group is not None:
+ tag_name = tag_identifier.name
+ tag_name_or_group = tag_identifier.group
else:
- if await ctx.invoke(tags_get_command, tag_name=tag_name):
- return
+ tag_name = None
+ tag_name_or_group = tag_identifier.name
+
+ if await ctx.invoke(tags_get_command, tag_name_or_group, tag_name):
+ return
if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
await self.send_command_suggestion(ctx, ctx.invoked_with)
diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py
index 8ce25b4e8..e3e7029ca 100644
--- a/bot/exts/info/source.py
+++ b/bot/exts/info/source.py
@@ -8,8 +8,9 @@ from discord.ext import commands
from bot.bot import Bot
from bot.constants import URLs
from bot.converters import SourceConverter
+from bot.exts.info.tags import TagIdentifier
-SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded]
+SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, TagIdentifier, commands.ExtensionNotLoaded]
class BotSource(commands.Cog):
@@ -41,9 +42,9 @@ class BotSource(commands.Cog):
source_item = inspect.unwrap(source_item.callback)
src = source_item.__code__
filename = src.co_filename
- elif isinstance(source_item, str):
+ elif isinstance(source_item, TagIdentifier):
tags_cog = self.bot.get_cog("Tags")
- filename = tags_cog._cache[source_item]["location"]
+ filename = tags_cog.tags[source_item].file_path
else:
src = type(source_item)
try:
@@ -51,7 +52,7 @@ class BotSource(commands.Cog):
except TypeError:
raise commands.BadArgument("Cannot get source for a dynamically-created object.")
- if not isinstance(source_item, str):
+ if not isinstance(source_item, TagIdentifier):
try:
lines, first_line_no = inspect.getsourcelines(src)
except OSError:
@@ -64,7 +65,7 @@ class BotSource(commands.Cog):
# Handle tag file location differently than others to avoid errors in some cases
if not first_line_no:
- file_location = Path(filename).relative_to("/bot/")
+ file_location = Path(filename).relative_to("bot/")
else:
file_location = Path(filename).relative_to(Path.cwd()).as_posix()
@@ -82,7 +83,7 @@ class BotSource(commands.Cog):
elif isinstance(source_object, commands.Command):
description = source_object.short_doc
title = f"Command: {source_object.qualified_name}"
- elif isinstance(source_object, str):
+ elif isinstance(source_object, TagIdentifier):
title = f"Tag: {source_object}"
description = ""
else:
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index bb91a8563..3d222933a 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -1,15 +1,19 @@
+from __future__ import annotations
+
+import enum
import logging
import re
import time
from pathlib import Path
-from typing import Callable, Dict, Iterable, List, Optional
+from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union
-from discord import Colour, Embed, Member
+import discord
+import frontmatter
+from discord import Embed, Member
from discord.ext.commands import Cog, Context, group
from bot import constants
from bot.bot import Bot
-from bot.converters import TagNameConverter
from bot.pagination import LinePaginator
from bot.utils.messages import wait_for_deletion
@@ -24,99 +28,166 @@ REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE)
FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>."
+class COOLDOWN(enum.Enum):
+ """Sentinel value to signal that a tag is on cooldown."""
+
+ obj = object()
+
+
+class TagIdentifier(NamedTuple):
+ """Stores the group and name used as an identifier for a tag."""
+
+ group: Optional[str]
+ name: str
+
+ def get_fuzzy_score(self, fuzz_tag_identifier: TagIdentifier) -> float:
+ """Get fuzzy score, using `fuzz_tag_identifier` as the identifier to fuzzy match with."""
+ if (self.group is None) != (fuzz_tag_identifier.group is None):
+ # Ignore tags without groups if the identifier has a group and vice versa
+ return .0
+ if self.group == fuzz_tag_identifier.group:
+ # Completely identical, or both None
+ group_score = 1
+ else:
+ group_score = _fuzzy_search(fuzz_tag_identifier.group, self.group)
+
+ fuzzy_score = group_score * _fuzzy_search(fuzz_tag_identifier.name, self.name) * 100
+ if fuzzy_score:
+ log.trace(f"Fuzzy score {fuzzy_score:=06.2f} for tag {self!r} with fuzz {fuzz_tag_identifier!r}")
+ return fuzzy_score
+
+ def __str__(self) -> str:
+ if self.group is not None:
+ return f"{self.group} {self.name}"
+ else:
+ return self.name
+
+ @classmethod
+ def from_string(cls, string: str) -> TagIdentifier:
+ """Create a `TagIdentifier` instance from the beginning of `string`."""
+ split_string = string.removeprefix(constants.Bot.prefix).split(" ", maxsplit=2)
+ if len(split_string) == 1:
+ return cls(None, split_string[0])
+ else:
+ return cls(split_string[0], split_string[1])
+
+
+class Tag:
+ """Provide an interface to a tag from resources with `file_content`."""
+
+ def __init__(self, content_path: Path):
+ post = frontmatter.loads(content_path.read_text("utf8"))
+ self.file_path = content_path
+ self.content = post.content
+ self.metadata = post.metadata
+ self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ()))
+ self._cooldowns: dict[discord.TextChannel, float] = {}
+
+ @property
+ def embed(self) -> Embed:
+ """Create an embed for the tag."""
+ embed = Embed.from_dict(self.metadata.get("embed", {}))
+ embed.description = self.content
+ return embed
+
+ def accessible_by(self, member: discord.Member) -> bool:
+ """Check whether `member` can access the tag."""
+ return bool(
+ not self._restricted_to
+ or self._restricted_to & {role.id for role in member.roles}
+ )
+
+ def on_cooldown_in(self, channel: discord.TextChannel) -> bool:
+ """Check whether the tag is on cooldown in `channel`."""
+ return channel in self._cooldowns and self._cooldowns[channel] > time.time()
+
+ def set_cooldown_for(self, channel: discord.TextChannel) -> None:
+ """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds."""
+ self._cooldowns[channel] = time.time() + constants.Cooldowns.tags
+
+
+def _fuzzy_search(search: str, target: str) -> float:
+ """A simple scoring algorithm based on how many letters are found / total, with order in mind."""
+ _search = REGEX_NON_ALPHABET.sub("", search.lower())
+ if not _search:
+ return 0
+
+ _targets = iter(REGEX_NON_ALPHABET.split(target.lower()))
+
+ current = 0
+ for _target in _targets:
+ index = 0
+ try:
+ while index < len(_target) and _search[current] == _target[index]:
+ current += 1
+ index += 1
+ except IndexError:
+ # Exit when _search runs out
+ break
+
+ return current / len(_search)
+
+
class Tags(Cog):
- """Save new tags and fetch existing tags."""
+ """Fetch tags by name or content."""
def __init__(self, bot: Bot):
self.bot = bot
- self.tag_cooldowns = {}
- self._cache = self.get_tags()
-
- @staticmethod
- def get_tags() -> dict:
- """Get all tags."""
- cache = {}
+ self.tags: dict[TagIdentifier, Tag] = {}
+ self.initialize_tags()
+ def initialize_tags(self) -> None:
+ """Load all tags from resources into `self.tags`."""
base_path = Path("bot", "resources", "tags")
+
for file in base_path.glob("**/*"):
if file.is_file():
- tag_title = file.stem
- tag = {
- "title": tag_title,
- "embed": {
- "description": file.read_text(encoding="utf8"),
- },
- "restricted_to": None,
- "location": f"/bot/{file}"
- }
-
- # Convert to a list to allow negative indexing.
- parents = list(file.relative_to(base_path).parents)
- if len(parents) > 1:
- # -1 would be '.' hence -2 is used as the index.
- tag["restricted_to"] = parents[-2].name
-
- cache[tag_title] = tag
-
- return cache
-
- @staticmethod
- def check_accessibility(user: Member, tag: dict) -> bool:
- """Check if user can access a tag."""
- return not tag["restricted_to"] or tag["restricted_to"].lower() in [role.name.lower() for role in user.roles]
-
- @staticmethod
- def _fuzzy_search(search: str, target: str) -> float:
- """A simple scoring algorithm based on how many letters are found / total, with order in mind."""
- current, index = 0, 0
- _search = REGEX_NON_ALPHABET.sub('', search.lower())
- _targets = iter(REGEX_NON_ALPHABET.split(target.lower()))
- _target = next(_targets)
- try:
- while True:
- while index < len(_target) and _search[current] == _target[index]:
- current += 1
- index += 1
- index, _target = 0, next(_targets)
- except (StopIteration, IndexError):
- pass
- return current / len(_search) * 100
-
- def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]:
- """Return a list of suggested tags."""
- scores: Dict[str, int] = {
- tag_title: Tags._fuzzy_search(tag_name, tag['title'])
- for tag_title, tag in self._cache.items()
- }
-
- thresholds = thresholds or [100, 90, 80, 70, 60]
-
- for threshold in thresholds:
+ parent_dir = file.relative_to(base_path).parent
+ tag_name = file.stem
+ # Files directly under `base_path` have an empty string as the parent directory name
+ tag_group = parent_dir.name or None
+
+ self.tags[TagIdentifier(tag_group, tag_name)] = Tag(file)
+
+ def _get_suggestions(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]:
+ """Return a list of suggested tags for `tag_identifier`."""
+ for threshold in [100, 90, 80, 70, 60]:
suggestions = [
- self._cache[tag_title]
- for tag_title, matching_score in scores.items()
- if matching_score >= threshold
+ (identifier, tag)
+ for identifier, tag in self.tags.items()
+ if identifier.get_fuzzy_score(tag_identifier) >= threshold
]
if suggestions:
return suggestions
return []
- def _get_tag(self, tag_name: str) -> list:
- """Get a specific tag."""
- found = [self._cache.get(tag_name.lower(), None)]
- if not found[0]:
- return self._get_suggestions(tag_name)
- return found
+ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]:
+ """Get tags with identifiers similar to `tag_identifier`."""
+ suggestions = []
+
+ if tag_identifier.group is not None and len(tag_identifier.group) >= 3:
+ # Try fuzzy matching with only a name first
+ suggestions += self._get_suggestions(TagIdentifier(None, tag_identifier.group))
+
+ if len(tag_identifier.name) >= 3:
+ suggestions += self._get_suggestions(tag_identifier)
- def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list:
+ return suggestions
+
+ def _get_tags_via_content(
+ self,
+ check: Callable[[Iterable], bool],
+ keywords: str,
+ user: Member,
+ ) -> list[tuple[TagIdentifier, Tag]]:
"""
Search for tags via contents.
`predicate` will be the built-in any, all, or a custom callable. Must return a bool.
"""
- keywords_processed: List[str] = []
- for keyword in keywords.split(','):
+ keywords_processed = []
+ for keyword in keywords.split(","):
keyword_sanitized = keyword.strip().casefold()
if not keyword_sanitized:
# this happens when there are leading / trailing / consecutive comma.
@@ -124,32 +195,37 @@ class Tags(Cog):
keywords_processed.append(keyword_sanitized)
if not keywords_processed:
- # after sanitizing, we can end up with an empty list, for example when keywords is ','
+ # after sanitizing, we can end up with an empty list, for example when keywords is ","
# in that case, we simply want to search for such keywords directly instead.
keywords_processed = [keywords]
matching_tags = []
- for tag in self._cache.values():
- matches = (query in tag['embed']['description'].casefold() for query in keywords_processed)
- if self.check_accessibility(user, tag) and check(matches):
- matching_tags.append(tag)
+ for identifier, tag in self.tags.items():
+ matches = (query in tag.content.casefold() for query in keywords_processed)
+ if tag.accessible_by(user) and check(matches):
+ matching_tags.append((identifier, tag))
return matching_tags
- async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None:
+ async def _send_matching_tags(
+ self,
+ ctx: Context,
+ keywords: str,
+ matching_tags: list[tuple[TagIdentifier, Tag]],
+ ) -> None:
"""Send the result of matching tags to user."""
- if not matching_tags:
- pass
- elif len(matching_tags) == 1:
- await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed']))
- else:
- is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0
+ if len(matching_tags) == 1:
+ await ctx.send(embed=matching_tags[0][1].embed)
+ elif matching_tags:
+ is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0
embed = Embed(
title=f"Here are the tags containing the given keyword{'s' * is_plural}:",
- description='\n'.join(tag['title'] for tag in matching_tags[:10])
)
await LinePaginator.paginate(
- sorted(f"**»** {tag['title']}" for tag in matching_tags),
+ sorted(
+ f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}"
+ for identifier, _ in matching_tags
+ ),
ctx,
embed,
footer_text=FOOTER_TEXT,
@@ -157,12 +233,17 @@ class Tags(Cog):
max_lines=15
)
- @group(name='tags', aliases=('tag', 't'), invoke_without_command=True)
- async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None:
+ @group(name="tags", aliases=("tag", "t"), invoke_without_command=True)
+ async def tags_group(
+ self,
+ ctx: Context,
+ tag_name_or_group: str = None,
+ tag_name: str = None,
+ ) -> None:
"""Show all known tags, a single tag, or run a subcommand."""
- await self.get_command(ctx, tag_name=tag_name)
+ await self.get_command(ctx, tag_name_or_group=tag_name_or_group, tag_name=tag_name)
- @tags_group.group(name='search', invoke_without_command=True)
+ @tags_group.group(name="search", invoke_without_command=True)
async def search_tag_content(self, ctx: Context, *, keywords: str) -> None:
"""
Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
@@ -172,123 +253,155 @@ class Tags(Cog):
matching_tags = self._get_tags_via_content(all, keywords, ctx.author)
await self._send_matching_tags(ctx, keywords, matching_tags)
- @search_tag_content.command(name='any')
- async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None:
+ @search_tag_content.command(name="any")
+ async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None:
"""
Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
Search for tags that has ANY of the keywords.
"""
- matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author)
+ matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author)
await self._send_matching_tags(ctx, keywords, matching_tags)
- async def display_tag(self, ctx: Context, tag_name: str = None) -> bool:
+ async def get_tag_embed(
+ self,
+ ctx: Context,
+ tag_identifier: TagIdentifier,
+ ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:
"""
- If a tag is not found, display similar tag names as suggestions.
-
- If a tag is not specified, display a paginated embed of all tags.
+ Generate an embed of the requested tag or of suggestions if the tag doesn't exist/isn't accessible by the user.
- Tags are on cooldowns on a per-tag, per-channel basis. If a tag is on cooldown, display
- nothing and return True.
+ If the requested tag is on cooldown return `COOLDOWN.obj`, otherwise if no suggestions were found return None.
"""
- def _command_on_cooldown(tag_name: str) -> bool:
- """
- Check if the command is currently on cooldown, on a per-tag, per-channel basis.
-
- The cooldown duration is set in constants.py.
- """
- now = time.time()
-
- cooldown_conditions = (
- tag_name
- and tag_name in self.tag_cooldowns
- and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags
- and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id
+ filtered_tags = [
+ (ident, tag) for ident, tag in
+ self.get_fuzzy_matches(tag_identifier)[:10]
+ if tag.accessible_by(ctx.author)
+ ]
+
+ tag = self.tags.get(tag_identifier)
+ if tag is None and len(filtered_tags) == 1:
+ tag_identifier = filtered_tags[0][0]
+ tag = filtered_tags[0][1]
+
+ if tag is not None:
+ if tag.on_cooldown_in(ctx.channel):
+ log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.")
+ return COOLDOWN.obj
+ tag.set_cooldown_for(ctx.channel)
+
+ self.bot.stats.incr(
+ f"tags.usages"
+ f"{'.' + tag_identifier.group.replace('-', '_') if tag_identifier.group else ''}"
+ f".{tag_identifier.name.replace('-', '_')}"
)
+ return tag.embed
- if cooldown_conditions:
- return True
- return False
-
- if _command_on_cooldown(tag_name):
- time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"]
- time_left = constants.Cooldowns.tags - time_elapsed
- log.info(
- f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. "
- f"Cooldown ends in {time_left:.1f} seconds."
+ else:
+ if not filtered_tags:
+ return None
+ suggested_tags_text = "\n".join(
+ f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
+ for identifier, tag in filtered_tags
+ if not tag.on_cooldown_in(ctx.channel)
+ )
+ return Embed(
+ title="Did you mean ...",
+ description=suggested_tags_text
)
- return True
-
- if tag_name is not None:
- temp_founds = self._get_tag(tag_name)
-
- founds = []
-
- for found_tag in temp_founds:
- if self.check_accessibility(ctx.author, found_tag):
- founds.append(found_tag)
- if len(founds) == 1:
- tag = founds[0]
- if ctx.channel.id not in TEST_CHANNELS:
- self.tag_cooldowns[tag_name] = {
- "time": time.time(),
- "channel": ctx.channel.id
- }
+ async def list_all_tags(self, ctx: Context) -> None:
+ """Send a paginator with all loaded tags accessible by `ctx.author`, groups first, and alphabetically sorted."""
+ def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str:
+ group, name = tag_item[0]
+ if group is None:
+ # Max codepoint character to force tags without a group to the end
+ group = chr(0x10ffff)
+
+ return group + name
+
+ result_lines = []
+ current_group = ""
+ group_accessible = True
+
+ for identifier, tag in sorted(self.tags.items(), key=tag_sort_key):
+
+ if identifier.group != current_group:
+ if not group_accessible:
+ # Remove group separator line if no tags in the previous group were accessible by the user.
+ result_lines.pop()
+ # A new group began, add a separator with the group name.
+ current_group = identifier.group
+ if current_group is not None:
+ group_accessible = False
+ result_lines.append(f"\n\N{BULLET} **{current_group}**")
+ else:
+ result_lines.append("\n\N{BULLET}")
+
+ if tag.accessible_by(ctx.author):
+ result_lines.append(f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}")
+ group_accessible = True
+
+ embed = Embed(title="Current tags")
+ await LinePaginator.paginate(result_lines, ctx, embed, max_lines=15, empty=False, footer_text=FOOTER_TEXT)
+
+ async def list_tags_in_group(self, ctx: Context, group: str) -> None:
+ """Send a paginator with all tags under `group`, that are accessible by `ctx.author`."""
+ embed = Embed(title=f"Tags under *{group}*")
+ tag_lines = sorted(
+ f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
+ for identifier, tag in self.tags.items()
+ if identifier.group == group and tag.accessible_by(ctx.author)
+ )
+ await LinePaginator.paginate(tag_lines, ctx, embed, footer_text=FOOTER_TEXT, empty=False, max_lines=15)
+
+ @tags_group.command(name="get", aliases=("show", "g"))
+ async def get_command(
+ self, ctx: Context,
+ tag_name_or_group: str = None,
+ tag_name: str = None,
+ ) -> bool:
+ """
+ If a single argument matching a group name is given, list all accessible tags from that group
+ Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name.
- self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}")
+ With no arguments, list all accessible tags.
- await wait_for_deletion(
- await ctx.send(embed=Embed.from_dict(tag['embed'])),
- [ctx.author.id],
- )
+ Returns True if a message was sent, or if the tag is on cooldown.
+ Returns False if no message was sent.
+ """ # noqa: D205, D415
+ if tag_name_or_group is None and tag_name is None:
+ if self.tags:
+ await self.list_all_tags(ctx)
return True
- elif founds and len(tag_name) >= 3:
- await wait_for_deletion(
- await ctx.send(
- embed=Embed(
- title='Did you mean ...',
- description='\n'.join(tag['title'] for tag in founds[:10])
- )
- ),
- [ctx.author.id],
- )
+ else:
+ await ctx.send(embed=Embed(description="**There are no tags!**"))
return True
- else:
- tags = self._cache.values()
- if not tags:
- await ctx.send(embed=Embed(
- description="**There are no tags in the database!**",
- colour=Colour.red()
- ))
+ elif tag_name is None:
+ if any(
+ tag_name_or_group == identifier.group and tag.accessible_by(ctx.author)
+ for identifier, tag in self.tags.items()
+ ):
+ await self.list_tags_in_group(ctx, tag_name_or_group)
return True
else:
- embed: Embed = Embed(title="**Current tags**")
- await LinePaginator.paginate(
- sorted(
- f"**»** {tag['title']}" for tag in tags
- if self.check_accessibility(ctx.author, tag)
- ),
- ctx,
- embed,
- footer_text=FOOTER_TEXT,
- empty=False,
- max_lines=15
- )
- return True
-
- return False
+ tag_name = tag_name_or_group
+ tag_group = None
+ else:
+ tag_group = tag_name_or_group
- @tags_group.command(name='get', aliases=('show', 'g'))
- async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> bool:
- """
- Get a specified tag, or a list of all tags if no tag is specified.
+ embed = await self.get_tag_embed(ctx, TagIdentifier(tag_group, tag_name))
+ if embed is None:
+ return False
- Returns True if something can be sent, or if the tag is on cooldown.
- Returns False if no matches are found.
- """
- return await self.display_tag(ctx, tag_name)
+ if embed is not COOLDOWN.obj:
+ await wait_for_deletion(
+ await ctx.send(embed=embed),
+ (ctx.author.id,)
+ )
+ # A valid tag was found and was either sent, or is on cooldown
+ return True
def setup(bot: Bot) -> None:
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 2b0549b98..ce59ee5fa 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -337,14 +337,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_try_get_tag_get_command(self):
"""Should call `Bot.get_command` with `tags get` argument."""
self.bot.get_command.reset_mock()
- self.ctx.invoked_with = "foo"
await self.cog.try_get_tag(self.ctx)
self.bot.get_command.assert_called_once_with("tags get")
async def test_try_get_tag_invoked_from_error_handler(self):
"""`self.ctx` should have `invoked_from_error_handler` `True`."""
self.ctx.invoked_from_error_handler = False
- self.ctx.invoked_with = "foo"
await self.cog.try_get_tag(self.ctx)
self.assertTrue(self.ctx.invoked_from_error_handler)
@@ -359,38 +357,24 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
err = errors.CommandError()
self.tag.get_command.can_run = AsyncMock(side_effect=err)
self.cog.on_command_error = AsyncMock()
- self.ctx.invoked_with = "foo"
self.assertIsNone(await self.cog.try_get_tag(self.ctx))
self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
- @patch("bot.exts.backend.error_handler.TagNameConverter")
- async def test_try_get_tag_convert_success(self, tag_converter):
- """Converting tag should successful."""
- self.ctx.invoked_with = "foo"
- tag_converter.convert = AsyncMock(return_value="foo")
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- tag_converter.convert.assert_awaited_once_with(self.ctx, "foo")
- self.ctx.invoke.assert_awaited_once()
-
- @patch("bot.exts.backend.error_handler.TagNameConverter")
- async def test_try_get_tag_convert_fail(self, tag_converter):
- """Converting tag should raise `BadArgument`."""
- self.ctx.reset_mock()
- self.ctx.invoked_with = "bar"
- tag_converter.convert = AsyncMock(side_effect=errors.BadArgument())
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- self.ctx.invoke.assert_not_awaited()
-
async def test_try_get_tag_ctx_invoke(self):
"""Should call `ctx.invoke` with proper args/kwargs."""
- self.ctx.reset_mock()
- self.ctx.invoked_with = "foo"
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo")
+ test_cases = (
+ ("foo", ("foo", None)),
+ ("foo bar", ("foo", "bar")),
+ )
+ for message_content, args in test_cases:
+ self.ctx.reset_mock()
+ self.ctx.message = MagicMock(content=message_content)
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, *args)
async def test_dont_call_suggestion_tag_sent(self):
"""Should never call command suggestion if tag is already sent."""
- self.ctx.invoked_with = "foo"
+ self.ctx.message = MagicMock(content="foo")
self.ctx.invoke = AsyncMock(return_value=True)
self.cog.send_command_suggestion = AsyncMock()
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index 6e3a6b898..f84de453d 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -11,7 +11,6 @@ from bot.converters import (
HushDurationConverter,
ISODateTime,
PackageName,
- TagNameConverter,
)
@@ -25,21 +24,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
- async def test_tag_name_converter_for_invalid(self):
- """TagNameConverter should raise the correct exception for invalid tag names."""
- test_values = (
- ('👋', "Don't be ridiculous, you can't use that character!"),
- ('', "Tag names should not be empty, or filled with whitespace."),
- (' ', "Tag names should not be empty, or filled with whitespace."),
- ('42', "Tag names must contain at least one letter."),
- ('x' * 128, "Are you insane? That's way too long!"),
- )
-
- for invalid_name, exception_message in test_values:
- with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
- await TagNameConverter.convert(self.context, invalid_name)
-
async def test_package_name_for_valid(self):
"""PackageName returns valid package names unchanged."""
test_values = ('foo', 'le_mon', 'num83r')