diff options
-rw-r--r-- | bot/converters.py | 44 | ||||
-rw-r--r-- | bot/exts/backend/error_handler.py | 20 | ||||
-rw-r--r-- | bot/exts/info/source.py | 13 | ||||
-rw-r--r-- | bot/exts/info/tags.py | 487 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 36 | ||||
-rw-r--r-- | tests/bot/test_converters.py | 16 |
6 files changed, 332 insertions, 284 deletions
diff --git a/bot/converters.py b/bot/converters.py index bd4044c7e..48a5e3dc2 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -18,6 +18,7 @@ from bot import exts from bot.api import ResponseCodeError from bot.constants import URLs from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier from bot.utils.extensions import EXTENSIONS, unqualify from bot.utils.regex import INVITE_RE from bot.utils.time import parse_duration_string @@ -279,41 +280,6 @@ class Snowflake(IDConverter): return snowflake -class TagNameConverter(Converter): - """ - Ensure that a proposed tag name is valid. - - Valid tag names meet the following conditions: - * All ASCII characters - * Has at least one non-whitespace character - * Not solely numeric - * Shorter than 127 characters - """ - - @staticmethod - async def convert(ctx: Context, tag_name: str) -> str: - """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" - tag_name = tag_name.lower().strip() - - # The tag name has at least one invalid character. - if ascii(tag_name)[1:-1] != tag_name: - raise BadArgument("Don't be ridiculous, you can't use that character!") - - # The tag name is either empty, or consists of nothing but whitespace. - elif not tag_name: - raise BadArgument("Tag names should not be empty, or filled with whitespace.") - - # The tag name is longer than 127 characters. - elif len(tag_name) > 127: - raise BadArgument("Are you insane? That's way too long!") - - # The tag name is ascii but does not contain any letters. - elif not any(character.isalpha() for character in tag_name): - raise BadArgument("Tag names must contain at least one letter.") - - return tag_name - - class SourceConverter(Converter): """Convert an argument into a help command, tag, command, or cog.""" @@ -336,9 +302,10 @@ class SourceConverter(Converter): if not tags_cog: show_tag = False - elif argument.lower() in tags_cog._cache: - return argument.lower() - + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier escaped_arg = escape_markdown(argument) raise BadArgument( @@ -579,7 +546,6 @@ if t.TYPE_CHECKING: ValidURL = str # noqa: F811 Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811 Snowflake = int # noqa: F811 - TagNameConverter = str # noqa: F811 SourceConverter = SourceType # noqa: F811 DurationDelta = relativedelta # noqa: F811 Duration = datetime # noqa: F811 diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 578c372c3..128e72c84 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -9,8 +9,8 @@ from sentry_sdk import push_scope from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Colours, Icons, MODERATION_ROLES -from bot.converters import TagNameConverter from bot.errors import InvalidInfractedUserError, LockedResourceError +from bot.exts.info import tags from bot.utils.checks import ContextCheckFailure log = logging.getLogger(__name__) @@ -174,16 +174,16 @@ class ErrorHandler(Cog): await self.on_command_error(ctx, tag_error) return - try: - tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) - except errors.BadArgument: - log.debug( - f"{ctx.author} tried to use an invalid command " - f"and the fallback tag failed validation in TagNameConverter." - ) + tag_identifier = tags.TagIdentifier.from_string(ctx.message.content) + if tag_identifier.group is not None: + tag_name = tag_identifier.name + tag_name_or_group = tag_identifier.group else: - if await ctx.invoke(tags_get_command, tag_name=tag_name): - return + tag_name = None + tag_name_or_group = tag_identifier.name + + if await ctx.invoke(tags_get_command, tag_name_or_group, tag_name): + return if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): await self.send_command_suggestion(ctx, ctx.invoked_with) diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 8ce25b4e8..e3e7029ca 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -8,8 +8,9 @@ from discord.ext import commands from bot.bot import Bot from bot.constants import URLs from bot.converters import SourceConverter +from bot.exts.info.tags import TagIdentifier -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, TagIdentifier, commands.ExtensionNotLoaded] class BotSource(commands.Cog): @@ -41,9 +42,9 @@ class BotSource(commands.Cog): source_item = inspect.unwrap(source_item.callback) src = source_item.__code__ filename = src.co_filename - elif isinstance(source_item, str): + elif isinstance(source_item, TagIdentifier): tags_cog = self.bot.get_cog("Tags") - filename = tags_cog._cache[source_item]["location"] + filename = tags_cog.tags[source_item].file_path else: src = type(source_item) try: @@ -51,7 +52,7 @@ class BotSource(commands.Cog): except TypeError: raise commands.BadArgument("Cannot get source for a dynamically-created object.") - if not isinstance(source_item, str): + if not isinstance(source_item, TagIdentifier): try: lines, first_line_no = inspect.getsourcelines(src) except OSError: @@ -64,7 +65,7 @@ class BotSource(commands.Cog): # Handle tag file location differently than others to avoid errors in some cases if not first_line_no: - file_location = Path(filename).relative_to("/bot/") + file_location = Path(filename).relative_to("bot/") else: file_location = Path(filename).relative_to(Path.cwd()).as_posix() @@ -82,7 +83,7 @@ class BotSource(commands.Cog): elif isinstance(source_object, commands.Command): description = source_object.short_doc title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, str): + elif isinstance(source_object, TagIdentifier): title = f"Tag: {source_object}" description = "" else: diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index bb91a8563..3d222933a 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -1,15 +1,19 @@ +from __future__ import annotations + +import enum import logging import re import time from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union -from discord import Colour, Embed, Member +import discord +import frontmatter +from discord import Embed, Member from discord.ext.commands import Cog, Context, group from bot import constants from bot.bot import Bot -from bot.converters import TagNameConverter from bot.pagination import LinePaginator from bot.utils.messages import wait_for_deletion @@ -24,99 +28,166 @@ REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>." +class COOLDOWN(enum.Enum): + """Sentinel value to signal that a tag is on cooldown.""" + + obj = object() + + +class TagIdentifier(NamedTuple): + """Stores the group and name used as an identifier for a tag.""" + + group: Optional[str] + name: str + + def get_fuzzy_score(self, fuzz_tag_identifier: TagIdentifier) -> float: + """Get fuzzy score, using `fuzz_tag_identifier` as the identifier to fuzzy match with.""" + if (self.group is None) != (fuzz_tag_identifier.group is None): + # Ignore tags without groups if the identifier has a group and vice versa + return .0 + if self.group == fuzz_tag_identifier.group: + # Completely identical, or both None + group_score = 1 + else: + group_score = _fuzzy_search(fuzz_tag_identifier.group, self.group) + + fuzzy_score = group_score * _fuzzy_search(fuzz_tag_identifier.name, self.name) * 100 + if fuzzy_score: + log.trace(f"Fuzzy score {fuzzy_score:=06.2f} for tag {self!r} with fuzz {fuzz_tag_identifier!r}") + return fuzzy_score + + def __str__(self) -> str: + if self.group is not None: + return f"{self.group} {self.name}" + else: + return self.name + + @classmethod + def from_string(cls, string: str) -> TagIdentifier: + """Create a `TagIdentifier` instance from the beginning of `string`.""" + split_string = string.removeprefix(constants.Bot.prefix).split(" ", maxsplit=2) + if len(split_string) == 1: + return cls(None, split_string[0]) + else: + return cls(split_string[0], split_string[1]) + + +class Tag: + """Provide an interface to a tag from resources with `file_content`.""" + + def __init__(self, content_path: Path): + post = frontmatter.loads(content_path.read_text("utf8")) + self.file_path = content_path + self.content = post.content + self.metadata = post.metadata + self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ())) + self._cooldowns: dict[discord.TextChannel, float] = {} + + @property + def embed(self) -> Embed: + """Create an embed for the tag.""" + embed = Embed.from_dict(self.metadata.get("embed", {})) + embed.description = self.content + return embed + + def accessible_by(self, member: discord.Member) -> bool: + """Check whether `member` can access the tag.""" + return bool( + not self._restricted_to + or self._restricted_to & {role.id for role in member.roles} + ) + + def on_cooldown_in(self, channel: discord.TextChannel) -> bool: + """Check whether the tag is on cooldown in `channel`.""" + return channel in self._cooldowns and self._cooldowns[channel] > time.time() + + def set_cooldown_for(self, channel: discord.TextChannel) -> None: + """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds.""" + self._cooldowns[channel] = time.time() + constants.Cooldowns.tags + + +def _fuzzy_search(search: str, target: str) -> float: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + _search = REGEX_NON_ALPHABET.sub("", search.lower()) + if not _search: + return 0 + + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + + current = 0 + for _target in _targets: + index = 0 + try: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + except IndexError: + # Exit when _search runs out + break + + return current / len(_search) + + class Tags(Cog): - """Save new tags and fetch existing tags.""" + """Fetch tags by name or content.""" def __init__(self, bot: Bot): self.bot = bot - self.tag_cooldowns = {} - self._cache = self.get_tags() - - @staticmethod - def get_tags() -> dict: - """Get all tags.""" - cache = {} + self.tags: dict[TagIdentifier, Tag] = {} + self.initialize_tags() + def initialize_tags(self) -> None: + """Load all tags from resources into `self.tags`.""" base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): if file.is_file(): - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf8"), - }, - "restricted_to": None, - "location": f"/bot/{file}" - } - - # Convert to a list to allow negative indexing. - parents = list(file.relative_to(base_path).parents) - if len(parents) > 1: - # -1 would be '.' hence -2 is used as the index. - tag["restricted_to"] = parents[-2].name - - cache[tag_title] = tag - - return cache - - @staticmethod - def check_accessibility(user: Member, tag: dict) -> bool: - """Check if user can access a tag.""" - return not tag["restricted_to"] or tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - - @staticmethod - def _fuzzy_search(search: str, target: str) -> float: - """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" - current, index = 0, 0 - _search = REGEX_NON_ALPHABET.sub('', search.lower()) - _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) - _target = next(_targets) - try: - while True: - while index < len(_target) and _search[current] == _target[index]: - current += 1 - index += 1 - index, _target = 0, next(_targets) - except (StopIteration, IndexError): - pass - return current / len(_search) * 100 - - def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: - """Return a list of suggested tags.""" - scores: Dict[str, int] = { - tag_title: Tags._fuzzy_search(tag_name, tag['title']) - for tag_title, tag in self._cache.items() - } - - thresholds = thresholds or [100, 90, 80, 70, 60] - - for threshold in thresholds: + parent_dir = file.relative_to(base_path).parent + tag_name = file.stem + # Files directly under `base_path` have an empty string as the parent directory name + tag_group = parent_dir.name or None + + self.tags[TagIdentifier(tag_group, tag_name)] = Tag(file) + + def _get_suggestions(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]: + """Return a list of suggested tags for `tag_identifier`.""" + for threshold in [100, 90, 80, 70, 60]: suggestions = [ - self._cache[tag_title] - for tag_title, matching_score in scores.items() - if matching_score >= threshold + (identifier, tag) + for identifier, tag in self.tags.items() + if identifier.get_fuzzy_score(tag_identifier) >= threshold ] if suggestions: return suggestions return [] - def _get_tag(self, tag_name: str) -> list: - """Get a specific tag.""" - found = [self._cache.get(tag_name.lower(), None)] - if not found[0]: - return self._get_suggestions(tag_name) - return found + def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]: + """Get tags with identifiers similar to `tag_identifier`.""" + suggestions = [] + + if tag_identifier.group is not None and len(tag_identifier.group) >= 3: + # Try fuzzy matching with only a name first + suggestions += self._get_suggestions(TagIdentifier(None, tag_identifier.group)) + + if len(tag_identifier.name) >= 3: + suggestions += self._get_suggestions(tag_identifier) - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: + return suggestions + + def _get_tags_via_content( + self, + check: Callable[[Iterable], bool], + keywords: str, + user: Member, + ) -> list[tuple[TagIdentifier, Tag]]: """ Search for tags via contents. `predicate` will be the built-in any, all, or a custom callable. Must return a bool. """ - keywords_processed: List[str] = [] - for keyword in keywords.split(','): + keywords_processed = [] + for keyword in keywords.split(","): keyword_sanitized = keyword.strip().casefold() if not keyword_sanitized: # this happens when there are leading / trailing / consecutive comma. @@ -124,32 +195,37 @@ class Tags(Cog): keywords_processed.append(keyword_sanitized) if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # after sanitizing, we can end up with an empty list, for example when keywords is "," # in that case, we simply want to search for such keywords directly instead. keywords_processed = [keywords] matching_tags = [] - for tag in self._cache.values(): - matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) - if self.check_accessibility(user, tag) and check(matches): - matching_tags.append(tag) + for identifier, tag in self.tags.items(): + matches = (query in tag.content.casefold() for query in keywords_processed) + if tag.accessible_by(user) and check(matches): + matching_tags.append((identifier, tag)) return matching_tags - async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + async def _send_matching_tags( + self, + ctx: Context, + keywords: str, + matching_tags: list[tuple[TagIdentifier, Tag]], + ) -> None: """Send the result of matching tags to user.""" - if not matching_tags: - pass - elif len(matching_tags) == 1: - await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) - else: - is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + if len(matching_tags) == 1: + await ctx.send(embed=matching_tags[0][1].embed) + elif matching_tags: + is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0 embed = Embed( title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - description='\n'.join(tag['title'] for tag in matching_tags[:10]) ) await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in matching_tags), + sorted( + f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}" + for identifier, _ in matching_tags + ), ctx, embed, footer_text=FOOTER_TEXT, @@ -157,12 +233,17 @@ class Tags(Cog): max_lines=15 ) - @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) - async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: + @group(name="tags", aliases=("tag", "t"), invoke_without_command=True) + async def tags_group( + self, + ctx: Context, + tag_name_or_group: str = None, + tag_name: str = None, + ) -> None: """Show all known tags, a single tag, or run a subcommand.""" - await self.get_command(ctx, tag_name=tag_name) + await self.get_command(ctx, tag_name_or_group=tag_name_or_group, tag_name=tag_name) - @tags_group.group(name='search', invoke_without_command=True) + @tags_group.group(name="search", invoke_without_command=True) async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: """ Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. @@ -172,123 +253,155 @@ class Tags(Cog): matching_tags = self._get_tags_via_content(all, keywords, ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) - @search_tag_content.command(name='any') - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: + @search_tag_content.command(name="any") + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None: """ Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. Search for tags that has ANY of the keywords. """ - matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) + matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) - async def display_tag(self, ctx: Context, tag_name: str = None) -> bool: + async def get_tag_embed( + self, + ctx: Context, + tag_identifier: TagIdentifier, + ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ - If a tag is not found, display similar tag names as suggestions. - - If a tag is not specified, display a paginated embed of all tags. + Generate an embed of the requested tag or of suggestions if the tag doesn't exist/isn't accessible by the user. - Tags are on cooldowns on a per-tag, per-channel basis. If a tag is on cooldown, display - nothing and return True. + If the requested tag is on cooldown return `COOLDOWN.obj`, otherwise if no suggestions were found return None. """ - def _command_on_cooldown(tag_name: str) -> bool: - """ - Check if the command is currently on cooldown, on a per-tag, per-channel basis. - - The cooldown duration is set in constants.py. - """ - now = time.time() - - cooldown_conditions = ( - tag_name - and tag_name in self.tag_cooldowns - and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags - and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id + filtered_tags = [ + (ident, tag) for ident, tag in + self.get_fuzzy_matches(tag_identifier)[:10] + if tag.accessible_by(ctx.author) + ] + + tag = self.tags.get(tag_identifier) + if tag is None and len(filtered_tags) == 1: + tag_identifier = filtered_tags[0][0] + tag = filtered_tags[0][1] + + if tag is not None: + if tag.on_cooldown_in(ctx.channel): + log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.") + return COOLDOWN.obj + tag.set_cooldown_for(ctx.channel) + + self.bot.stats.incr( + f"tags.usages" + f"{'.' + tag_identifier.group.replace('-', '_') if tag_identifier.group else ''}" + f".{tag_identifier.name.replace('-', '_')}" ) + return tag.embed - if cooldown_conditions: - return True - return False - - if _command_on_cooldown(tag_name): - time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] - time_left = constants.Cooldowns.tags - time_elapsed - log.info( - f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds." + else: + if not filtered_tags: + return None + suggested_tags_text = "\n".join( + f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" + for identifier, tag in filtered_tags + if not tag.on_cooldown_in(ctx.channel) + ) + return Embed( + title="Did you mean ...", + description=suggested_tags_text ) - return True - - if tag_name is not None: - temp_founds = self._get_tag(tag_name) - - founds = [] - - for found_tag in temp_founds: - if self.check_accessibility(ctx.author, found_tag): - founds.append(found_tag) - if len(founds) == 1: - tag = founds[0] - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } + async def list_all_tags(self, ctx: Context) -> None: + """Send a paginator with all loaded tags accessible by `ctx.author`, groups first, and alphabetically sorted.""" + def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: + group, name = tag_item[0] + if group is None: + # Max codepoint character to force tags without a group to the end + group = chr(0x10ffff) + + return group + name + + result_lines = [] + current_group = "" + group_accessible = True + + for identifier, tag in sorted(self.tags.items(), key=tag_sort_key): + + if identifier.group != current_group: + if not group_accessible: + # Remove group separator line if no tags in the previous group were accessible by the user. + result_lines.pop() + # A new group began, add a separator with the group name. + current_group = identifier.group + if current_group is not None: + group_accessible = False + result_lines.append(f"\n\N{BULLET} **{current_group}**") + else: + result_lines.append("\n\N{BULLET}") + + if tag.accessible_by(ctx.author): + result_lines.append(f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}") + group_accessible = True + + embed = Embed(title="Current tags") + await LinePaginator.paginate(result_lines, ctx, embed, max_lines=15, empty=False, footer_text=FOOTER_TEXT) + + async def list_tags_in_group(self, ctx: Context, group: str) -> None: + """Send a paginator with all tags under `group`, that are accessible by `ctx.author`.""" + embed = Embed(title=f"Tags under *{group}*") + tag_lines = sorted( + f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" + for identifier, tag in self.tags.items() + if identifier.group == group and tag.accessible_by(ctx.author) + ) + await LinePaginator.paginate(tag_lines, ctx, embed, footer_text=FOOTER_TEXT, empty=False, max_lines=15) + + @tags_group.command(name="get", aliases=("show", "g")) + async def get_command( + self, ctx: Context, + tag_name_or_group: str = None, + tag_name: str = None, + ) -> bool: + """ + If a single argument matching a group name is given, list all accessible tags from that group + Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name. - self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") + With no arguments, list all accessible tags. - await wait_for_deletion( - await ctx.send(embed=Embed.from_dict(tag['embed'])), - [ctx.author.id], - ) + Returns True if a message was sent, or if the tag is on cooldown. + Returns False if no message was sent. + """ # noqa: D205, D415 + if tag_name_or_group is None and tag_name is None: + if self.tags: + await self.list_all_tags(ctx) return True - elif founds and len(tag_name) >= 3: - await wait_for_deletion( - await ctx.send( - embed=Embed( - title='Did you mean ...', - description='\n'.join(tag['title'] for tag in founds[:10]) - ) - ), - [ctx.author.id], - ) + else: + await ctx.send(embed=Embed(description="**There are no tags!**")) return True - else: - tags = self._cache.values() - if not tags: - await ctx.send(embed=Embed( - description="**There are no tags in the database!**", - colour=Colour.red() - )) + elif tag_name is None: + if any( + tag_name_or_group == identifier.group and tag.accessible_by(ctx.author) + for identifier, tag in self.tags.items() + ): + await self.list_tags_in_group(ctx, tag_name_or_group) return True else: - embed: Embed = Embed(title="**Current tags**") - await LinePaginator.paginate( - sorted( - f"**»** {tag['title']}" for tag in tags - if self.check_accessibility(ctx.author, tag) - ), - ctx, - embed, - footer_text=FOOTER_TEXT, - empty=False, - max_lines=15 - ) - return True - - return False + tag_name = tag_name_or_group + tag_group = None + else: + tag_group = tag_name_or_group - @tags_group.command(name='get', aliases=('show', 'g')) - async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> bool: - """ - Get a specified tag, or a list of all tags if no tag is specified. + embed = await self.get_tag_embed(ctx, TagIdentifier(tag_group, tag_name)) + if embed is None: + return False - Returns True if something can be sent, or if the tag is on cooldown. - Returns False if no matches are found. - """ - return await self.display_tag(ctx, tag_name) + if embed is not COOLDOWN.obj: + await wait_for_deletion( + await ctx.send(embed=embed), + (ctx.author.id,) + ) + # A valid tag was found and was either sent, or is on cooldown + return True def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 2b0549b98..ce59ee5fa 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -337,14 +337,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" self.bot.get_command.reset_mock() - self.ctx.invoked_with = "foo" await self.cog.try_get_tag(self.ctx) self.bot.get_command.assert_called_once_with("tags get") async def test_try_get_tag_invoked_from_error_handler(self): """`self.ctx` should have `invoked_from_error_handler` `True`.""" self.ctx.invoked_from_error_handler = False - self.ctx.invoked_with = "foo" await self.cog.try_get_tag(self.ctx) self.assertTrue(self.ctx.invoked_from_error_handler) @@ -359,38 +357,24 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.ctx.invoked_with = "foo" self.assertIsNone(await self.cog.try_get_tag(self.ctx)) self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) - @patch("bot.exts.backend.error_handler.TagNameConverter") - async def test_try_get_tag_convert_success(self, tag_converter): - """Converting tag should successful.""" - self.ctx.invoked_with = "foo" - tag_converter.convert = AsyncMock(return_value="foo") - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - tag_converter.convert.assert_awaited_once_with(self.ctx, "foo") - self.ctx.invoke.assert_awaited_once() - - @patch("bot.exts.backend.error_handler.TagNameConverter") - async def test_try_get_tag_convert_fail(self, tag_converter): - """Converting tag should raise `BadArgument`.""" - self.ctx.reset_mock() - self.ctx.invoked_with = "bar" - tag_converter.convert = AsyncMock(side_effect=errors.BadArgument()) - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - self.ctx.invoke.assert_not_awaited() - async def test_try_get_tag_ctx_invoke(self): """Should call `ctx.invoke` with proper args/kwargs.""" - self.ctx.reset_mock() - self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo") + test_cases = ( + ("foo", ("foo", None)), + ("foo bar", ("foo", "bar")), + ) + for message_content, args in test_cases: + self.ctx.reset_mock() + self.ctx.message = MagicMock(content=message_content) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, *args) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.ctx.invoked_with = "foo" + self.ctx.message = MagicMock(content="foo") self.ctx.invoke = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 6e3a6b898..f84de453d 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -11,7 +11,6 @@ from bot.converters import ( HushDurationConverter, ISODateTime, PackageName, - TagNameConverter, ) @@ -25,21 +24,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') - async def test_tag_name_converter_for_invalid(self): - """TagNameConverter should raise the correct exception for invalid tag names.""" - test_values = ( - ('👋', "Don't be ridiculous, you can't use that character!"), - ('', "Tag names should not be empty, or filled with whitespace."), - (' ', "Tag names should not be empty, or filled with whitespace."), - ('42', "Tag names must contain at least one letter."), - ('x' * 128, "Are you insane? That's way too long!"), - ) - - for invalid_name, exception_message in test_values: - with self.subTest(invalid_name=invalid_name, exception_message=exception_message): - with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): - await TagNameConverter.convert(self.context, invalid_name) - async def test_package_name_for_valid(self): """PackageName returns valid package names unchanged.""" test_values = ('foo', 'le_mon', 'num83r') |