diff options
| author | 2020-04-05 21:22:28 +0530 | |
|---|---|---|
| committer | 2020-04-05 21:22:28 +0530 | |
| commit | 72768b432b07acd3b1bfd5533c55241126329886 (patch) | |
| tree | 74908a76b754266c92fcea8339b253acca202b82 | |
| parent | Set unsilence permissions to inherit instead of true (diff) | |
Add feature to restrict tags to specific role(s)
| -rw-r--r-- | bot/cogs/tags.py | 52 |
1 files changed, 37 insertions, 15 deletions
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 539105017..9c897ad36 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -4,7 +4,7 @@ import time from pathlib import Path from typing import Callable, Dict, Iterable, List, Optional -from discord import Colour, Embed +from discord import Colour, Embed, Member from discord.ext.commands import Cog, Context, group from bot import constants @@ -36,19 +36,33 @@ class Tags(Cog): """Get all tags.""" # Save all tags in memory. cache = {} - tag_files = Path("bot", "resources", "tags").iterdir() + tag_files = Path("bot", "resources", "tags").glob("**/*") for file in tag_files: - tag_title = file.stem - tag = { - "title": tag_title, - "embed": { - "description": file.read_text() + file_path = str(file).split("/") + if file.is_file(): + tag_title = file.stem + tag = { + "title": tag_title, + "embed": { + "description": file.read_text() + }, + "restricted_to": "developers" } - } - cache[tag_title] = tag + if len(file_path) == 5: + restricted_to = file_path[3] + tag["restricted_to"] = restricted_to + + cache[tag_title] = tag return cache @staticmethod + def check_accessibility(user: Member, tag: dict) -> bool: + """Check if user can access a tag.""" + if tag["restricted_to"].lower() in [role.name.lower() for role in user.roles]: + return True + return False + + @staticmethod def _fuzzy_search(search: str, target: str) -> float: """A simple scoring algorithm based on how many letters are found / total, with order in mind.""" current, index = 0, 0 @@ -92,7 +106,7 @@ class Tags(Cog): return self._get_suggestions(tag_name) return found - def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str) -> list: + def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str, user: Member) -> list: """ Search for tags via contents. @@ -113,8 +127,9 @@ class Tags(Cog): matching_tags = [] for tag in self._cache.values(): - if check(query in tag['embed']['description'].casefold() for query in keywords_processed): - matching_tags.append(tag) + if self.check_accessibility(user, tag): + if check(query in tag['embed']['description'].casefold() for query in keywords_processed): + matching_tags.append(tag) return matching_tags @@ -151,7 +166,7 @@ class Tags(Cog): Only search for tags that has ALL the keywords. """ - matching_tags = self._get_tags_via_content(all, keywords) + matching_tags = self._get_tags_via_content(all, keywords, ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) @search_tag_content.command(name='any') @@ -161,7 +176,7 @@ class Tags(Cog): Search for tags that has ANY of the keywords. """ - matching_tags = self._get_tags_via_content(any, keywords or 'any') + matching_tags = self._get_tags_via_content(any, keywords or 'any', ctx.author) await self._send_matching_tags(ctx, keywords, matching_tags) @tags_group.command(name='get', aliases=('show', 'g')) @@ -198,6 +213,10 @@ class Tags(Cog): if tag_name is not None: founds = self._get_tag(tag_name) + for found_tag in founds: + if not self.check_accessibility(ctx.author, found_tag): + founds.remove(found_tag) + if len(founds) == 1: tag = founds[0] if ctx.channel.id not in TEST_CHANNELS: @@ -222,7 +241,10 @@ class Tags(Cog): else: embed: Embed = Embed(title="**Current tags**") await LinePaginator.paginate( - sorted(f"**»** {tag['title']}" for tag in tags), + sorted( + f"**»** {tag['title']}" for tag in tags + if self.check_accessibility(ctx.author, tag) + ), ctx, embed, footer_text=FOOTER_TEXT, |