diff options
author | 2020-05-26 17:42:05 +0700 | |
---|---|---|
committer | 2020-05-26 17:42:05 +0700 | |
commit | df4e91879ea2a85459ab21e52a97102b1865fd35 (patch) | |
tree | 6772024cd793c01fb7c8fb6479fdd4f668d7a2a8 | |
parent | [stats] Do not report modmail channels to stats (diff) | |
parent | Merge branch 'master' into restricted_tags (diff) |
Merge pull request #866 from python-discord/restricted_tags
feature to restrict tags to specific role(s)
-rw-r--r-- | bot/cogs/tags.py | 59 |
1 files changed, 42 insertions, 17 deletions
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index a813ffff5..bc7f53f68 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -4,7 +4,7 @@ import time from pathlib import Path from typing import Callable, Dict, Iterable, List, Optional -from discord import Colour, Embed +from discord import Colour, Embed, Member from discord.ext.commands import Cog, Context, group from bot import constants @@ -35,21 +35,36 @@ class Tags(Cog): @staticmethod def get_tags() -> dict: """Get all tags.""" - # Save all tags in memory. cache = {} - tag_files = Path("bot", "resources", "tags").iterdir() - for file in tag_files: - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text(encoding="utf-8") + + base_path = Path("bot", "resources", "tags") + for file in base_path.glob("**/*"): + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text(), + }, + "restricted_to": "developers", } - } - cache[tag_title] = tag + + # Convert to a list to allow negative indexing. + parents = list(file.relative_to(base_path).parents) + if len(parents) > 1: + # -1 would be '.' hence -2 is used as the index. + tag["restricted_to"] = parents[-2].name + + cache[tag_title] = tag + return cache @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + return tag["restricted_to"].lower() in [role.name.lower() for role in user.roles] + + @staticmethod def _fuzzy_search(search: str, target: str) -> float: """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" current, index = 0, 0 @@ -93,7 +108,7 @@ class Tags(Cog): return self._get_suggestions(tag_name) return found - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str) -> list: + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: """ Search for tags via contents. @@ -114,7 +129,8 @@ class Tags(Cog): matching_tags = [] for tag in self._cache.values(): - if check(query in tag['embed']['description'].casefold() for query in keywords_processed): + matches = (query in tag['embed']['description'].casefold() for query in keywords_processed) + if self.check_accessibility(user, tag) and check(matches): matching_tags.append(tag) return matching_tags @@ -152,7 +168,7 @@ class Tags(Cog): Only search for tags that has ALL the keywords. """ - matching_tags = self._get_tags_via_content(all, keywords) + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) @search_tag_content.command(name='any') @@ -162,7 +178,7 @@ class Tags(Cog): Search for tags that has ANY of the keywords. """ - matching_tags = self._get_tags_via_content(any, keywords or 'any') + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) @tags_group.command(name='get', aliases=('show', 'g')) @@ -198,7 +214,13 @@ class Tags(Cog): return if tag_name is not None: - founds = self._get_tag(tag_name) + temp_founds = self._get_tag(tag_name) + + founds = [] + + for found_tag in temp_founds: + if self.check_accessibility(ctx.author, found_tag): + founds.append(found_tag) if len(founds) == 1: tag = founds[0] @@ -237,7 +259,10 @@ class Tags(Cog): else: embed: Embed = Embed(title="**Current tags**") await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in tags), + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), ctx, embed, footer_text=FOOTER_TEXT, |