diff options
| -rw-r--r-- | bot/cogs/tags.py | 52 | 
1 files changed, 37 insertions, 15 deletions
| diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 539105017..9c897ad36 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -4,7 +4,7 @@ import time  from pathlib import Path  from typing import Callable, Dict, Iterable, List, Optional -from discord import Colour, Embed +from discord import Colour, Embed, Member  from discord.ext.commands import Cog, Context, group  from bot import constants @@ -36,19 +36,33 @@ class Tags(Cog):          """Get all tags."""          # Save all tags in memory.          cache = {} -        tag_files = Path("bot", "resources", "tags").iterdir() +        tag_files = Path("bot", "resources", "tags").glob("**/*")          for file in tag_files: -            tag_title = file.stem -            tag = { -                "title": tag_title, -                "embed": { -                    "description": file.read_text() +            file_path = str(file).split("/") +            if file.is_file(): +                tag_title = file.stem +                tag = { +                    "title": tag_title, +                    "embed": { +                        "description": file.read_text() +                    }, +                    "restricted_to": "developers"                  } -            } -            cache[tag_title] = tag +                if len(file_path) == 5: +                    restricted_to = file_path[3] +                    tag["restricted_to"] = restricted_to + +                cache[tag_title] = tag          return cache      @staticmethod +    def check_accessibility(user: Member, tag: dict) -> bool: +        """Check if user can access a tag.""" +        if tag["restricted_to"].lower() in [role.name.lower() for role in user.roles]: +            return True +        return False + +    @staticmethod      def _fuzzy_search(search: str, target: str) -> float:          """A simple scoring algorithm based on how many letters are found / total, with order in mind."""          current, index = 0, 0 @@ -92,7 +106,7 @@ class Tags(Cog):              return self._get_suggestions(tag_name)          return found -    def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str) -> list: +    def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list:          """          Search for tags via contents. @@ -113,8 +127,9 @@ class Tags(Cog):          matching_tags = []          for tag in self._cache.values(): -            if check(query in tag['embed']['description'].casefold() for query in keywords_processed): -                matching_tags.append(tag) +            if self.check_accessibility(user, tag): +                if check(query in tag['embed']['description'].casefold() for query in keywords_processed): +                    matching_tags.append(tag)          return matching_tags @@ -151,7 +166,7 @@ class Tags(Cog):          Only search for tags that has ALL the keywords.          """ -        matching_tags = self._get_tags_via_content(all, keywords) +        matching_tags = self._get_tags_via_content(all, keywords, ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags)      @search_tag_content.command(name='any') @@ -161,7 +176,7 @@ class Tags(Cog):          Search for tags that has ANY of the keywords.          """ -        matching_tags = self._get_tags_via_content(any, keywords or 'any') +        matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags)      @tags_group.command(name='get', aliases=('show', 'g')) @@ -198,6 +213,10 @@ class Tags(Cog):          if tag_name is not None:              founds = self._get_tag(tag_name) +            for found_tag in founds: +                if not self.check_accessibility(ctx.author, found_tag): +                    founds.remove(found_tag) +              if len(founds) == 1:                  tag = founds[0]                  if ctx.channel.id not in TEST_CHANNELS: @@ -222,7 +241,10 @@ class Tags(Cog):              else:                  embed: Embed = Embed(title="**Current tags**")                  await LinePaginator.paginate( -                    sorted(f"**»**   {tag['title']}" for tag in tags), +                    sorted( +                        f"**»**   {tag['title']}" for tag in tags +                        if self.check_accessibility(ctx.author, tag) +                    ),                      ctx,                      embed,                      footer_text=FOOTER_TEXT, | 
