diff options
| -rw-r--r-- | bot/cogs/tags.py | 91 |
1 files changed, 82 insertions, 9 deletions
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 970301013..54a51921c 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -1,5 +1,7 @@ import logging +import re import time +from typing import Dict, List, Optional from discord import Colour, Embed from discord.ext.commands import Cog, Context, group @@ -10,7 +12,6 @@ from bot.converters import TagContentConverter, TagNameConverter from bot.decorators import with_role from bot.pagination import LinePaginator - log = logging.getLogger(__name__) TEST_CHANNELS = ( @@ -19,6 +20,8 @@ TEST_CHANNELS = ( Channels.helpers ) +REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) + class Tags(Cog): """Save new tags and fetch existing tags.""" @@ -27,6 +30,63 @@ class Tags(Cog): self.bot = bot self.tag_cooldowns = {} + self._cache = {} + self._last_fetch: float = 0.0 + + async def _get_tags(self, is_forced: bool = False) -> None: + """Get all tags.""" + # refresh only when there's a more than 5m gap from last call. + time_now: float = time.time() + if is_forced or not self._last_fetch or time_now - self._last_fetch > 5 * 60: + tags = await self.bot.api_client.get('bot/tags') + self._cache = {tag['title'].lower(): tag for tag in tags} + self._last_fetch = time_now + + @staticmethod + def _fuzzy_search(search: str, target: str) -> int: + """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" + current, index = 0, 0 + _search = REGEX_NON_ALPHABET.sub('', search.lower()) + _targets = iter(REGEX_NON_ALPHABET.split(target.lower())) + _target = next(_targets) + try: + while True: + while index < len(_target) and _search[current] == _target[index]: + current += 1 + index += 1 + index, _target = 0, next(_targets) + except (StopIteration, IndexError): + pass + return current / len(_search) * 100 + + def _get_suggestions(self, tag_name: str, thresholds: Optional[List[int]] = None) -> List[str]: + """Return a list of suggested tags.""" + scores: Dict[str, int] = { + tag_title: Tags._fuzzy_search(tag_name, tag['title']) + for tag_title, tag in self._cache.items() + } + + thresholds = thresholds or [100, 90, 80, 70, 60] + + for threshold in thresholds: + suggestions = [ + self._cache[tag_title] + for tag_title, matching_score in scores.items() + if matching_score >= threshold + ] + if suggestions: + return suggestions + + return [] + + async def _get_tag(self, tag_name: str) -> list: + """Get a specific tag.""" + await self._get_tags() + found = [self._cache.get(tag_name.lower(), None)] + if not found[0]: + return self._get_suggestions(tag_name) + return found + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: """Show all known tags, a single tag, or run a subcommand.""" @@ -60,17 +120,27 @@ class Tags(Cog): f"Cooldown ends in {time_left:.1f} seconds.") return + await self._get_tags() + if tag_name is not None: - tag = await self.bot.api_client.get(f'bot/tags/{tag_name}') - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - await ctx.send(embed=Embed.from_dict(tag['embed'])) + founds = await self._get_tag(tag_name) + + if len(founds) == 1: + tag = founds[0] + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + await ctx.send(embed=Embed.from_dict(tag['embed'])) + elif founds and len(tag_name) >= 3: + await ctx.send(embed=Embed( + title='Did you mean ...', + description='\n'.join(tag['title'] for tag in founds[:10]) + )) else: - tags = await self.bot.api_client.get('bot/tags') + tags = self._cache.values() if not tags: await ctx.send(embed=Embed( description="**There are no tags in the database!**", @@ -106,6 +176,7 @@ class Tags(Cog): } await self.bot.api_client.post('bot/tags', json=body) + self._cache[tag_name.lower()] = await self.bot.api_client.get(f'bot/tags/{tag_name}') log.debug(f"{ctx.author} successfully added the following tag to our database: \n" f"tag_name: {tag_name}\n" @@ -135,6 +206,7 @@ class Tags(Cog): } await self.bot.api_client.patch(f'bot/tags/{tag_name}', json=body) + self._cache[tag_name.lower()] = await self.bot.api_client.get(f'bot/tags/{tag_name}') log.debug(f"{ctx.author} successfully edited the following tag in our database: \n" f"tag_name: {tag_name}\n" @@ -151,6 +223,7 @@ class Tags(Cog): async def delete_command(self, ctx: Context, *, tag_name: TagNameConverter) -> None: """Remove a tag from the database.""" await self.bot.api_client.delete(f'bot/tags/{tag_name}') + self._cache.pop(tag_name.lower(), None) log.debug(f"{ctx.author} successfully deleted the tag called '{tag_name}'") await ctx.send(embed=Embed( |