diff options
| author | 2020-05-26 17:42:05 +0700 | |
|---|---|---|
| committer | 2020-05-26 17:42:05 +0700 | |
| commit | df4e91879ea2a85459ab21e52a97102b1865fd35 (patch) | |
| tree | 6772024cd793c01fb7c8fb6479fdd4f668d7a2a8 | |
| parent | [stats] Do not report modmail channels to stats (diff) | |
| parent | Merge branch 'master' into restricted_tags (diff) | |
Merge pull request #866 from python-discord/restricted_tags
feature to restrict tags to specific role(s)
| -rw-r--r-- | bot/cogs/tags.py | 59 | 
1 files changed, 42 insertions, 17 deletions
| diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index a813ffff5..bc7f53f68 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -4,7 +4,7 @@ import time  from pathlib import Path  from typing import Callable, Dict, Iterable, List, Optional -from discord import Colour, Embed +from discord import Colour, Embed, Member  from discord.ext.commands import Cog, Context, group  from bot import constants @@ -35,21 +35,36 @@ class Tags(Cog):      @staticmethod      def get_tags() -> dict:          """Get all tags.""" -        # Save all tags in memory.          cache = {} -        tag_files = Path("bot", "resources", "tags").iterdir() -        for file in tag_files: -            tag_title = file.stem -            tag = { -                "title": tag_title, -                "embed": { -                    "description": file.read_text(encoding="utf-8") + +        base_path = Path("bot", "resources", "tags") +        for file in base_path.glob("**/*"): +            if file.is_file(): +                tag_title = file.stem +                tag = { +                    "title": tag_title, +                    "embed": { +                        "description": file.read_text(), +                    }, +                    "restricted_to": "developers",                  } -            } -            cache[tag_title] = tag + +                # Convert to a list to allow negative indexing. +                parents = list(file.relative_to(base_path).parents) +                if len(parents) > 1: +                    # -1 would be '.' hence -2 is used as the index. +                    tag["restricted_to"] = parents[-2].name + +                cache[tag_title] = tag +          return cache      @staticmethod +    def check_accessibility(user: Member, tag: dict) -> bool: +        """Check if user can access a tag.""" +        return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + +    @staticmethod      def _fuzzy_search(search: str, target: str) -> float:          """A simple scoring algorithm based on how many letters are found / total, with order in mind."""          current, index = 0, 0 @@ -93,7 +108,7 @@ class Tags(Cog):              return self._get_suggestions(tag_name)          return found -    def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str) -> list: +    def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list:          """          Search for tags via contents. @@ -114,7 +129,8 @@ class Tags(Cog):          matching_tags = []          for tag in self._cache.values(): -            if check(query in tag['embed']['description'].casefold() for query in keywords_processed): +            matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) +            if self.check_accessibility(user, tag) and check(matches):                  matching_tags.append(tag)          return matching_tags @@ -152,7 +168,7 @@ class Tags(Cog):          Only search for tags that has ALL the keywords.          """ -        matching_tags = self._get_tags_via_content(all, keywords) +        matching_tags = self._get_tags_via_content(all, keywords, ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags)      @search_tag_content.command(name='any') @@ -162,7 +178,7 @@ class Tags(Cog):          Search for tags that has ANY of the keywords.          """ -        matching_tags = self._get_tags_via_content(any, keywords or 'any') +        matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author)          await self._send_matching_tags(ctx, keywords, matching_tags)      @tags_group.command(name='get', aliases=('show', 'g')) @@ -198,7 +214,13 @@ class Tags(Cog):              return          if tag_name is not None: -            founds = self._get_tag(tag_name) +            temp_founds = self._get_tag(tag_name) + +            founds = [] + +            for found_tag in temp_founds: +                if self.check_accessibility(ctx.author, found_tag): +                    founds.append(found_tag)              if len(founds) == 1:                  tag = founds[0] @@ -237,7 +259,10 @@ class Tags(Cog):              else:                  embed: Embed = Embed(title="**Current tags**")                  await LinePaginator.paginate( -                    sorted(f"**»**   {tag['title']}" for tag in tags), +                    sorted( +                        f"**»**   {tag['title']}" for tag in tags +                        if self.check_accessibility(ctx.author, tag) +                    ),                      ctx,                      embed,                      footer_text=FOOTER_TEXT, | 
