diff options
| -rw-r--r-- | bot/converters.py | 44 | ||||
| -rw-r--r-- | bot/exts/backend/error_handler.py | 13 | ||||
| -rw-r--r-- | bot/exts/info/source.py | 13 | ||||
| -rw-r--r-- | bot/exts/info/tags.py | 479 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 30 | ||||
| -rw-r--r-- | tests/bot/test_converters.py | 17 | 
6 files changed, 305 insertions, 291 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 0984fa0a3..559e759e1 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -18,6 +18,7 @@ from bot.api import ResponseCodeError  from bot.constants import URLs  from bot.errors import InvalidInfraction  from bot.exts.info.doc import _inventory_parser +from bot.exts.info.tags import TagIdentifier  from bot.log import get_logger  from bot.utils.extensions import EXTENSIONS, unqualify  from bot.utils.regex import INVITE_RE @@ -286,41 +287,6 @@ class Snowflake(IDConverter):          return snowflake -class TagNameConverter(Converter): -    """ -    Ensure that a proposed tag name is valid. - -    Valid tag names meet the following conditions: -        * All ASCII characters -        * Has at least one non-whitespace character -        * Not solely numeric -        * Shorter than 127 characters -    """ - -    @staticmethod -    async def convert(ctx: Context, tag_name: str) -> str: -        """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" -        tag_name = tag_name.lower().strip() - -        # The tag name has at least one invalid character. -        if ascii(tag_name)[1:-1] != tag_name: -            raise BadArgument("Don't be ridiculous, you can't use that character!") - -        # The tag name is either empty, or consists of nothing but whitespace. -        elif not tag_name: -            raise BadArgument("Tag names should not be empty, or filled with whitespace.") - -        # The tag name is longer than 127 characters. -        elif len(tag_name) > 127: -            raise BadArgument("Are you insane? That's way too long!") - -        # The tag name is ascii but does not contain any letters. -        elif not any(character.isalpha() for character in tag_name): -            raise BadArgument("Tag names must contain at least one letter.") - -        return tag_name - -  class SourceConverter(Converter):      """Convert an argument into a help command, tag, command, or cog.""" @@ -343,9 +309,10 @@ class SourceConverter(Converter):          if not tags_cog:              show_tag = False -        elif argument.lower() in tags_cog._cache: -            return argument.lower() - +        else: +            identifier = TagIdentifier.from_string(argument.lower()) +            if identifier in tags_cog.tags: +                return identifier          escaped_arg = escape_markdown(argument)          raise BadArgument( @@ -615,7 +582,6 @@ if t.TYPE_CHECKING:      ValidURL = str  # noqa: F811      Inventory = t.Tuple[str, _inventory_parser.InventoryDict]  # noqa: F811      Snowflake = int  # noqa: F811 -    TagNameConverter = str  # noqa: F811      SourceConverter = SourceType  # noqa: F811      DurationDelta = relativedelta  # noqa: F811      Duration = datetime  # noqa: F811 diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 5bef72808..c79c7b2a7 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -7,7 +7,6 @@ from sentry_sdk import push_scope  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Colours, Icons, MODERATION_ROLES -from bot.converters import TagNameConverter  from bot.errors import InvalidInfractedUserError, LockedResourceError  from bot.log import get_logger  from bot.utils.checks import ContextCheckFailure @@ -174,16 +173,8 @@ class ErrorHandler(Cog):              await self.on_command_error(ctx, tag_error)              return -        try: -            tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) -        except errors.BadArgument: -            log.debug( -                f"{ctx.author} tried to use an invalid command " -                f"and the fallback tag failed validation in TagNameConverter." -            ) -        else: -            if await ctx.invoke(tags_get_command, tag_name=tag_name): -                return +        if await ctx.invoke(tags_get_command, argument_string=ctx.message.content): +            return          if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):              await self.send_command_suggestion(ctx, ctx.invoked_with) diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 8ce25b4e8..e3e7029ca 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -8,8 +8,9 @@ from discord.ext import commands  from bot.bot import Bot  from bot.constants import URLs  from bot.converters import SourceConverter +from bot.exts.info.tags import TagIdentifier -SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, commands.ExtensionNotLoaded] +SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, TagIdentifier, commands.ExtensionNotLoaded]  class BotSource(commands.Cog): @@ -41,9 +42,9 @@ class BotSource(commands.Cog):              source_item = inspect.unwrap(source_item.callback)              src = source_item.__code__              filename = src.co_filename -        elif isinstance(source_item, str): +        elif isinstance(source_item, TagIdentifier):              tags_cog = self.bot.get_cog("Tags") -            filename = tags_cog._cache[source_item]["location"] +            filename = tags_cog.tags[source_item].file_path          else:              src = type(source_item)              try: @@ -51,7 +52,7 @@ class BotSource(commands.Cog):              except TypeError:                  raise commands.BadArgument("Cannot get source for a dynamically-created object.") -        if not isinstance(source_item, str): +        if not isinstance(source_item, TagIdentifier):              try:                  lines, first_line_no = inspect.getsourcelines(src)              except OSError: @@ -64,7 +65,7 @@ class BotSource(commands.Cog):          # Handle tag file location differently than others to avoid errors in some cases          if not first_line_no: -            file_location = Path(filename).relative_to("/bot/") +            file_location = Path(filename).relative_to("bot/")          else:              file_location = Path(filename).relative_to(Path.cwd()).as_posix() @@ -82,7 +83,7 @@ class BotSource(commands.Cog):          elif isinstance(source_object, commands.Command):              description = source_object.short_doc              title = f"Command: {source_object.qualified_name}" -        elif isinstance(source_object, str): +        elif isinstance(source_object, TagIdentifier):              title = f"Tag: {source_object}"              description = ""          else: diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 842647555..7c8d378a9 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -1,14 +1,18 @@ +from __future__ import annotations + +import enum  import re  import time  from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union -from discord import Colour, Embed, Member +import discord +import frontmatter +from discord import Embed, Member  from discord.ext.commands import Cog, Context, group  from bot import constants  from bot.bot import Bot -from bot.converters import TagNameConverter  from bot.log import get_logger  from bot.pagination import LinePaginator  from bot.utils.messages import wait_for_deletion @@ -24,99 +28,168 @@ REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE)  FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>." +class COOLDOWN(enum.Enum): +    """Sentinel value to signal that a tag is on cooldown.""" + +    obj = object() + + +class TagIdentifier(NamedTuple): +    """Stores the group and name used as an identifier for a tag.""" + +    group: Optional[str] +    name: str + +    def get_fuzzy_score(self, fuzz_tag_identifier: TagIdentifier) -> float: +        """Get fuzzy score, using `fuzz_tag_identifier` as the identifier to fuzzy match with.""" +        if (self.group is None) != (fuzz_tag_identifier.group is None): +            # Ignore tags without groups if the identifier has a group and vice versa +            return .0 +        if self.group == fuzz_tag_identifier.group: +            # Completely identical, or both None +            group_score = 1 +        else: +            group_score = _fuzzy_search(fuzz_tag_identifier.group, self.group) + +        fuzzy_score = group_score * _fuzzy_search(fuzz_tag_identifier.name, self.name) * 100 +        if fuzzy_score: +            log.trace(f"Fuzzy score {fuzzy_score:=06.2f} for tag {self!r} with fuzz {fuzz_tag_identifier!r}") +        return fuzzy_score + +    def __str__(self) -> str: +        if self.group is not None: +            return f"{self.group} {self.name}" +        else: +            return self.name + +    @classmethod +    def from_string(cls, string: str) -> TagIdentifier: +        """Create a `TagIdentifier` instance from the beginning of `string`.""" +        split_string = string.removeprefix(constants.Bot.prefix).split(" ", maxsplit=2) +        if len(split_string) == 1: +            return cls(None, split_string[0]) +        else: +            return cls(split_string[0], split_string[1]) + + +class Tag: +    """Provide an interface to a tag from resources with `file_content`.""" + +    def __init__(self, content_path: Path): +        post = frontmatter.loads(content_path.read_text("utf8")) +        self.file_path = content_path +        self.content = post.content +        self.metadata = post.metadata +        self._restricted_to: set[int] = set(self.metadata.get("restricted_to", ())) +        self._cooldowns: dict[discord.TextChannel, float] = {} + +    @property +    def embed(self) -> Embed: +        """Create an embed for the tag.""" +        embed = Embed.from_dict(self.metadata.get("embed", {})) +        embed.description = self.content +        return embed + +    def accessible_by(self, member: discord.Member) -> bool: +        """Check whether `member` can access the tag.""" +        return bool( +            not self._restricted_to +            or self._restricted_to & {role.id for role in member.roles} +        ) + +    def on_cooldown_in(self, channel: discord.TextChannel) -> bool: +        """Check whether the tag is on cooldown in `channel`.""" +        return self._cooldowns.get(channel, float("-inf")) > time.time() + +    def set_cooldown_for(self, channel: discord.TextChannel) -> None: +        """Set the tag to be on cooldown in `channel` for `constants.Cooldowns.tags` seconds.""" +        self._cooldowns[channel] = time.time() + constants.Cooldowns.tags + + +def _fuzzy_search(search: str, target: str) -> float: +    """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" +    _search = REGEX_NON_ALPHABET.sub("", search.lower()) +    if not _search: +        return 0 + +    _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + +    current = 0 +    for _target in _targets: +        index = 0 +        try: +            while index < len(_target) and _search[current] == _target[index]: +                current += 1 +                index += 1 +        except IndexError: +            # Exit when _search runs out +            break + +    return current / len(_search) + +  class Tags(Cog): -    """Save new tags and fetch existing tags.""" +    """Fetch tags by name or content.""" + +    PAGINATOR_DEFAULTS = dict(max_lines=15, empty=False, footer_text=FOOTER_TEXT)      def __init__(self, bot: Bot):          self.bot = bot -        self.tag_cooldowns = {} -        self._cache = self.get_tags() - -    @staticmethod -    def get_tags() -> dict: -        """Get all tags.""" -        cache = {} +        self.tags: dict[TagIdentifier, Tag] = {} +        self.initialize_tags() +    def initialize_tags(self) -> None: +        """Load all tags from resources into `self.tags`."""          base_path = Path("bot", "resources", "tags") +          for file in base_path.glob("**/*"):              if file.is_file(): -                tag_title = file.stem -                tag = { -                    "title": tag_title, -                    "embed": { -                        "description": file.read_text(encoding="utf8"), -                    }, -                    "restricted_to": None, -                    "location": f"/bot/{file}" -                } - -                # Convert to a list to allow negative indexing. -                parents = list(file.relative_to(base_path).parents) -                if len(parents) > 1: -                    # -1 would be '.' hence -2 is used as the index. -                    tag["restricted_to"] = parents[-2].name - -                cache[tag_title] = tag - -        return cache - -    @staticmethod -    def check_accessibility(user: Member, tag: dict) -> bool: -        """Check if user can access a tag.""" -        return not tag["restricted_to"] or tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] - -    @staticmethod -    def _fuzzy_search(search: str, target: str) -> float: -        """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" -        current, index = 0, 0 -        _search = REGEX_NON_ALPHABET.sub('', search.lower()) -        _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) -        _target = next(_targets) -        try: -            while True: -                while index < len(_target) and _search[current] == _target[index]: -                    current += 1 -                    index += 1 -                index, _target = 0, next(_targets) -        except (StopIteration, IndexError): -            pass -        return current / len(_search) * 100 - -    def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: -        """Return a list of suggested tags.""" -        scores: Dict[str, int] = { -            tag_title: Tags._fuzzy_search(tag_name, tag['title']) -            for tag_title, tag in self._cache.items() -        } - -        thresholds = thresholds or [100, 90, 80, 70, 60] - -        for threshold in thresholds: +                parent_dir = file.relative_to(base_path).parent +                tag_name = file.stem +                # Files directly under `base_path` have an empty string as the parent directory name +                tag_group = parent_dir.name or None + +                self.tags[TagIdentifier(tag_group, tag_name)] = Tag(file) + +    def _get_suggestions(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]: +        """Return a list of suggested tags for `tag_identifier`.""" +        for threshold in [100, 90, 80, 70, 60]:              suggestions = [ -                self._cache[tag_title] -                for tag_title, matching_score in scores.items() -                if matching_score >= threshold +                (identifier, tag) +                for identifier, tag in self.tags.items() +                if identifier.get_fuzzy_score(tag_identifier) >= threshold              ]              if suggestions:                  return suggestions          return [] -    def _get_tag(self, tag_name: str) -> list: -        """Get a specific tag.""" -        found = [self._cache.get(tag_name.lower(), None)] -        if not found[0]: -            return self._get_suggestions(tag_name) -        return found +    def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIdentifier, Tag]]: +        """Get tags with identifiers similar to `tag_identifier`.""" +        suggestions = [] + +        if tag_identifier.group is not None and len(tag_identifier.group) >= 3: +            # Try fuzzy matching with only a name first +            suggestions += self._get_suggestions(TagIdentifier(None, tag_identifier.group)) + +        if len(tag_identifier.name) >= 3: +            suggestions += self._get_suggestions(tag_identifier) -    def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: +        return suggestions + +    def _get_tags_via_content( +            self, +            check: Callable[[Iterable], bool], +            keywords: str, +            user: Member, +    ) -> list[tuple[TagIdentifier, Tag]]:          """          Search for tags via contents.          `predicate` will be the built-in any, all, or a custom callable. Must return a bool.          """ -        keywords_processed: List[str] = [] -        for keyword in keywords.split(','): +        keywords_processed = [] +        for keyword in keywords.split(","):              keyword_sanitized = keyword.strip().casefold()              if not keyword_sanitized:                  # this happens when there are leading / trailing / consecutive comma. @@ -124,45 +197,48 @@ class Tags(Cog):              keywords_processed.append(keyword_sanitized)          if not keywords_processed: -            # after sanitizing, we can end up with an empty list, for example when keywords is ',' +            # after sanitizing, we can end up with an empty list, for example when keywords is ","              # in that case, we simply want to search for such keywords directly instead.              keywords_processed = [keywords]          matching_tags = [] -        for tag in self._cache.values(): -            matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) -            if self.check_accessibility(user, tag) and check(matches): -                matching_tags.append(tag) +        for identifier, tag in self.tags.items(): +            matches = (query in tag.content.casefold() for query in keywords_processed) +            if tag.accessible_by(user) and check(matches): +                matching_tags.append((identifier, tag))          return matching_tags -    async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: +    async def _send_matching_tags( +            self, +            ctx: Context, +            keywords: str, +            matching_tags: list[tuple[TagIdentifier, Tag]], +    ) -> None:          """Send the result of matching tags to user.""" -        if not matching_tags: -            pass -        elif len(matching_tags) == 1: -            await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) -        else: -            is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 +        if len(matching_tags) == 1: +            await ctx.send(embed=matching_tags[0][1].embed) +        elif matching_tags: +            is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0              embed = Embed(                  title=f"Here are the tags containing the given keyword{'s' * is_plural}:", -                description='\n'.join(tag['title'] for tag in matching_tags[:10])              )              await LinePaginator.paginate( -                sorted(f"**»**   {tag['title']}" for tag in matching_tags), +                sorted( +                    f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}" +                    for identifier, _ in matching_tags +                ),                  ctx,                  embed, -                footer_text=FOOTER_TEXT, -                empty=False, -                max_lines=15 +                **self.PAGINATOR_DEFAULTS,              ) -    @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) -    async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: +    @group(name="tags", aliases=("tag", "t"), invoke_without_command=True, usage="[tag_group] [tag_name]") +    async def tags_group(self, ctx: Context, *, argument_string: Optional[str]) -> None:          """Show all known tags, a single tag, or run a subcommand.""" -        await self.get_command(ctx, tag_name=tag_name) +        await self.get_command(ctx, argument_string=argument_string) -    @tags_group.group(name='search', invoke_without_command=True) +    @tags_group.group(name="search", invoke_without_command=True)      async def search_tag_content(self, ctx: Context, *, keywords: str) -> None:          """          Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. @@ -172,123 +248,146 @@ class Tags(Cog):          matching_tags = self._get_tags_via_content(all, keywords, ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags) -    @search_tag_content.command(name='any') -    async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = 'any') -> None: +    @search_tag_content.command(name="any") +    async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None:          """          Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.          Search for tags that has ANY of the keywords.          """ -        matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) +        matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags) -    async def display_tag(self, ctx: Context, tag_name: str = None) -> bool: +    async def get_tag_embed( +            self, +            ctx: Context, +            tag_identifier: TagIdentifier, +    ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:          """ -        If a tag is not found, display similar tag names as suggestions. - -        If a tag is not specified, display a paginated embed of all tags. +        Generate an embed of the requested tag or of suggestions if the tag doesn't exist/isn't accessible by the user. -        Tags are on cooldowns on a per-tag, per-channel basis. If a tag is on cooldown, display -        nothing and return True. +        If the requested tag is on cooldown return `COOLDOWN.obj`, otherwise if no suggestions were found return None.          """ -        def _command_on_cooldown(tag_name: str) -> bool: -            """ -            Check if the command is currently on cooldown, on a per-tag, per-channel basis. - -            The cooldown duration is set in constants.py. -            """ -            now = time.time() - -            cooldown_conditions = ( -                tag_name -                and tag_name in self.tag_cooldowns -                and (now - self.tag_cooldowns[tag_name]["time"]) < constants.Cooldowns.tags -                and self.tag_cooldowns[tag_name]["channel"] == ctx.channel.id +        filtered_tags = [ +            (ident, tag) for ident, tag in +            self.get_fuzzy_matches(tag_identifier)[:10] +            if tag.accessible_by(ctx.author) +        ] + +        tag = self.tags.get(tag_identifier) +        if tag is None and len(filtered_tags) == 1: +            tag_identifier = filtered_tags[0][0] +            tag = filtered_tags[0][1] + +        if tag is not None: +            if tag.on_cooldown_in(ctx.channel): +                log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.") +                return COOLDOWN.obj +            tag.set_cooldown_for(ctx.channel) + +            self.bot.stats.incr( +                f"tags.usages" +                f"{'.' + tag_identifier.group.replace('-', '_') if tag_identifier.group else ''}" +                f".{tag_identifier.name.replace('-', '_')}"              ) +            return tag.embed -            if cooldown_conditions: -                return True -            return False - -        if _command_on_cooldown(tag_name): -            time_elapsed = time.time() - self.tag_cooldowns[tag_name]["time"] -            time_left = constants.Cooldowns.tags - time_elapsed -            log.info( -                f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " -                f"Cooldown ends in {time_left:.1f} seconds." +        else: +            if not filtered_tags: +                return None +            suggested_tags_text = "\n".join( +                f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" +                for identifier, tag in filtered_tags +                if not tag.on_cooldown_in(ctx.channel) +            ) +            return Embed( +                title="Did you mean ...", +                description=suggested_tags_text              ) -            return True - -        if tag_name is not None: -            temp_founds = self._get_tag(tag_name) - -            founds = [] - -            for found_tag in temp_founds: -                if self.check_accessibility(ctx.author, found_tag): -                    founds.append(found_tag) -            if len(founds) == 1: -                tag = founds[0] -                if ctx.channel.id not in TEST_CHANNELS: -                    self.tag_cooldowns[tag_name] = { -                        "time": time.time(), -                        "channel": ctx.channel.id -                    } +    def accessible_tags(self, user: Member) -> list[str]: +        """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted.""" +        def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: +            group, name = tag_item[0] +            if group is None: +                # Max codepoint character to force tags without a group to the end +                group = chr(0x10ffff) + +            return group + name + +        result_lines = [] +        current_group = "" +        group_accessible = True + +        for identifier, tag in sorted(self.tags.items(), key=tag_sort_key): + +            if identifier.group != current_group: +                if not group_accessible: +                    # Remove group separator line if no tags in the previous group were accessible by the user. +                    result_lines.pop() +                # A new group began, add a separator with the group name. +                current_group = identifier.group +                if current_group is not None: +                    group_accessible = False +                    result_lines.append(f"\n\N{BULLET} **{current_group}**") +                else: +                    result_lines.append("\n\N{BULLET}") + +            if tag.accessible_by(user): +                result_lines.append(f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}") +                group_accessible = True + +        return result_lines + +    def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: +        """Return a formatted list of tags in `group`, that are accessible by `user`.""" +        return sorted( +            f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" +            for identifier, tag in self.tags.items() +            if identifier.group == group and tag.accessible_by(user) +        ) + +    @tags_group.command(name="get", aliases=("show", "g"), usage="[tag_group] [tag_name]") +    async def get_command(self, ctx: Context, *, argument_string: Optional[str]) -> bool: +        """ +        If a single argument matching a group name is given, list all accessible tags from that group +        Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name. -                self.bot.stats.incr(f"tags.usages.{tag['title'].replace('-', '_')}") +        With no arguments, list all accessible tags. -                await wait_for_deletion( -                    await ctx.send(embed=Embed.from_dict(tag['embed'])), -                    [ctx.author.id], -                ) -                return True -            elif founds and len(tag_name) >= 3: -                await wait_for_deletion( -                    await ctx.send( -                        embed=Embed( -                            title='Did you mean ...', -                            description='\n'.join(tag['title'] for tag in founds[:10]) -                        ) -                    ), -                    [ctx.author.id], +        Returns True if a message was sent, or if the tag is on cooldown. +        Returns False if no message was sent. +        """  # noqa: D205, D415 +        if not argument_string: +            if self.tags: +                await LinePaginator.paginate( +                    self.accessible_tags(ctx.author), ctx, Embed(title="Available tags"), **self.PAGINATOR_DEFAULTS                  ) -                return True - -        else: -            tags = self._cache.values() -            if not tags: -                await ctx.send(embed=Embed( -                    description="**There are no tags in the database!**", -                    colour=Colour.red() -                )) -                return True              else: -                embed: Embed = Embed(title="**Current tags**") +                await ctx.send(embed=Embed(description="**There are no tags!**")) +            return True + +        identifier = TagIdentifier.from_string(argument_string) + +        if identifier.group is None: +            # Try to find accessible tags from a group matching the identifier's name. +            if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author):                  await LinePaginator.paginate( -                    sorted( -                        f"**»**   {tag['title']}" for tag in tags -                        if self.check_accessibility(ctx.author, tag) -                    ), -                    ctx, -                    embed, -                    footer_text=FOOTER_TEXT, -                    empty=False, -                    max_lines=15 +                    group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS                  )                  return True -        return False - -    @tags_group.command(name='get', aliases=('show', 'g')) -    async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> bool: -        """ -        Get a specified tag, or a list of all tags if no tag is specified. +        embed = await self.get_tag_embed(ctx, identifier) +        if embed is None: +            return False -        Returns True if something can be sent, or if the tag is on cooldown. -        Returns False if no matches are found. -        """ -        return await self.display_tag(ctx, tag_name) +        if embed is not COOLDOWN.obj: +            await wait_for_deletion( +                await ctx.send(embed=embed), +                (ctx.author.id,) +            ) +        # A valid tag was found and was either sent, or is on cooldown +        return True  def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index d12329b1f..35fa0ee59 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -337,14 +337,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):      async def test_try_get_tag_get_command(self):          """Should call `Bot.get_command` with `tags get` argument."""          self.bot.get_command.reset_mock() -        self.ctx.invoked_with = "foo"          await self.cog.try_get_tag(self.ctx)          self.bot.get_command.assert_called_once_with("tags get")      async def test_try_get_tag_invoked_from_error_handler(self):          """`self.ctx` should have `invoked_from_error_handler` `True`."""          self.ctx.invoked_from_error_handler = False -        self.ctx.invoked_with = "foo"          await self.cog.try_get_tag(self.ctx)          self.assertTrue(self.ctx.invoked_from_error_handler) @@ -359,38 +357,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):          err = errors.CommandError()          self.tag.get_command.can_run = AsyncMock(side_effect=err)          self.cog.on_command_error = AsyncMock() -        self.ctx.invoked_with = "foo"          self.assertIsNone(await self.cog.try_get_tag(self.ctx))          self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) -    @patch("bot.exts.backend.error_handler.TagNameConverter") -    async def test_try_get_tag_convert_success(self, tag_converter): -        """Converting tag should successful.""" -        self.ctx.invoked_with = "foo" -        tag_converter.convert = AsyncMock(return_value="foo") -        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) -        tag_converter.convert.assert_awaited_once_with(self.ctx, "foo") -        self.ctx.invoke.assert_awaited_once() - -    @patch("bot.exts.backend.error_handler.TagNameConverter") -    async def test_try_get_tag_convert_fail(self, tag_converter): -        """Converting tag should raise `BadArgument`.""" -        self.ctx.reset_mock() -        self.ctx.invoked_with = "bar" -        tag_converter.convert = AsyncMock(side_effect=errors.BadArgument()) -        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) -        self.ctx.invoke.assert_not_awaited() - -    async def test_try_get_tag_ctx_invoke(self): -        """Should call `ctx.invoke` with proper args/kwargs.""" -        self.ctx.reset_mock() -        self.ctx.invoked_with = "foo" -        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) -        self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo") -      async def test_dont_call_suggestion_tag_sent(self):          """Should never call command suggestion if tag is already sent.""" -        self.ctx.invoked_with = "foo" +        self.ctx.message = MagicMock(content="foo")          self.ctx.invoke = AsyncMock(return_value=True)          self.cog.send_command_suggestion = AsyncMock() diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 988b3857b..1bb678db2 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch  from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument -from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName, TagNameConverter +from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName  class ConverterTests(unittest.IsolatedAsyncioTestCase): @@ -19,21 +19,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):          cls.fixed_utc_now = datetime.fromisoformat('2019-01-01T00:00:00+00:00') -    async def test_tag_name_converter_for_invalid(self): -        """TagNameConverter should raise the correct exception for invalid tag names.""" -        test_values = ( -            ('👋', "Don't be ridiculous, you can't use that character!"), -            ('', "Tag names should not be empty, or filled with whitespace."), -            ('  ', "Tag names should not be empty, or filled with whitespace."), -            ('42', "Tag names must contain at least one letter."), -            ('x' * 128, "Are you insane? That's way too long!"), -        ) - -        for invalid_name, exception_message in test_values: -            with self.subTest(invalid_name=invalid_name, exception_message=exception_message): -                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): -                    await TagNameConverter.convert(self.context, invalid_name) -      async def test_package_name_for_valid(self):          """PackageName returns valid package names unchanged."""          test_values = ('foo', 'le_mon', 'num83r') | 
