diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/tags.py | 36 | 
1 files changed, 24 insertions, 12 deletions
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 9e06b702c..8d3586b19 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -1,5 +1,6 @@  import logging  import time +from typing import Dict, List, Optional  from discord import Colour, Embed  from discord.ext.commands import Cog, Context, group @@ -39,9 +40,9 @@ class Tags(Cog):              self._last_fetch = time_now      @staticmethod -    def _fuzzy_search(search: str, target: str) -> bool: -        found = 0 -        index = 0 +    def _fuzzy_search(search: str, target: str) -> int: +        """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" +        found, index = 0, 0          _search = search.lower().replace(' ', '')          _target = target.lower().replace(' ', '')          for letter in _search: @@ -51,19 +52,32 @@ class Tags(Cog):              found += index > 0          return found / len(_search) * 100 -    def _get_suggestions(self, tag_name: str, score: int) -> list: -        return sorted( -            (tag for tag in self._cache.values() if Tags._fuzzy_search(tag_name, tag['title']) >= score), -            key=lambda tag: Tags._fuzzy_search(tag_name, tag['title']), -            reverse=True -        ) +    def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: +        """Return a list of suggested tags.""" +        scores: Dict[str, int] = { +            tag_title: Tags._fuzzy_search(tag_name, tag['title']) +            for tag_title, tag in self._cache.items() +        } + +        thresholds = thresholds or [100, 80] + +        for threshold in thresholds: +            suggestions = [ +                self._cache[tag_title] +                for tag_title, matching_score in scores.items() +                if matching_score >= threshold +            ] +            if suggestions: +                return suggestions + +        return []      async def _get_tag(self, tag_name: str) -> list:          """Get a specific tag."""          await self._get_tags()          found = [self._cache.get(tag_name.lower(), None)]          if not found[0]: -            return self._get_suggestions(tag_name, 100) or self._get_suggestions(tag_name, 80) +            return self._get_suggestions(tag_name, thresholds=[100, 80])          return found      @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) @@ -102,7 +116,6 @@ class Tags(Cog):          await self._get_tags()          if tag_name is not None: -            # tag = await self.bot.api_client.get(f'bot/tags/{tag_name}')              founds = await self._get_tag(tag_name)              if len(founds) == 1: @@ -120,7 +133,6 @@ class Tags(Cog):                  ))          else: -            # tags = await self.bot.api_client.get('bot/tags')              tags = self._cache.values()              if not tags:                  await ctx.send(embed=Embed(  |