From 1b6f3d23d4c0b1f6dfe1354a3a210e589f7b4956 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 18:40:02 +0200 Subject: Make sure that poor code does not contains token Added a new function `is_token_in_message` in `token_remover`. This function returns a `bool` and if the code contains a token then the embed message about the poorly formatted code is not displayed. --- bot/cogs/bot.py | 3 ++- bot/cogs/token_remover.py | 32 +++++++++++++++++++------------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 7583b2f2d..e8ac0a234 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple from discord import Embed, Message, RawMessageUpdateEvent from discord.ext.commands import Bot, Cog, Context, command, group +from bot.cogs.token_remover import TokenRemover from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion @@ -237,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock: + if parse_codeblock and not TokenRemover.is_token_in_message: # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 7dd0afbbd..8f356cf19 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -52,19 +52,8 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ - if msg.author.bot: - return - - maybe_match = TOKEN_RE.search(msg.content) - if maybe_match is None: - return - - try: - user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') - except ValueError: - return - - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + if self.is_token_in_message(msg): + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') self.mod_log.ignore(Event.message_delete, msg.id) await msg.delete() await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) @@ -86,6 +75,23 @@ class TokenRemover(Cog): channel_id=Channels.mod_alerts, ) + def is_token_in_message(self, msg: Message) -> bool: + """Check if `msg` contains a seemly valid token.""" + if msg.author.bot: + return False + + maybe_match = TOKEN_RE.search(msg.content) + if maybe_match is None: + return False + + try: + user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') + except ValueError: + return False + + if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + return True + @staticmethod def is_valid_user_id(b64_content: str) -> bool: """ -- cgit v1.2.3 From 2899bac85c3c0529b354a762ba27a587a520d7cd Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 18:51:18 +0200 Subject: minor fix --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index e8ac0a234..729550c1a 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message: # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(msg): # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From b94e8487c22d7c25ab09bb3d44c44d62e5a2b613 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 21:33:43 +0200 Subject: Another fix After a new bunch of test I found bugs, and this fix resolves them --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 729550c1a..b8de29f2a 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(msg): # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From 25d4c05b2656ce8d9454269c77d42e18fb1ba785 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 21:42:49 +0200 Subject: fix linting error fix linting error --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index b8de29f2a..eab253681 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From 2cb7ee12805957c7d655679ff54a14f16e059a80 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6880aab85..bf4403ce4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -34,6 +41,66 @@ class Reddit(Cog): self.new_posts_task = None self.top_weekly_posts_task = None + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -43,11 +110,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -55,7 +122,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -127,8 +194,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From aa0096469e12546b0eadbb4b214cd3cae3a3a80d Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Sat, 12 Oct 2019 14:54:58 +0200 Subject: Use a `classmethod` --- bot/cogs/bot.py | 2 +- bot/cogs/token_remover.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index eab253681..53221cd8b 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # no token in the msg + if parse_codeblock and not TokenRemover.is_token_in_message(msg): # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 8f356cf19..5e83a777e 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -75,6 +75,7 @@ class TokenRemover(Cog): channel_id=Channels.mod_alerts, ) + @classmethod def is_token_in_message(self, msg: Message) -> bool: """Check if `msg` contains a seemly valid token.""" if msg.author.bot: -- cgit v1.2.3 From 9b4b98eb6f37060771c61979458737629b3c5db7 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..25df014f8 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -35,6 +42,65 @@ class Reddit(Cog): self.top_weekly_posts_task = None self.bot.loop.create_task(self.init_reddit_polling()) + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -45,11 +111,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -57,7 +123,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -129,8 +195,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From 79ed098809ce3cfaa0fa75608f6f6a85af2a90dd Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 15 Oct 2019 23:36:36 +0200 Subject: Unload cog on auth error and fix linting warnings --- bot/cogs/reddit.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 25df014f8..451d2bf4c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,10 @@ import asyncio import logging import random import textwrap -from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, Message, TextChannel from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group @@ -46,7 +46,7 @@ class Reddit(Cog): @tasks.loop(hours=0.99) # access tokens are valid for one hour async def refresh_access_token(self) -> None: - """Refresh the access token""" + """Refresh Reddits access token.""" headers = {"Authorization": self.client_auth} data = { "grant_type": "refresh_token", @@ -54,7 +54,7 @@ class Reddit(Cog): } response = await self.bot.http_session.post( - url = f"{self.URL}/api/v1/access_token", + url=f"{self.URL}/api/v1/access_token", headers=headers, data=data, ) @@ -68,7 +68,7 @@ class Reddit(Cog): @refresh_access_token.before_loop async def get_tokens(self) -> None: - """Get Reddit access and refresh tokens""" + """Get Reddit access and refresh tokens.""" await self.bot.wait_until_ready() headers = {"User-Agent": self.USER_AGENT} @@ -77,16 +77,16 @@ class Reddit(Cog): "duration": "permanent" } - if RedditConfig.client_id and RedditConfig.secret: - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=headers, - auth=self.client_auth, - data=data - ) + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() self.access_token = content["access_token"] self.refresh_token = content["refresh_token"] @@ -95,12 +95,9 @@ class Reddit(Cog): "User-Agent": self.USER_AGENT } else: - self.client_auth = None - self.access_token = None - self.refresh_token = None - self.headers = None - - log.error("Unable to find client credentials.") + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -123,7 +120,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") -- cgit v1.2.3 From 4efb97c5020f591d8cdd1e214e06df294e72d8f1 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 20 Oct 2019 18:32:25 +0200 Subject: add handling for duplicate symbols in docs inventories --- bot/cogs/doc.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index a13464bff..43315f477 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -23,7 +23,17 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) logging.getLogger('urllib3').setLevel(logging.WARNING) - +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", + "function" +) +NO_OVERRIDE_PACKAGES = ( + "Python", +) UNWANTED_SIGNATURE_SYMBOLS = ('[source]', '¶') WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") @@ -125,6 +135,7 @@ class Doc(commands.Cog): self.base_urls = {} self.bot = bot self.inventories = {} + self.renamed_symbols = set() self.bot.loop.create_task(self.init_refresh_inventory()) @@ -151,12 +162,32 @@ class Doc(commands.Cog): self.base_urls[package_name] = base_url fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) - for _, value in (await self.bot.loop.run_in_executor(None, fetch_func)).items(): + for group, value in (await self.bot.loop.run_in_executor(None, fetch_func)).items(): # Each value has a bunch of information in the form # `(package_name, version, relative_url, ???)`, and we only - # need the relative documentation URL. - for symbol, (_, _, relative_doc_url, _) in value.items(): + # need the package_name and the relative documentation URL. + for symbol, (package_name, _, relative_doc_url, _) in value.items(): absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + # get `group_name` from _:group_name + group_name = group.split(":")[1] + if (group_name in NO_OVERRIDE_GROUPS + # check if any package from `NO_OVERRIDE_PACKAGES` + # is in base URL of the symbol that would be overridden + or any(package in self.inventories[symbol].split("/", 3)[2] + for package in NO_OVERRIDE_PACKAGES)): + + symbol = f"{group_name}.{symbol}" + # if renamed `symbol` was already exists, add library name in front + if symbol in self.renamed_symbols: + # split `package_name` because of packages like Pillow that have spaces in them + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + self.inventories[symbol] = absolute_doc_url log.trace(f"Fetched inventory for {package_name}.") -- cgit v1.2.3 From f1dbb63e6c4a7ed38f8bed994c109e638498d546 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 20 Oct 2019 18:39:08 +0200 Subject: show renamed duplicates in embed footer --- bot/cogs/doc.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 43315f477..ecff43864 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -281,18 +281,23 @@ class Doc(commands.Cog): if not signature: # It's some "meta-page", for example: # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - return discord.Embed( + embed = discord.Embed( title=f'`{symbol}`', url=permalink, description="This appears to be a generic page not tied to a specific symbol." ) - - signature = textwrap.shorten(signature, 500) - return discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=f"```py\n{signature}```{description}" - ) + else: + signature = textwrap.shorten(signature, 500) + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=f"```py\n{signature}```{description}" + ) + # show all symbols with the same name that were renamed in the footer + embed.set_footer(text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} + if renamed.endswith(f".{symbol}")) + ) + return embed @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: -- cgit v1.2.3 From a05f28c97d0f2ea9d3dafcdbd24444c59905af84 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 20 Oct 2019 18:42:59 +0200 Subject: Auto delete messages when docs are not found --- bot/cogs/doc.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index ecff43864..9bb21cce3 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -4,17 +4,19 @@ import logging import re import textwrap from collections import OrderedDict +from contextlib import suppress from typing import Any, Callable, Optional, Tuple import discord from bs4 import BeautifulSoup from bs4.element import PageElement +from discord.errors import NotFound from discord.ext import commands from markdownify import MarkdownConverter from requests import ConnectionError from sphinx.ext import intersphinx -from bot.constants import MODERATION_ROLES +from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import ValidPythonIdentifier, ValidURL from bot.decorators import with_role from bot.pagination import LinePaginator @@ -23,6 +25,7 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) logging.getLogger('urllib3').setLevel(logging.WARNING) +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay NO_OVERRIDE_GROUPS = ( "2to3fixer", "token", @@ -343,7 +346,10 @@ class Doc(commands.Cog): description=f"Sorry, I could not find any documentation for `{symbol}`.", colour=discord.Colour.red() ) - await ctx.send(embed=error_embed) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) else: await ctx.send(embed=doc_embed) -- cgit v1.2.3 From eda6cd7ff818454ad7bf448040a87ff0077025bc Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 20 Oct 2019 21:15:12 +0200 Subject: remove "function" from NO_OVERRIDE_GROUPS --- bot/cogs/doc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 9bb21cce3..f1213d170 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -32,7 +32,6 @@ NO_OVERRIDE_GROUPS = ( "label", "pdbcommand", "term", - "function" ) NO_OVERRIDE_PACKAGES = ( "Python", -- cgit v1.2.3 From bc907daa428d755d7f2cb0a6b945a179d523b31d Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Sun, 20 Oct 2019 21:35:38 +0200 Subject: Add check when a message is edited --- bot/cogs/token_remover.py | 60 +++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5e83a777e..e5b0e5b45 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -53,30 +53,44 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ if self.is_token_in_message(msg): - user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') - self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - message = ( - "Censored a seemingly valid token sent by " - f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " - f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" - ) - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) + await self.take_action(msg) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + if self.is_token_in_message(after): + await self.take_action(after) + + async def take_action(self, msg: Message) -> None: + """Remove the `msg` containing a token an send a mod_log message.""" + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') + self.mod_log.ignore(Event.message_delete, msg.id) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + message = ( + "Censored a seemingly valid token sent by " + f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " + f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" + ) + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) @classmethod - def is_token_in_message(self, msg: Message) -> bool: + def is_token_in_message(cls, msg: Message) -> bool: """Check if `msg` contains a seemly valid token.""" if msg.author.bot: return False @@ -90,7 +104,7 @@ class TokenRemover(Cog): except ValueError: return False - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): return True @staticmethod -- cgit v1.2.3 From d5dea25fef79e16d726f1f0ce8d2bb25291d6c49 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 21 Oct 2019 22:09:46 +0200 Subject: Don't include a signature and only get first paragraphs when scraping when symbol is a module --- bot/cogs/doc.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index f1213d170..a13552ac0 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -222,12 +222,11 @@ class Doc(commands.Cog): """ Given a Python symbol, return its signature and description. - Returns a tuple in the form (str, str), or `None`. - The first tuple element is the signature of the given symbol as a markup-free string, and the second tuple element is the description of the given symbol with HTML markup included. - If the given symbol could not be found, returns `None`. + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. """ url = self.inventories.get(symbol) if url is None: @@ -245,14 +244,23 @@ class Doc(commands.Cog): if symbol_heading is None: return None - # Traverse the tags of the signature header and ignore any - # unwanted symbols from it. Add all of it to a temporary buffer. - for tag in symbol_heading.strings: - if tag not in UNWANTED_SIGNATURE_SYMBOLS: - signature_buffer.append(tag.replace('\\', '')) + if symbol_id == f"module-{symbol}": + # Get all paragraphs until the first div after the section div + # if searched symbol is a module. + trailing_div = symbol_heading.findNext("div") + info_paragraphs = trailing_div.find_previous_siblings("p")[::-1] + signature = None + description = ''.join(str(paragraph) for paragraph in info_paragraphs).replace('¶', '') - signature = ''.join(signature_buffer) - description = str(symbol_heading.next_sibling.next_sibling).replace('¶', '') + else: + # Traverse the tags of the signature header and ignore any + # unwanted symbols from it. Add all of it to a temporary buffer. + + for tag in symbol_heading.strings: + if tag not in UNWANTED_SIGNATURE_SYMBOLS: + signature_buffer.append(tag.replace('\\', '')) + signature = ''.join(signature_buffer) + description = str(symbol_heading.next_sibling.next_sibling).replace('¶', '') return signature, description -- cgit v1.2.3 From 55b276a1f7e56a950e215bd8289b7f946b2f180e Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 21 Oct 2019 22:10:45 +0200 Subject: Allow embeds to not include signatures in case the symbol is a module --- bot/cogs/doc.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index a13552ac0..0c370f665 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -288,21 +288,24 @@ class Doc(commands.Cog): description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if not signature: + if signature is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signature: # It's some "meta-page", for example: # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description="This appears to be a generic page not tied to a specific symbol." - ) + embed_description = "This appears to be a generic page not tied to a specific symbol." + else: signature = textwrap.shorten(signature, 500) - embed = discord.Embed( - title=f'`{symbol}`', - url=permalink, - description=f"```py\n{signature}```{description}" - ) + embed_description = f"```py\n{signature}```{description}" + + embed = discord.Embed( + title=f'`{symbol}`', + url=permalink, + description=embed_description + ) # show all symbols with the same name that were renamed in the footer embed.set_footer(text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) -- cgit v1.2.3 From 09f5cd78142201ff0133a25ee1ea6cff1c739e1f Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 21 Oct 2019 22:11:20 +0200 Subject: Grammar check comment --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 0c370f665..8b81b3053 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -306,7 +306,7 @@ class Doc(commands.Cog): url=permalink, description=embed_description ) - # show all symbols with the same name that were renamed in the footer + # Show all symbols with the same name that were renamed in the footer. embed.set_footer(text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) ) -- cgit v1.2.3 From 5ec4db0cba484f8adfc25b642a4f24f362a5b53c Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 22 Oct 2019 22:05:50 +0200 Subject: Add reddit environment variable, change User-Agent and fix lint problem --- bot/cogs/reddit.py | 4 ++-- docker-compose.yml | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7b183221c..76da0f09f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -25,7 +25,7 @@ class Reddit(Cog): # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: # :: (by /u/) - USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 @@ -117,7 +117,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From 2c0fffbc697db5f33e759f733c310f9f0b754d11 Mon Sep 17 00:00:00 2001 From: Jens Date: Sat, 26 Oct 2019 17:24:09 +0200 Subject: Fix linting error --- bot/cogs/reddit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7e2ba40d5..64a940af1 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -20,6 +20,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" + # Change your client's User-Agent string to something unique and descriptive, # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: -- cgit v1.2.3 From 7f1d319a11de5e443307517ff1fd55fe87a69bb3 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:50:20 +0200 Subject: Add duck-pond constants. This adds the emojis, the channel, and the configuration needed for the duck-pond feature. This is added both to config-default.yml, and to the constants.py file. --- bot/constants.py | 57 +++++++++++++++++++++++++++++++----------------------- config-default.yml | 15 +++++++++----- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/bot/constants.py b/bot/constants.py index 838fe7a79..ed1e65cca 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -236,6 +236,14 @@ class Colours(metaclass=YAMLGetter): soft_orange: int +class DuckPond(metaclass=YAMLGetter): + section = "duck_pond" + + ducks_required: int + duck_custom_emojis: List[int] + duck_pond_channel: int + + class Emojis(metaclass=YAMLGetter): section = "style" subsection = "emojis" @@ -370,6 +378,7 @@ class Webhooks(metaclass=YAMLGetter): talent_pool: int big_brother: int reddit: int + duck_pond: int class Roles(metaclass=YAMLGetter): @@ -501,6 +510,30 @@ class RedirectOutput(metaclass=YAMLGetter): delete_delay: int +class Event(Enum): + """ + Event names. This does not include every event (for example, raw + events aren't here), but only events used in ModLog for now. + """ + + guild_channel_create = "guild_channel_create" + guild_channel_delete = "guild_channel_delete" + guild_channel_update = "guild_channel_update" + guild_role_create = "guild_role_create" + guild_role_delete = "guild_role_delete" + guild_role_update = "guild_role_update" + guild_update = "guild_update" + + member_join = "member_join" + member_remove = "member_remove" + member_ban = "member_ban" + member_unban = "member_unban" + member_update = "member_update" + + message_delete = "message_delete" + message_edit = "message_edit" + + # Debug mode DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False @@ -572,27 +605,3 @@ ERROR_REPLIES = [ "Noooooo!!", "I can't believe you've done this", ] - - -class Event(Enum): - """ - Event names. This does not include every event (for example, raw - events aren't here), but only events used in ModLog for now. - """ - - guild_channel_create = "guild_channel_create" - guild_channel_delete = "guild_channel_delete" - guild_channel_update = "guild_channel_update" - guild_role_create = "guild_role_create" - guild_role_delete = "guild_role_delete" - guild_role_update = "guild_role_update" - guild_update = "guild_update" - - member_join = "member_join" - member_remove = "member_remove" - member_ban = "member_ban" - member_unban = "member_unban" - member_update = "member_update" - - message_delete = "message_delete" - message_edit = "message_edit" diff --git a/config-default.yml b/config-default.yml index 4638a89ee..bad9c72db 100644 --- a/config-default.yml +++ b/config-default.yml @@ -22,11 +22,6 @@ style: defcon_enabled: "<:defconenabled:470326274213150730>" defcon_updated: "<:defconsettingsupdated:470326274082996224>" - green_chevron: "<:greenchevron:418104310329769993>" - red_chevron: "<:redchevron:418112778184818698>" - white_chevron: "<:whitechevron:418110396973711363>" - bb_message: "<:bbmessage:476273120999636992>" - status_online: "<:status_online:470326272351010816>" status_idle: "<:status_idle:470326266625785866>" status_dnd: "<:status_dnd:470326272082313216>" @@ -37,6 +32,9 @@ style: new: "\U0001F195" cross_mark: "\u274C" + ducky: &DUCKY_EMOJI 574951975574175744 + ducky_blurple: &DUCKY_BLURPLE_EMOJI 574951975310065675 + icons: crown_blurple: "https://cdn.discordapp.com/emojis/469964153289965568.png" crown_green: "https://cdn.discordapp.com/emojis/469964154719961088.png" @@ -98,6 +96,7 @@ guild: defcon: &DEFCON 464469101889454091 devlog: &DEVLOG 622895325144940554 devtest: &DEVTEST 414574275865870337 + duck_pond: &DUCK_POND 000000000000000000 help_0: 303906576991780866 help_1: 303906556754395136 help_2: 303906514266226689 @@ -148,6 +147,7 @@ guild: talent_pool: 569145364800602132 big_brother: 569133704568373283 reddit: 635408384794951680 + duck_pond: 637779355304722435 filter: @@ -382,5 +382,10 @@ redirect_output: delete_invocation: true delete_delay: 15 +duck_pond: + ducks_required: 5 + duck_custom_emojis: [*DUCKY_EMOJI, *DUCKY_BLURPLE_EMOJI] + duck_pond_channel: *DUCK_POND + config: required_keys: ['bot.token'] -- cgit v1.2.3 From 957f46226a6c9cbc9e86bab8a4365665d479885f Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:56:20 +0200 Subject: Add duck_pond cog. This cog will listen for duck reactions on any message, and then: - If the reaction was added by a staff member - and the reaction was a duck - and the message has not already been added to the #duck-pond It will add the message to the #duck-pond and then add a green checkbox to the original message to indicate that the message has been ponded. Messages are added to the #duck-pond via webhook, so that they can retain the appearance of having their original authors. Once this checkmark has been added, the message will not be processed in the future. If the checkmark is removed and there are more than ducks_required ducks on the message, the bot will automatically add the checkmark back. However, if all reactions are removed, the bot does not have a countermeasure for this. In order to implement a countermeasure, it would be necessary to involve the API and the database. --- bot/__main__.py | 1 + bot/cogs/duck_pond.py | 206 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 bot/cogs/duck_pond.py diff --git a/bot/__main__.py b/bot/__main__.py index f352cd60e..ea7c43a12 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -55,6 +55,7 @@ if not DEBUG_MODE: bot.load_extension("bot.cogs.alias") bot.load_extension("bot.cogs.defcon") bot.load_extension("bot.cogs.eval") +bot.load_extension("bot.cogs.duck_pond") bot.load_extension("bot.cogs.free") bot.load_extension("bot.cogs.information") bot.load_extension("bot.cogs.jams") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py new file mode 100644 index 000000000..d5d528458 --- /dev/null +++ b/bot/cogs/duck_pond.py @@ -0,0 +1,206 @@ +import logging +from typing import List, Optional, Union + +import discord +from discord import Color, Embed, Member, Message, PartialEmoji, RawReactionActionEvent, Reaction, User, errors +from discord.ext.commands import Bot, Cog + +import bot.constants as constants +from bot.utils.messages import send_attachments + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.log = log + self.webhook_id = constants.Webhooks.duck_pond + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self): + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_ready() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + @staticmethod + def has_green_checkmark(message: Optional[Message] = None, reaction_list: Optional[List[Reaction]] = None) -> bool: + """Check if the message has a green checkmark reaction.""" + assert message or reaction_list, "You can either pass message or reactions, but not both, or neither." + + if message: + reactions = message.reactions + else: + reactions = reaction_list + + for reaction in reactions: + if isinstance(reaction.emoji, str): + if reaction.emoji == "✅": + return True + elif isinstance(reaction.emoji, PartialEmoji): + if reaction.emoji.name == "✅": + return True + return False + + async def send_webhook( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + try: + await self.webhook.send( + content=content, + username=username, + avatar_url=avatar_url, + embed=embed + ) + except discord.HTTPException as exc: + self.log.exception( + f"Failed to send a message to the Duck Pool webhook", + exc_info=exc + ) + + async def count_ducks(self, message: Optional[Message] = None, reaction_list: Optional[List[Reaction]] = None) -> int: + """Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + assert message or reaction_list, "You can either pass message or reactions, but not both, or neither." + + duck_count = 0 + duck_reactors = [] + + if message: + reactions = message.reactions + else: + reactions = reaction_list + + for reaction in reactions: + async for user in reaction.users(): + + # Is the user or member a staff member? + if self.is_staff(user) and user.id not in duck_reactors: + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.duck_custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + else: + if isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, PartialEmoji): + if reaction.emoji.name == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/ducks_required, it will + send the message off to the duck pond. + """ + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a staff member? + if not self.is_staff(member): + return + + # Bot reactions don't count. + if member.bot: + return + + # Is the emoji in the reaction a duck? + if payload.emoji.is_custom_emoji(): + if payload.emoji.id not in constants.DuckPond.duck_custom_emojis: + return + else: + if payload.emoji.name != "🦆": + return + + # Does the message already have a green checkmark? + if self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.ducks_required: + clean_content = message.clean_content + + if clean_content: + await self.send_webhook( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.send_webhook( + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException as exc: + self.log.exception( + f"Failed to send an attachment to the webhook", + exc_info=exc + ) + await message.add_reaction("✅") + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + + # Prevent the green checkmark from being removed + if isinstance(payload.emoji, str): + if payload.emoji == "✅": + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.ducks_required: + await message.add_reaction("✅") + + elif isinstance(payload.emoji, PartialEmoji): + if payload.emoji.name == "✅": + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.ducks_required: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Token Remover cog load.""" + bot.add_cog(DuckPond(bot)) + log.info("Cog loaded: DuckPond") -- cgit v1.2.3 From 38579ade38a7390e9aca428410a9703dd7ba9fac Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:07:55 +0100 Subject: Appease the linter --- bot/cogs/duck_pond.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index d5d528458..b2b786a3f 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -5,7 +5,7 @@ import discord from discord import Color, Embed, Member, Message, PartialEmoji, RawReactionActionEvent, Reaction, User, errors from discord.ext.commands import Bot, Cog -import bot.constants as constants +from bot import constants from bot.utils.messages import send_attachments log = logging.getLogger(__name__) @@ -20,7 +20,7 @@ class DuckPond(Cog): self.webhook_id = constants.Webhooks.duck_pond self.bot.loop.create_task(self.fetch_webhook()) - async def fetch_webhook(self): + async def fetch_webhook(self) -> None: """Fetches the webhook object, so we can post to it.""" await self.bot.wait_until_ready() @@ -31,7 +31,7 @@ class DuckPond(Cog): @staticmethod def is_staff(member: Union[User, Member]) -> bool: - """Check if a specific member or user is staff""" + """Check if a specific member or user is staff.""" if hasattr(member, "roles"): for role in member.roles: if role.id in constants.STAFF_ROLES: @@ -64,6 +64,7 @@ class DuckPond(Cog): avatar_url: Optional[str] = None, embed: Optional[Embed] = None, ) -> None: + """Send a webhook to the duck_pond channel.""" try: await self.webhook.send( content=content, @@ -77,8 +78,13 @@ class DuckPond(Cog): exc_info=exc ) - async def count_ducks(self, message: Optional[Message] = None, reaction_list: Optional[List[Reaction]] = None) -> int: - """Count the number of ducks in the reactions of a specific message. + async def count_ducks( + self, + message: Optional[Message] = None, + reaction_list: Optional[List[Reaction]] = None + ) -> int: + """ + Count the number of ducks in the reactions of a specific message. Only counts ducks added by staff members. """ @@ -116,7 +122,8 @@ class DuckPond(Cog): @Cog.listener() async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: - """Determine if a message should be sent to the duck pond. + """ + Determine if a message should be sent to the duck pond. This will count the number of duck reactions on the message, and if this amount meets the amount of ducks specified in the config under duck_pond/ducks_required, it will -- cgit v1.2.3 From 0957bee71fafc575c99ac107f941e9bfe6a72397 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:16:41 +0100 Subject: Add correct values for constants from production server. --- config-default.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config-default.yml b/config-default.yml index bad9c72db..074143f92 100644 --- a/config-default.yml +++ b/config-default.yml @@ -96,7 +96,7 @@ guild: defcon: &DEFCON 464469101889454091 devlog: &DEVLOG 622895325144940554 devtest: &DEVTEST 414574275865870337 - duck_pond: &DUCK_POND 000000000000000000 + duck_pond: &DUCK_POND 637820308341915648 help_0: 303906576991780866 help_1: 303906556754395136 help_2: 303906514266226689 @@ -147,7 +147,7 @@ guild: talent_pool: 569145364800602132 big_brother: 569133704568373283 reddit: 635408384794951680 - duck_pond: 637779355304722435 + duck_pond: 637821475327311927 filter: -- cgit v1.2.3 From c66cd6f236c9f8c68a39caafdbbba1f5724947a5 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:26:52 +0100 Subject: Fix broken constant tests --- bot/constants.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/bot/constants.py b/bot/constants.py index ed1e65cca..d626fd4ba 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -252,11 +252,6 @@ class Emojis(metaclass=YAMLGetter): defcon_enabled: str # noqa: E704 defcon_updated: str # noqa: E704 - green_chevron: str - red_chevron: str - white_chevron: str - bb_message: str - status_online: str status_offline: str status_idle: str @@ -267,6 +262,9 @@ class Emojis(metaclass=YAMLGetter): pencil: str cross_mark: str + ducky: int + ducky_blurple: int + class Icons(metaclass=YAMLGetter): section = "style" @@ -344,6 +342,7 @@ class Channels(metaclass=YAMLGetter): defcon: int devlog: int devtest: int + duck_pond: int help_0: int help_1: int help_2: int -- cgit v1.2.3 From dac975e8bf238b545b60f10cd1891a68f31dc1ef Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 02:28:03 +0100 Subject: Improve the setup() docstring Co-Authored-By: Mark --- bot/cogs/duck_pond.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index b2b786a3f..6244bdf5a 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -208,6 +208,6 @@ class DuckPond(Cog): def setup(bot: Bot) -> None: - """Token Remover cog load.""" + """Load the duck pond cog.""" bot.add_cog(DuckPond(bot)) log.info("Cog loaded: DuckPond") -- cgit v1.2.3 From 51622223f4173e35a90d2a306a61e020fc0b422b Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 15:16:35 +0100 Subject: Addressing review by Mark. This refactors the duck pond cog to have fewer redundancies, removes some unused features (like supporting reaction_list in the count_duck and has_green_checkbox helpers), and makes other various minor (mostly cosmetic) improvements. --- bot/cogs/duck_pond.py | 120 ++++++++++++++++---------------------------------- bot/constants.py | 5 +-- config-default.yml | 6 +-- 3 files changed, 42 insertions(+), 89 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index b2b786a3f..70cf0d2b0 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -1,8 +1,8 @@ import logging -from typing import List, Optional, Union +from typing import Optional, Union import discord -from discord import Color, Embed, Member, Message, PartialEmoji, RawReactionActionEvent, Reaction, User, errors +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors from discord.ext.commands import Bot, Cog from bot import constants @@ -16,7 +16,6 @@ class DuckPond(Cog): def __init__(self, bot: Bot): self.bot = bot - self.log = log self.webhook_id = constants.Webhooks.duck_pond self.bot.loop.create_task(self.fetch_webhook()) @@ -27,7 +26,7 @@ class DuckPond(Cog): try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) except discord.HTTPException: - self.log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") @staticmethod def is_staff(member: Union[User, Member]) -> bool: @@ -39,22 +38,11 @@ class DuckPond(Cog): return False @staticmethod - def has_green_checkmark(message: Optional[Message] = None, reaction_list: Optional[List[Reaction]] = None) -> bool: + def has_green_checkmark(message: Message) -> bool: """Check if the message has a green checkmark reaction.""" - assert message or reaction_list, "You can either pass message or reactions, but not both, or neither." - - if message: - reactions = message.reactions - else: - reactions = reaction_list - - for reaction in reactions: - if isinstance(reaction.emoji, str): - if reaction.emoji == "✅": - return True - elif isinstance(reaction.emoji, PartialEmoji): - if reaction.emoji.name == "✅": - return True + for reaction in message.reactions: + if reaction.emoji == "✅": + return True return False async def send_webhook( @@ -72,52 +60,34 @@ class DuckPond(Cog): avatar_url=avatar_url, embed=embed ) - except discord.HTTPException as exc: - self.log.exception( - f"Failed to send a message to the Duck Pool webhook", - exc_info=exc - ) + except discord.HTTPException: + log.exception(f"Failed to send a message to the Duck Pool webhook") - async def count_ducks( - self, - message: Optional[Message] = None, - reaction_list: Optional[List[Reaction]] = None - ) -> int: + async def count_ducks(self, message: Message) -> int: """ Count the number of ducks in the reactions of a specific message. Only counts ducks added by staff members. """ - assert message or reaction_list, "You can either pass message or reactions, but not both, or neither." - duck_count = 0 duck_reactors = [] - if message: - reactions = message.reactions - else: - reactions = reaction_list - - for reaction in reactions: + for reaction in message.reactions: async for user in reaction.users(): # Is the user or member a staff member? - if self.is_staff(user) and user.id not in duck_reactors: - - # Is the emoji a duck? - if hasattr(reaction.emoji, "id"): - if reaction.emoji.id in constants.DuckPond.duck_custom_emojis: - duck_count += 1 - duck_reactors.append(user.id) - else: - if isinstance(reaction.emoji, str): - if reaction.emoji == "🦆": - duck_count += 1 - duck_reactors.append(user.id) - elif isinstance(reaction.emoji, PartialEmoji): - if reaction.emoji.name == "🦆": - duck_count += 1 - duck_reactors.append(user.id) + if not self.is_staff(user) or not user.id not in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) return duck_count @Cog.listener() @@ -126,28 +96,23 @@ class DuckPond(Cog): Determine if a message should be sent to the duck pond. This will count the number of duck reactions on the message, and if this amount meets the - amount of ducks specified in the config under duck_pond/ducks_required, it will + amount of ducks specified in the config under duck_pond/threshold, it will send the message off to the duck pond. """ channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) message = await channel.fetch_message(payload.message_id) member = discord.utils.get(message.guild.members, id=payload.user_id) - # Is the member a staff member? - if not self.is_staff(member): - return - - # Bot reactions don't count. - if member.bot: + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: return # Is the emoji in the reaction a duck? if payload.emoji.is_custom_emoji(): - if payload.emoji.id not in constants.DuckPond.duck_custom_emojis: - return - else: - if payload.emoji.name != "🦆": + if payload.emoji.id not in constants.DuckPond.custom_emojis: return + elif payload.emoji.name != "🦆": + return # Does the message already have a green checkmark? if self.has_green_checkmark(message): @@ -157,7 +122,7 @@ class DuckPond(Cog): duck_count = await self.count_ducks(message) # If we've got more than the required amount of ducks, send the message to the duck_pond. - if duck_count >= constants.DuckPond.ducks_required: + if duck_count >= constants.DuckPond.threshold: clean_content = message.clean_content if clean_content: @@ -180,31 +145,22 @@ class DuckPond(Cog): username=message.author.display_name, avatar_url=message.author.avatar_url ) - except discord.HTTPException as exc: - self.log.exception( - f"Failed to send an attachment to the webhook", - exc_info=exc - ) + except discord.HTTPException: + log.exception(f"Failed to send an attachment to the webhook") + await message.add_reaction("✅") @Cog.listener() async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: """Ensure that people don't remove the green checkmark from duck ponded messages.""" channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) - message = await channel.fetch_message(payload.message_id) # Prevent the green checkmark from being removed - if isinstance(payload.emoji, str): - if payload.emoji == "✅": - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.ducks_required: - await message.add_reaction("✅") - - elif isinstance(payload.emoji, PartialEmoji): - if payload.emoji.name == "✅": - duck_count = await self.count_ducks(message) - if duck_count >= constants.DuckPond.ducks_required: - await message.add_reaction("✅") + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") def setup(bot: Bot) -> None: diff --git a/bot/constants.py b/bot/constants.py index d626fd4ba..79845711d 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -239,9 +239,8 @@ class Colours(metaclass=YAMLGetter): class DuckPond(metaclass=YAMLGetter): section = "duck_pond" - ducks_required: int - duck_custom_emojis: List[int] - duck_pond_channel: int + threshold: int + custom_emojis: List[int] class Emojis(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 074143f92..59087f51f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -96,7 +96,6 @@ guild: defcon: &DEFCON 464469101889454091 devlog: &DEVLOG 622895325144940554 devtest: &DEVTEST 414574275865870337 - duck_pond: &DUCK_POND 637820308341915648 help_0: 303906576991780866 help_1: 303906556754395136 help_2: 303906514266226689 @@ -383,9 +382,8 @@ redirect_output: delete_delay: 15 duck_pond: - ducks_required: 5 - duck_custom_emojis: [*DUCKY_EMOJI, *DUCKY_BLURPLE_EMOJI] - duck_pond_channel: *DUCK_POND + threshold: 5 + custom_emojis: [*DUCKY_EMOJI, *DUCKY_BLURPLE_EMOJI] config: required_keys: ['bot.token'] -- cgit v1.2.3 From 08d0c46c4aca4f181e60ef4cf6aa9a450c0db200 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 27 Oct 2019 15:38:09 +0100 Subject: Adding kosas additional ducks to default-config --- config-default.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/config-default.yml b/config-default.yml index 59087f51f..76892677e 100644 --- a/config-default.yml +++ b/config-default.yml @@ -32,8 +32,13 @@ style: new: "\U0001F195" cross_mark: "\u274C" - ducky: &DUCKY_EMOJI 574951975574175744 - ducky_blurple: &DUCKY_BLURPLE_EMOJI 574951975310065675 + ducky_yellow: &DUCKY_YELLOW 574951975574175744 + ducky_blurple: &DUCKY_BLURPLE 574951975310065675 + ducky_regal: &DUCKY_REGAL 637883439185395712 + ducky_camo: &DUCKY_CAMO 637914731566596096 + ducky_ninja: &DUCKY_NINJA 637923502535606293 + ducky_devil: &DUCKY_DEVIL 637925314982576139 + ducky_tube: &DUCKY_TUBE 637881368008851456 icons: crown_blurple: "https://cdn.discordapp.com/emojis/469964153289965568.png" @@ -383,7 +388,7 @@ redirect_output: duck_pond: threshold: 5 - custom_emojis: [*DUCKY_EMOJI, *DUCKY_BLURPLE_EMOJI] + custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE] config: required_keys: ['bot.token'] -- cgit v1.2.3 From a98485e173208cc2272f5c6355ddaf5858050403 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Thu, 31 Oct 2019 07:39:03 +0100 Subject: Figure out which tests we need. This adds empty tests for all the tests I'd like to add to this pull request. It also adds a few more duckies to the emoji constant list, and adds a single line of clarification to the testing readme. --- bot/constants.py | 8 +++- tests/README.md | 1 + tests/bot/cogs/test_duck_pond.py | 80 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 tests/bot/cogs/test_duck_pond.py diff --git a/bot/constants.py b/bot/constants.py index 79845711d..dbbf32063 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -261,8 +261,13 @@ class Emojis(metaclass=YAMLGetter): pencil: str cross_mark: str - ducky: int + ducky_yellow: int ducky_blurple: int + ducky_regal: int + ducky_camo: int + ducky_ninja: int + ducky_devil: int + ducky_tube: int class Icons(metaclass=YAMLGetter): @@ -341,7 +346,6 @@ class Channels(metaclass=YAMLGetter): defcon: int devlog: int devtest: int - duck_pond: int help_0: int help_1: int help_2: int diff --git a/tests/README.md b/tests/README.md index 6ab9bc93e..d052de2f6 100644 --- a/tests/README.md +++ b/tests/README.md @@ -15,6 +15,7 @@ We are using the following modules and packages for our unit tests: To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts: - `pipenv run test` will run `unittest` with `coverage.py` +- `pipenv run test path/to/test.py` will run a specific test. - `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report. If you want a coverage report, make sure to run the tests with `pipenv run test` *first*. diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py new file mode 100644 index 000000000..79f11843b --- /dev/null +++ b/tests/bot/cogs/test_duck_pond.py @@ -0,0 +1,80 @@ +import logging +import unittest +from unittest.mock import MagicMock + +from bot.cogs import duck_pond +from tests.helpers import MockBot, MockMessage + + +class DuckPondTest(unittest.TestCase): + """Tests the `DuckPond` cog.""" + + def setUp(self): + """Adds the cog, a bot, and a message to the instance for usage in tests.""" + self.bot = MockBot() + self.cog = duck_pond.DuckPond(bot=self.bot) + + self.msg = MockMessage(message_id=555, content='') + self.msg.author.__str__ = MagicMock() + self.msg.author.__str__.return_value = 'lemon' + self.msg.author.bot = False + self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' + self.msg.author.id = 42 + self.msg.author.mention = '@lemon' + self.msg.channel.mention = "#lemonade-stand" + + def test_is_staff_correctly_identifies_staff(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_has_green_checkmark(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_count_custom_duck_emojis(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_count_unicode_duck_emojis(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_count_mixed_duck_emojis(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_raw_reaction_add_rejects_bot(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_raw_reaction_add_rejects_non_staff(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_raw_reaction_add_sends_message_on_valid_input(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_raw_reaction_remove_rejects_non_checkmarks(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + def test_raw_reaction_remove_prevents_checkmark_removal(self): + """A string decoding to numeric characters is a valid user ID.""" + pass + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the cog should log a message at `INFO` level.""" + bot = MockBot() + log = logging.getLogger('bot.cogs.duck_pond') + + with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: + duck_pond.setup(bot) + line = log_watcher.output[0] + + bot.add_cog.assert_called_once() + self.assertIn("Cog loaded: DuckPond", line) -- cgit v1.2.3 From efe592cc0420f325ab266afc822b8d4b8135d467 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sat, 2 Nov 2019 17:26:50 +0100 Subject: Do not cut off description in code blocks --- bot/cogs/doc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 8b81b3053..4a095fa51 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -284,7 +284,13 @@ class Doc(commands.Cog): if len(description) > 1000: shortened = description[:1000] last_paragraph_end = shortened.rfind('\n\n') - description = description[:last_paragraph_end] + f"... [read more]({permalink})" + description = description[:last_paragraph_end] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) -- cgit v1.2.3 From 82e1f3764ba0d102ede007ba6352406cfe3fb82a Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sat, 2 Nov 2019 17:37:42 +0100 Subject: Get symbol description by searching for a dd tag instead of traversing the siblings --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 4a095fa51..96f737c03 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -260,7 +260,7 @@ class Doc(commands.Cog): if tag not in UNWANTED_SIGNATURE_SYMBOLS: signature_buffer.append(tag.replace('\\', '')) signature = ''.join(signature_buffer) - description = str(symbol_heading.next_sibling.next_sibling).replace('¶', '') + description = str(symbol_heading.find_next_sibling("dd")).replace('¶', '') return signature, description -- cgit v1.2.3 From ae8c862a353ddc10593d36d557fc7215232baf5b Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sat, 2 Nov 2019 17:44:02 +0100 Subject: Get up to 3 signatures of a symbol --- bot/cogs/doc.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 96f737c03..2987f7245 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -36,7 +36,7 @@ NO_OVERRIDE_GROUPS = ( NO_OVERRIDE_PACKAGES = ( "Python", ) -UNWANTED_SIGNATURE_SYMBOLS = ('[source]', '¶') +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") @@ -218,7 +218,7 @@ class Doc(commands.Cog): ] await asyncio.gather(*coros) - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[str, str]]: + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: """ Given a Python symbol, return its signature and description. @@ -239,7 +239,7 @@ class Doc(commands.Cog): symbol_id = url.split('#')[-1] soup = BeautifulSoup(html, 'lxml') symbol_heading = soup.find(id=symbol_id) - signature_buffer = [] + signatures = [] if symbol_heading is None: return None @@ -253,16 +253,14 @@ class Doc(commands.Cog): description = ''.join(str(paragraph) for paragraph in info_paragraphs).replace('¶', '') else: - # Traverse the tags of the signature header and ignore any - # unwanted symbols from it. Add all of it to a temporary buffer. - - for tag in symbol_heading.strings: - if tag not in UNWANTED_SIGNATURE_SYMBOLS: - signature_buffer.append(tag.replace('\\', '')) - signature = ''.join(signature_buffer) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature: + signatures.append(signature) description = str(symbol_heading.find_next_sibling("dd")).replace('¶', '') - return signature, description + return signatures, description @async_cache(arg_offset=1) async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: @@ -275,7 +273,7 @@ class Doc(commands.Cog): if scraped_html is None: return None - signature = scraped_html[0] + signatures = scraped_html[0] permalink = self.inventories[symbol] description = markdownify(scraped_html[1]) @@ -294,18 +292,18 @@ class Doc(commands.Cog): description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if signature is None: + if signatures is None: # If symbol is a module, don't show signature. embed_description = description - elif not signature: + elif not signatures: # It's some "meta-page", for example: # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views embed_description = "This appears to be a generic page not tied to a specific symbol." else: - signature = textwrap.shorten(signature, 500) - embed_description = f"```py\n{signature}```{description}" + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += description embed = discord.Embed( title=f'`{symbol}`', -- cgit v1.2.3 From 1aed2e4f4996f5546652bbb26e8fbf403e28aac4 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sat, 2 Nov 2019 18:28:04 +0100 Subject: Improve module description searching --- bot/cogs/doc.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 2987f7245..30a14f26c 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional, Tuple import discord from bs4 import BeautifulSoup -from bs4.element import PageElement +from bs4.element import PageElement, Tag from discord.errors import NotFound from discord.ext import commands from markdownify import MarkdownConverter @@ -37,6 +37,16 @@ NO_OVERRIDE_PACKAGES = ( "Python", ) UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") @@ -245,12 +255,21 @@ class Doc(commands.Cog): return None if symbol_id == f"module-{symbol}": - # Get all paragraphs until the first div after the section div - # if searched symbol is a module. - trailing_div = symbol_heading.findNext("div") - info_paragraphs = trailing_div.find_previous_siblings("p")[::-1] - signature = None - description = ''.join(str(paragraph) for paragraph in info_paragraphs).replace('¶', '') + search_html = str(soup) + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index].replace('¶', '') + signatures = None else: # Get text of up to 3 signatures, remove unwanted symbols @@ -422,6 +441,15 @@ class Doc(commands.Cog): await self.refresh_inventory() await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + def setup(bot: commands.Bot) -> None: """Doc cog load.""" -- cgit v1.2.3 From ea5b01d1369faae5485e8514fb38f7ca8d9a24cc Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Sun, 3 Nov 2019 17:54:24 +0700 Subject: Refactor Using ternary to avoid if else --- bot/cogs/moderation/modlog.py | 96 +++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 64 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 88f2b6c67..347b820de 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -635,39 +635,23 @@ class ModLog(Cog, name="ModLog"): author = before.author channel = before.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + before_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{before.id}`\n" + "\n" + f"{before.clean_content}" + ) - if channel.category: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{before.clean_content}" - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{after.clean_content}" - ) - else: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{before.clean_content}" - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{after.clean_content}" - ) + after_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{before.id}`\n" + "\n" + f"{after.clean_content}" + ) if before.edited_at: # Message was previously edited, to assist with self-bot detection, use the edited_at @@ -718,39 +702,23 @@ class ModLog(Cog, name="ModLog"): author = message.author channel = message.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + before_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) - if channel.category: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) - else: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) + after_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + f"{message.clean_content}" + ) await self.send_log_message( Icons.message_edit, Colour.blurple(), "Message edited (Before)", -- cgit v1.2.3 From 2c19c1c7e1a433c639b6fba8aa10ad744e3827db Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Sun, 3 Nov 2019 18:05:55 +0700 Subject: Merge before & after response, show only differences - Merged `before_response` and `after_response`. - Only show the differences between `before.clean_content` and `after.clean_content` - Included a `jump to message` link. --- bot/cogs/moderation/modlog.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 347b820de..92b399874 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -1,4 +1,6 @@ import asyncio +import difflib +import itertools import logging import typing as t from datetime import datetime @@ -637,20 +639,36 @@ class ModLog(Cog, name="ModLog"): channel = before.channel channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - before_response = ( + _before = before.clean_content + _after = after.clean_content + groups = tuple((g[0], tuple(g[1])) + for g in itertools.groupby(difflib.ndiff(_before.split(), _after.split()), key=lambda s: s[0])) + + for index, (name, values) in enumerate(groups): + sub = ' '.join(s[2:].strip() for s in values) + if name == '-': + _before = _before.replace(sub, f"[{sub}](http://.z)") + elif name == '+': + _after = _after.replace(sub, f"[{sub}](http://.z)") + else: + if len(values) > 2: + new = (f"{values[0].strip() if index > 0 else ''}" + " ... " + f"{values[-1].strip() if index < len(groups) - 1 else ''}") + else: + new = sub + _before = _before.replace(sub, new) + _after = _after.replace(sub, new) + + response = ( f"**Author:** {author} (`{author.id}`)\n" f"**Channel:** {channel_name} (`{channel.id}`)\n" f"**Message ID:** `{before.id}`\n" "\n" - f"{before.clean_content}" - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" + f"**Before**:\n{_before}\n" + f"**After**:\n{_after}\n" "\n" - f"{after.clean_content}" + f"[jump to message]({after.jump_url})" ) if before.edited_at: @@ -667,15 +685,10 @@ class ModLog(Cog, name="ModLog"): footer = None await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (Before)", before_response, + Icons.message_edit, Colour.blurple(), "Message edited", response, channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer ) - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (After)", after_response, - channel_id=Channels.message_log, timestamp_override=after.edited_at - ) - @Cog.listener() async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: """Log raw message edit event to message change log.""" -- cgit v1.2.3 From 3140b01bff9c4912b9f89589e3b3f200dbad99ee Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 3 Nov 2019 18:36:16 +0100 Subject: Handle exceptions when fetching inventories --- bot/cogs/doc.py | 88 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 30a14f26c..55b69e9a4 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -13,8 +13,9 @@ from bs4.element import PageElement, Tag from discord.errors import NotFound from discord.ext import commands from markdownify import MarkdownConverter -from requests import ConnectionError +from requests import ConnectTimeout, ConnectionError, HTTPError from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import ValidPythonIdentifier, ValidURL @@ -36,6 +37,7 @@ NO_OVERRIDE_GROUPS = ( NO_OVERRIDE_PACKAGES = ( "Python", ) +FAILED_REQUEST_RETRY_AMOUNT = 3 UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") SEARCH_END_TAG_ATTRS = ( "data", @@ -173,36 +175,37 @@ class Doc(commands.Cog): """ self.base_urls[package_name] = base_url - fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) - for group, value in (await self.bot.loop.run_in_executor(None, fetch_func)).items(): - # Each value has a bunch of information in the form - # `(package_name, version, relative_url, ???)`, and we only - # need the package_name and the relative documentation URL. - for symbol, (package_name, _, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - # get `group_name` from _:group_name - group_name = group.split(":")[1] - if (group_name in NO_OVERRIDE_GROUPS - # check if any package from `NO_OVERRIDE_PACKAGES` - # is in base URL of the symbol that would be overridden - or any(package in self.inventories[symbol].split("/", 3)[2] - for package in NO_OVERRIDE_PACKAGES)): - - symbol = f"{group_name}.{symbol}" - # if renamed `symbol` was already exists, add library name in front - if symbol in self.renamed_symbols: - # split `package_name` because of packages like Pillow that have spaces in them - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") + package = await self._fetch_inventory(inventory_url, config) + if package: + for group, value in package.items(): + # Each value has a bunch of information in the form + # `(package_name, version, relative_url, ???)`, and we only + # need the package_name and the relative documentation URL. + for symbol, (package_name, _, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + # get `group_name` from _:group_name + group_name = group.split(":")[1] + if (group_name in NO_OVERRIDE_GROUPS + # check if any package from `NO_OVERRIDE_PACKAGES` + # is in base URL of the symbol that would be overridden + or any(package in self.inventories[symbol].split("/", 3)[2] + for package in NO_OVERRIDE_PACKAGES)): + + symbol = f"{group_name}.{symbol}" + # if renamed `symbol` was already exists, add library name in front + if symbol in self.renamed_symbols: + # split `package_name` because of packages like Pillow that have spaces in them + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") async def refresh_inventory(self) -> None: """Refresh internal documentation inventory.""" @@ -441,6 +444,29 @@ class Doc(commands.Cog): await self.refresh_inventory() await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + async def _fetch_inventory(self, inventory_url: str, config: SphinxConfiguration) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error(f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})") + except ProtocolError: + log.error(f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})") + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + @staticmethod def _match_end_tag(tag: Tag) -> bool: """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" -- cgit v1.2.3 From a8475f5fedb91c9e0f1c5c28c7d64aebbbef64f4 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 3 Nov 2019 20:06:15 +0100 Subject: Fix case for the python package name in `NO_OVERRIDE_PACKAGES` --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 55b69e9a4..563f83040 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -35,7 +35,7 @@ NO_OVERRIDE_GROUPS = ( "term", ) NO_OVERRIDE_PACKAGES = ( - "Python", + "python", ) FAILED_REQUEST_RETRY_AMOUNT = 3 UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -- cgit v1.2.3 From 1b0a8c8109240615e5d9309937a434e1d29bcf24 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 3 Nov 2019 20:06:41 +0100 Subject: Comment grammar --- bot/cogs/doc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 563f83040..934cb2a6d 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -194,9 +194,9 @@ class Doc(commands.Cog): for package in NO_OVERRIDE_PACKAGES)): symbol = f"{group_name}.{symbol}" - # if renamed `symbol` was already exists, add library name in front + # If renamed `symbol` already exists, add library name in front. if symbol in self.renamed_symbols: - # split `package_name` because of packages like Pillow that have spaces in them + # Split `package_name` because of packages like Pillow that have spaces in them. symbol = f"{package_name.split()[0]}.{symbol}" self.inventories[symbol] = absolute_doc_url -- cgit v1.2.3 From 254dfbb616651f875936598c9884761921de7b76 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 3 Nov 2019 20:28:07 +0100 Subject: Make sure only signatures belonging to the symbol are fetched --- bot/cogs/doc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 934cb2a6d..dcbcfe3ad 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -253,12 +253,12 @@ class Doc(commands.Cog): soup = BeautifulSoup(html, 'lxml') symbol_heading = soup.find(id=symbol_id) signatures = [] + search_html = str(soup) if symbol_heading is None: return None if symbol_id == f"module-{symbol}": - search_html = str(soup) # Get page content from the module headerlink to the # first tag that has its class in `SEARCH_END_TAG_ATTRS` start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) @@ -275,12 +275,13 @@ class Doc(commands.Cog): signatures = None else: + description = str(symbol_heading.find_next_sibling("dd")).replace('¶', '') + description_pos = search_html.find(description) # Get text of up to 3 signatures, remove unwanted symbols for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature: + if signature and search_html.find(signature) < description_pos: signatures.append(signature) - description = str(symbol_heading.find_next_sibling("dd")).replace('¶', '') return signatures, description -- cgit v1.2.3 From acb937dc24b30c84b4978f651a642208f562c36e Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sun, 3 Nov 2019 22:34:16 +0100 Subject: Test is_staff and has_green_checkmark. --- tests/bot/cogs/test_duck_pond.py | 52 +++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 79f11843b..31c7e9f89 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -1,35 +1,53 @@ import logging import unittest -from unittest.mock import MagicMock from bot.cogs import duck_pond -from tests.helpers import MockBot, MockMessage +from tests.helpers import MockBot, MockMember, MockMessage, MockReaction, MockRole class DuckPondTest(unittest.TestCase): """Tests the `DuckPond` cog.""" def setUp(self): - """Adds the cog, a bot, and a message to the instance for usage in tests.""" + """Adds the cog, a bot, and the mocks we'll need for our tests.""" self.bot = MockBot() self.cog = duck_pond.DuckPond(bot=self.bot) - self.msg = MockMessage(message_id=555, content='') - self.msg.author.__str__ = MagicMock() - self.msg.author.__str__.return_value = 'lemon' - self.msg.author.bot = False - self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' - self.msg.author.id = 42 - self.msg.author.mention = '@lemon' - self.msg.channel.mention = "#lemonade-stand" + # Set up some roles + self.admin_role = MockRole(name="Admins", role_id=476190234653229056) + self.contrib_role = MockRole(name="Contributor", role_id=476190302659543061) - def test_is_staff_correctly_identifies_staff(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + # Set up some users + self.admin_member = MockMember(roles=(self.admin_role,)) + self.contrib_member = MockMember(roles=(self.contrib_role,)) + self.no_role_member = MockMember() - def test_has_green_checkmark(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + # Set up emojis + self.checkmark_emoji = "✅" + self.thumbs_up_emoji = "👍" + + # Set up reactions + self.checkmark_reaction = MockReaction(emoji=self.checkmark_emoji) + self.thumbs_up_reaction = MockReaction(emoji=self.thumbs_up_emoji) + + # Set up a messages + self.checkmark_message = MockMessage(reactions=(self.checkmark_reaction,)) + self.thumbs_up_message = MockMessage(reactions=(self.thumbs_up_reaction,)) + self.no_reaction_message = MockMessage() + + def test_is_staff_correctly_identifies_staff(self): + """Test that is_staff correctly identifies a staff member.""" + with self.subTest(): + self.assertTrue(duck_pond.DuckPond.is_staff(self.admin_member)) + self.assertFalse(duck_pond.DuckPond.is_staff(self.contrib_member)) + self.assertFalse(duck_pond.DuckPond.is_staff(self.no_role_member)) + + def test_has_green_checkmark_correctly_identifies_messages(self): + """Test that has_green_checkmark recognizes messages with checkmarks.""" + with self.subTest(): + self.assertTrue(duck_pond.DuckPond.has_green_checkmark(self.checkmark_message)) + self.assertFalse(duck_pond.DuckPond.has_green_checkmark(self.thumbs_up_message)) + self.assertFalse(duck_pond.DuckPond.has_green_checkmark(self.no_reaction_message)) def test_count_custom_duck_emojis(self): """A string decoding to numeric characters is a valid user ID.""" -- cgit v1.2.3 From 071bc9d775a48750cc8d44236523c7ec5a30f7f9 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Fri, 25 Oct 2019 17:29:28 -0700 Subject: Add logging for moderation functions --- bot/cogs/moderation/scheduler.py | 51 ++++++++++++++++++++++++++++++++++--- bot/cogs/moderation/superstarify.py | 21 +++++++++++---- bot/cogs/moderation/utils.py | 12 +++++++++ bot/utils/scheduling.py | 6 ++++- 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 7990df226..7a08fc236 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -39,6 +39,8 @@ class InfractionScheduler(Scheduler): """Schedule expiration for previous infractions.""" await self.bot.wait_until_ready() + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + infractions = await self.bot.api_client.get( 'bot/infractions', params={'active': 'true'} @@ -59,6 +61,10 @@ class InfractionScheduler(Scheduler): # Mark as inactive if less than a minute remains. if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) await self.deactivate_infraction(infraction) return @@ -78,6 +84,9 @@ class InfractionScheduler(Scheduler): icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] expiry = infraction["expires_at"] + _id = infraction['id'] + + log.trace(f"Applying {infr_type} infraction #{_id} to {user}.") if expiry: expiry = time.format_infraction(expiry) @@ -111,10 +120,20 @@ class InfractionScheduler(Scheduler): log_content = ctx.author.mention if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{_id} actor is bot; including the reason in the confirmation message." + ) + end_msg = f" (reason: {infraction['reason']})" elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{_id} context is not in a staff channel; omitting infraction count." + ) + end_msg = "" else: + log.trace(f"Fetching total infraction count for {user}.") + infractions = await self.bot.api_client.get( "bot/infractions", params={"user__id": str(user.id)} @@ -124,6 +143,7 @@ class InfractionScheduler(Scheduler): # Execute the necessary actions to apply the infraction on Discord. if action_coro: + log.trace(f"Awaiting the infraction #{_id} application action coroutine.") try: await action_coro if expiry: @@ -136,12 +156,16 @@ class InfractionScheduler(Scheduler): log_content = ctx.author.mention log_title = "failed to apply" + log.warning(f"Failed to apply {infr_type} infraction #{_id} to {user}.") + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{_id} confirmation message.") await ctx.send( f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." ) # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{_id}.") await self.mod_log.send_log_message( icon_url=icon, colour=Colours.soft_red, @@ -157,9 +181,14 @@ class InfractionScheduler(Scheduler): footer=f"ID {infraction['id']}" ) + log.info(f"Applied {infr_type} infraction #{_id} to {user}.") + async def pardon_infraction(self, ctx: Context, infr_type: str, user: MemberObject) -> None: """Prematurely end an infraction for a user and log the action in the mod log.""" + log.trace(f"Pardoning {infr_type} infraction for {user}.") + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") response = await self.bot.api_client.get( 'bot/infractions', params={ @@ -170,6 +199,7 @@ class InfractionScheduler(Scheduler): ) if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") return @@ -179,12 +209,16 @@ class InfractionScheduler(Scheduler): log_text["Member"] = f"{user.mention}(`{user.id}`)" log_text["Actor"] = str(ctx.message.author) log_content = None - footer = f"ID: {response[0]['id']}" + _id = response[0]['id'] + footer = f"ID: {_id}" # If multiple active infractions were found, mark them as inactive in the database # and cancel their expiration tasks. if len(response) > 1: - log.warning(f"Found more than one active {infr_type} infraction for user {user.id}") + log.warning( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" @@ -227,11 +261,16 @@ class InfractionScheduler(Scheduler): confirm_msg = ":x: failed to pardon" log_title = "pardon failed" log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{_id} for {user}.") else: confirm_msg = f":ok_hand: pardoned" log_title = "pardoned" + log.info(f"Pardoned {infr_type} infraction #{_id} for {user}.") + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{_id} pardon confirmation message.") await ctx.send( f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " f"{log_text.get('Failure', '')}" @@ -268,7 +307,7 @@ class InfractionScheduler(Scheduler): _type = infraction["type"] _id = infraction["id"] - log.debug(f"Marking infraction #{_id} as inactive (expired).") + log.info(f"Marking infraction #{_id} as inactive (expired).") log_content = None log_text = { @@ -278,7 +317,9 @@ class InfractionScheduler(Scheduler): } try: + log.trace("Awaiting the pardon action coroutine.") returned_log = await self._pardon_action(infraction) + if returned_log is not None: log_text = {**log_text, **returned_log} # Merge the logs together else: @@ -296,6 +337,8 @@ class InfractionScheduler(Scheduler): # Check if the user is currently being watched by Big Brother. try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + active_watch = await self.bot.api_client.get( "bot/infractions", params={ @@ -312,6 +355,7 @@ class InfractionScheduler(Scheduler): try: # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{_id} as inactive in the database.") await self.bot.api_client.patch( f"bot/infractions/{_id}", json={"active": False} @@ -335,6 +379,7 @@ class InfractionScheduler(Scheduler): if send_log: log_title = f"expiration failed" if "Failure" in log_text else "expired" + log.trace(f"Sending deactivation mod log for infraction #{_id}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[_type][1], colour=Colours.soft_green, diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index c66222e5a..9ab870823 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -34,8 +34,8 @@ class Superstarify(InfractionScheduler, Cog): return # User didn't change their nickname. Abort! log.trace( - f"{before.display_name} is trying to change their nickname to {after.display_name}. " - "Checking if the user is in superstar-prison..." + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." ) active_superstarifies = await self.bot.api_client.get( @@ -48,6 +48,7 @@ class Superstarify(InfractionScheduler, Cog): ) if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") return infraction = active_superstarifies[0] @@ -132,15 +133,17 @@ class Superstarify(InfractionScheduler, Cog): # Post the infraction to the API reason = reason or f"old nick: {member.display_name}" infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration) + _id = infraction["id"] old_nick = member.display_name - forced_nick = self.get_nick(infraction["id"], member.id) + forced_nick = self.get_nick(_id, member.id) expiry_str = format_infraction(infraction["expires_at"]) # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(ctx.bot.loop, infraction["id"], infraction) + self.schedule_task(ctx.bot.loop, _id, infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( @@ -152,6 +155,7 @@ class Superstarify(InfractionScheduler, Cog): ) # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{_id} embed.") embed = Embed( title="Congratulations!", colour=constants.Colours.soft_orange, @@ -167,6 +171,7 @@ class Superstarify(InfractionScheduler, Cog): await ctx.send(embed=embed) # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{_id}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS["superstar"][0], colour=Colour.gold(), @@ -180,7 +185,7 @@ class Superstarify(InfractionScheduler, Cog): Old nickname: `{old_nick}` New nickname: `{forced_nick}` """), - footer=f"ID {infraction['id']}" + footer=f"ID {_id}" ) @command(name="unsuperstarify", aliases=("release_nick", "unstar")) @@ -198,6 +203,10 @@ class Superstarify(InfractionScheduler, Cog): # Don't bother sending a notification if the user left the guild. if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) return {} # DM the user about the expiration. @@ -216,6 +225,8 @@ class Superstarify(InfractionScheduler, Cog): @staticmethod def get_nick(infraction_id: int, member_id: int) -> str: """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + rng = random.Random(str(infraction_id) + str(member_id)) return rng.choice(STAR_NAMES) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 9179c0afb..325b9567a 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -37,6 +37,8 @@ def proxy_user(user_id: str) -> discord.Object: Used when a Member or User object cannot be resolved. """ + log.trace(f"Attempting to create a proxy user for the user id {user_id}.") + try: user_id = int(user_id) except ValueError: @@ -59,6 +61,8 @@ async def post_infraction( active: bool = True, ) -> t.Optional[dict]: """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + payload = { "actor": ctx.message.author.id, "hidden": hidden, @@ -92,6 +96,8 @@ async def post_infraction( async def has_active_infraction(ctx: Context, user: MemberObject, infr_type: str) -> bool: """Checks if a user already has an active infraction of the given type.""" + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + active_infractions = await ctx.bot.api_client.get( 'bot/infractions', params={ @@ -101,12 +107,14 @@ async def has_active_infraction(ctx: Context, user: MemberObject, infr_type: str } ) if active_infractions: + log.trace(f"{user} has active infractions of type {infr_type}.") await ctx.send( f":x: According to my records, this user already has a {infr_type} infraction. " f"See infraction **#{active_infractions[0]['id']}**." ) return True else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") return False @@ -118,6 +126,8 @@ async def notify_infraction( icon_url: str = Icons.token_removed ) -> bool: """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + embed = discord.Embed( description=textwrap.dedent(f""" **Type:** {infr_type.capitalize()} @@ -146,6 +156,8 @@ async def notify_pardon( icon_url: str = Icons.user_verified ) -> bool: """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + embed = discord.Embed( description=content, colour=Colours.soft_green diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 08abd91d7..ee6c0a8e6 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -36,11 +36,15 @@ class Scheduler(metaclass=CogABCMeta): `task_data` is passed to `Scheduler._scheduled_expiration` """ if task_id in self.scheduled_tasks: + log.debug( + f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." + ) return task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) self.scheduled_tasks[task_id] = task + log.debug(f"{self.cog_name}: scheduled task #{task_id}.") def cancel_task(self, task_id: str) -> None: """Un-schedules a task.""" @@ -51,7 +55,7 @@ class Scheduler(metaclass=CogABCMeta): return task.cancel() - log.debug(f"{self.cog_name}: Unscheduled {task_id}.") + log.debug(f"{self.cog_name}: unscheduled task #{task_id}.") del self.scheduled_tasks[task_id] -- cgit v1.2.3 From 818e48baec3564e73cc223eb91fbbfe768d73c67 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 31 Oct 2019 14:26:25 -0700 Subject: Moderation: use trailing _ instead of leading for variable names PEP 8 states the convention is a trailing underscore when used to prevent name conflicts. --- bot/cogs/moderation/scheduler.py | 54 ++++++++++++++++++------------------- bot/cogs/moderation/superstarify.py | 12 ++++----- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 7a08fc236..462c7fc7f 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -84,9 +84,9 @@ class InfractionScheduler(Scheduler): icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] expiry = infraction["expires_at"] - _id = infraction['id'] + id_ = infraction['id'] - log.trace(f"Applying {infr_type} infraction #{_id} to {user}.") + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") if expiry: expiry = time.format_infraction(expiry) @@ -121,13 +121,13 @@ class InfractionScheduler(Scheduler): if infraction["actor"] == self.bot.user.id: log.trace( - f"Infraction #{_id} actor is bot; including the reason in the confirmation message." + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." ) end_msg = f" (reason: {infraction['reason']})" elif ctx.channel.id not in STAFF_CHANNELS: log.trace( - f"Infraction #{_id} context is not in a staff channel; omitting infraction count." + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." ) end_msg = "" @@ -143,7 +143,7 @@ class InfractionScheduler(Scheduler): # Execute the necessary actions to apply the infraction on Discord. if action_coro: - log.trace(f"Awaiting the infraction #{_id} application action coroutine.") + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") try: await action_coro if expiry: @@ -156,16 +156,16 @@ class InfractionScheduler(Scheduler): log_content = ctx.author.mention log_title = "failed to apply" - log.warning(f"Failed to apply {infr_type} infraction #{_id} to {user}.") + log.warning(f"Failed to apply {infr_type} infraction #{id_} to {user}.") # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{_id} confirmation message.") + log.trace(f"Sending infraction #{id_} confirmation message.") await ctx.send( f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." ) # Send a log message to the mod log. - log.trace(f"Sending apply mod log for infraction #{_id}.") + log.trace(f"Sending apply mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=icon, colour=Colours.soft_red, @@ -181,7 +181,7 @@ class InfractionScheduler(Scheduler): footer=f"ID {infraction['id']}" ) - log.info(f"Applied {infr_type} infraction #{_id} to {user}.") + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") async def pardon_infraction(self, ctx: Context, infr_type: str, user: MemberObject) -> None: """Prematurely end an infraction for a user and log the action in the mod log.""" @@ -209,8 +209,8 @@ class InfractionScheduler(Scheduler): log_text["Member"] = f"{user.mention}(`{user.id}`)" log_text["Actor"] = str(ctx.message.author) log_content = None - _id = response[0]['id'] - footer = f"ID: {_id}" + id_ = response[0]['id'] + footer = f"ID: {id_}" # If multiple active infractions were found, mark them as inactive in the database # and cancel their expiration tasks. @@ -232,15 +232,15 @@ class InfractionScheduler(Scheduler): # 1. Discord cannot store multiple active bans or assign multiples of the same role # 2. It would send a pardon DM for each active infraction, which is redundant for infraction in response[1:]: - _id = infraction['id'] + id_ = infraction['id'] try: # Mark infraction as inactive in the database. await self.bot.api_client.patch( - f"bot/infractions/{_id}", + f"bot/infractions/{id_}", json={"active": False} ) except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{_id} ({infr_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") # This is simpler and cleaner than trying to concatenate all the errors. log_text["Failure"] = "See bot's logs for details." @@ -262,15 +262,15 @@ class InfractionScheduler(Scheduler): log_title = "pardon failed" log_content = ctx.author.mention - log.warning(f"Failed to pardon {infr_type} infraction #{_id} for {user}.") + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") else: confirm_msg = f":ok_hand: pardoned" log_title = "pardoned" - log.info(f"Pardoned {infr_type} infraction #{_id} for {user}.") + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") # Send a confirmation message to the invoking context. - log.trace(f"Sending infraction #{_id} pardon confirmation message.") + log.trace(f"Sending infraction #{id_} pardon confirmation message.") await ctx.send( f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " f"{log_text.get('Failure', '')}" @@ -305,9 +305,9 @@ class InfractionScheduler(Scheduler): mod_role = guild.get_role(constants.Roles.moderator) user_id = infraction["user"] _type = infraction["type"] - _id = infraction["id"] + id_ = infraction["id"] - log.info(f"Marking infraction #{_id} as inactive (expired).") + log.info(f"Marking infraction #{id_} as inactive (expired).") log_content = None log_text = { @@ -324,14 +324,14 @@ class InfractionScheduler(Scheduler): log_text = {**log_text, **returned_log} # Merge the logs together else: raise ValueError( - f"Attempted to deactivate an unsupported infraction #{_id} ({_type})!" + f"Attempted to deactivate an unsupported infraction #{id_} ({_type})!" ) except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{_id} ({_type}): bot lacks permissions") + log.warning(f"Failed to deactivate infraction #{id_} ({_type}): bot lacks permissions") log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{_id} ({_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({_type})") log_text["Failure"] = f"HTTPException with code {e.code}." log_content = mod_role.mention @@ -355,13 +355,13 @@ class InfractionScheduler(Scheduler): try: # Mark infraction as inactive in the database. - log.trace(f"Marking infraction #{_id} as inactive in the database.") + log.trace(f"Marking infraction #{id_} as inactive in the database.") await self.bot.api_client.patch( - f"bot/infractions/{_id}", + f"bot/infractions/{id_}", json={"active": False} ) except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{_id} ({_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({_type})") log_line = f"API request failed with code {e.status}." log_content = mod_role.mention @@ -379,13 +379,13 @@ class InfractionScheduler(Scheduler): if send_log: log_title = f"expiration failed" if "Failure" in log_text else "expired" - log.trace(f"Sending deactivation mod log for infraction #{_id}.") + log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[_type][1], colour=Colours.soft_green, title=f"Infraction {log_title}: {_type}", text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {_id}", + footer=f"ID: {id_}", content=log_content, ) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 9ab870823..9b3c62403 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -133,17 +133,17 @@ class Superstarify(InfractionScheduler, Cog): # Post the infraction to the API reason = reason or f"old nick: {member.display_name}" infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration) - _id = infraction["id"] + id_ = infraction["id"] old_nick = member.display_name - forced_nick = self.get_nick(_id, member.id) + forced_nick = self.get_nick(id_, member.id) expiry_str = format_infraction(infraction["expires_at"]) # Apply the infraction and schedule the expiration task. log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(ctx.bot.loop, _id, infraction) + self.schedule_task(ctx.bot.loop, id_, infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( @@ -155,7 +155,7 @@ class Superstarify(InfractionScheduler, Cog): ) # Send an embed with the infraction information to the invoking context. - log.trace(f"Sending superstar #{_id} embed.") + log.trace(f"Sending superstar #{id_} embed.") embed = Embed( title="Congratulations!", colour=constants.Colours.soft_orange, @@ -171,7 +171,7 @@ class Superstarify(InfractionScheduler, Cog): await ctx.send(embed=embed) # Log to the mod log channel. - log.trace(f"Sending apply mod log for superstar #{_id}.") + log.trace(f"Sending apply mod log for superstar #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS["superstar"][0], colour=Colour.gold(), @@ -185,7 +185,7 @@ class Superstarify(InfractionScheduler, Cog): Old nickname: `{old_nick}` New nickname: `{forced_nick}` """), - footer=f"ID {_id}" + footer=f"ID {id_}" ) @command(name="unsuperstarify", aliases=("release_nick", "unstar")) -- cgit v1.2.3 From e90b47e37587fedde765cf78bf27eca15202314d Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Tue, 5 Nov 2019 10:27:27 +0700 Subject: un-monstrosify code ... I think? --- bot/cogs/moderation/modlog.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 92b399874..53ea4ebcb 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -641,20 +641,26 @@ class ModLog(Cog, name="ModLog"): _before = before.clean_content _after = after.clean_content - groups = tuple((g[0], tuple(g[1])) - for g in itertools.groupby(difflib.ndiff(_before.split(), _after.split()), key=lambda s: s[0])) - for index, (name, values) in enumerate(groups): - sub = ' '.join(s[2:].strip() for s in values) - if name == '-': + # Getting the difference per words and group them by type - add, remove, same + # Note that this is intended grouping without sorting + diff = difflib.ndiff(_before.split(), _after.split()) + diff_groups = tuple( + (diff_type, tuple(s[2:] for s in diff_words)) + for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) + ) + + for index, (diff_type, words) in enumerate(diff_groups): + sub = ' '.join(words) + if diff_type == '-': _before = _before.replace(sub, f"[{sub}](http://.z)") - elif name == '+': + elif diff_type == '+': _after = _after.replace(sub, f"[{sub}](http://.z)") else: - if len(values) > 2: - new = (f"{values[0].strip() if index > 0 else ''}" + if len(words) > 2: + new = (f"{words[0] if index > 0 else ''}" " ... " - f"{values[-1].strip() if index < len(groups) - 1 else ''}") + f"{words[-1] if index < len(diff_groups) - 1 else ''}") else: new = sub _before = _before.replace(sub, new) -- cgit v1.2.3 From 7d10fcafa8302b137733d4cf84ffc15fe1a8f219 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Tue, 5 Nov 2019 10:29:54 +0700 Subject: remove unneccessary else --- bot/cogs/moderation/modlog.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 53ea4ebcb..4fbe39d7f 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -661,10 +661,8 @@ class ModLog(Cog, name="ModLog"): new = (f"{words[0] if index > 0 else ''}" " ... " f"{words[-1] if index < len(diff_groups) - 1 else ''}") - else: - new = sub - _before = _before.replace(sub, new) - _after = _after.replace(sub, new) + _before = _before.replace(sub, new) + _after = _after.replace(sub, new) response = ( f"**Author:** {author} (`{author.id}`)\n" -- cgit v1.2.3 From 990e216ea7cbb0acf7df1a6805c62d6243897a90 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Tue, 5 Nov 2019 23:59:19 +0700 Subject: Changed link used in hyperlink - A simple `http://.z` will show properly for PC client, but for android it completely broke -> changed to `http://o.hi` - minimum link to make discord think it's a link. --- bot/cogs/moderation/modlog.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 4fbe39d7f..6ce83840d 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -653,9 +653,9 @@ class ModLog(Cog, name="ModLog"): for index, (diff_type, words) in enumerate(diff_groups): sub = ' '.join(words) if diff_type == '-': - _before = _before.replace(sub, f"[{sub}](http://.z)") + _before = _before.replace(sub, f"[{sub}](http://o.hi)") elif diff_type == '+': - _after = _after.replace(sub, f"[{sub}](http://.z)") + _after = _after.replace(sub, f"[{sub}](http://o.hi)") else: if len(words) > 2: new = (f"{words[0] if index > 0 else ''}" -- cgit v1.2.3 From 80967822c06f9ece1ad6989bd9448464dea73ece Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 6 Nov 2019 09:18:35 +0700 Subject: Merged `else` and its single `if`, changed style to be more consistent Following Mark's reviews: - The single `if` inside the `else` can be merged with its `else` - this will reduce the level of complexity and indentation. - Changed from style ```py new = ('hello' 'world') ``` to ```py new = ( 'hello' 'world' ) ``` to be more consistent with the rest of the code --- bot/cogs/moderation/modlog.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 6ce83840d..c86bf6faa 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -656,13 +656,14 @@ class ModLog(Cog, name="ModLog"): _before = _before.replace(sub, f"[{sub}](http://o.hi)") elif diff_type == '+': _after = _after.replace(sub, f"[{sub}](http://o.hi)") - else: - if len(words) > 2: - new = (f"{words[0] if index > 0 else ''}" - " ... " - f"{words[-1] if index < len(diff_groups) - 1 else ''}") - _before = _before.replace(sub, new) - _after = _after.replace(sub, new) + elif len(words) > 2: + new = ( + f"{words[0] if index > 0 else ''}" + " ... " + f"{words[-1] if index < len(diff_groups) - 1 else ''}" + ) + _before = _before.replace(sub, new) + _after = _after.replace(sub, new) response = ( f"**Author:** {author} (`{author.id}`)\n" -- cgit v1.2.3 From ad79540f058e5f04bde72dc9dba86533b1e296a2 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 6 Nov 2019 17:36:03 -0800 Subject: Use trailing _ instead of leading for some variable names PEP 8 states the convention is a trailing underscore when used to prevent name conflicts. --- bot/cogs/eval.py | 4 ++-- bot/cogs/filtering.py | 8 ++++---- bot/cogs/moderation/scheduler.py | 14 +++++++------- bot/interpreter.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 9ce854f2c..00b988dde 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -148,7 +148,7 @@ class CodeEval(Cog): self.env.update(env) # Ignore this code, it works - _code = """ + code_ = """ async def func(): # (None,) -> Any try: with contextlib.redirect_stdout(self.stdout): @@ -162,7 +162,7 @@ async def func(): # (None,) -> Any """.format(textwrap.indent(code, ' ')) try: - exec(_code, self.env) # noqa: B102,S102 + exec(code_, self.env) # noqa: B102,S102 func = self.env['func'] res = await func() diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 4195783f1..1e7521054 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -43,7 +43,7 @@ class Filtering(Cog): def __init__(self, bot: Bot): self.bot = bot - _staff_mistake_str = "If you believe this was a mistake, please let staff know!" + staff_mistake_str = "If you believe this was a mistake, please let staff know!" self.filters = { "filter_zalgo": { "enabled": Filter.filter_zalgo, @@ -53,7 +53,7 @@ class Filtering(Cog): "user_notification": Filter.notify_user_zalgo, "notification_msg": ( "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{_staff_mistake_str}" + f"{staff_mistake_str}" ) }, "filter_invites": { @@ -63,7 +63,7 @@ class Filtering(Cog): "content_only": True, "user_notification": Filter.notify_user_invites, "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {_staff_mistake_str}\n\n" + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" r"Our server rules can be found here: " ) }, @@ -74,7 +74,7 @@ class Filtering(Cog): "content_only": True, "user_notification": Filter.notify_user_domains, "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {_staff_mistake_str}" + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" ) }, "watch_rich_embeds": { diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 462c7fc7f..49b61f35e 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -304,7 +304,7 @@ class InfractionScheduler(Scheduler): guild = self.bot.get_guild(constants.Guild.id) mod_role = guild.get_role(constants.Roles.moderator) user_id = infraction["user"] - _type = infraction["type"] + type_ = infraction["type"] id_ = infraction["id"] log.info(f"Marking infraction #{id_} as inactive (expired).") @@ -324,14 +324,14 @@ class InfractionScheduler(Scheduler): log_text = {**log_text, **returned_log} # Merge the logs together else: raise ValueError( - f"Attempted to deactivate an unsupported infraction #{id_} ({_type})!" + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({_type}): bot lacks permissions") + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions") log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{id_} ({_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") log_text["Failure"] = f"HTTPException with code {e.code}." log_content = mod_role.mention @@ -361,7 +361,7 @@ class InfractionScheduler(Scheduler): json={"active": False} ) except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{id_} ({_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") log_line = f"API request failed with code {e.status}." log_content = mod_role.mention @@ -381,9 +381,9 @@ class InfractionScheduler(Scheduler): log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[_type][1], + icon_url=utils.INFRACTION_ICONS[type_][1], colour=Colours.soft_green, - title=f"Infraction {log_title}: {_type}", + title=f"Infraction {log_title}: {type_}", text="\n".join(f"{k}: {v}" for k, v in log_text.items()), footer=f"ID: {id_}", content=log_content, diff --git a/bot/interpreter.py b/bot/interpreter.py index a42b45a2d..76a3fc293 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -20,8 +20,8 @@ class Interpreter(InteractiveInterpreter): write_callable = None def __init__(self, bot: Bot): - _locals = {"bot": bot} - super().__init__(_locals) + locals_ = {"bot": bot} + super().__init__(locals_) async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any: """Execute the provided source code as the bot & return the output.""" -- cgit v1.2.3 From 4d5d307f9a499cd874d90e6500f877ce560c012f Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 19:34:28 +0100 Subject: fix signatures and descriptions not being found when present --- bot/cogs/doc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index dcbcfe3ad..6e50cd27d 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -271,19 +271,19 @@ class Doc(commands.Cog): description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) description_end_index = search_html.find(str(end_tag)) - description = search_html[description_start_index:description_end_index].replace('¶', '') + description = search_html[description_start_index:description_end_index] signatures = None else: - description = str(symbol_heading.find_next_sibling("dd")).replace('¶', '') + description = str(symbol_heading.find_next_sibling("dd")) description_pos = search_html.find(description) # Get text of up to 3 signatures, remove unwanted symbols for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) - if signature and search_html.find(signature) < description_pos: + if signature and search_html.find(str(element)) < description_pos: signatures.append(signature) - return signatures, description + return signatures, description.replace('¶', '') @async_cache(arg_offset=1) async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: -- cgit v1.2.3 From 7de5156a7719f0639021e8186f7ea17f5b853af7 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 19:39:14 +0100 Subject: Add a newline after signatures for readability --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 6e50cd27d..653d48528 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -326,7 +326,7 @@ class Doc(commands.Cog): else: embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) - embed_description += description + embed_description += f"\n{description}" embed = discord.Embed( title=f'`{symbol}`', -- cgit v1.2.3 From 4795da86d0fef72ac677ae0a8f9e988da1923e17 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 19:43:56 +0100 Subject: Cut off description at 1000 chars if paragraph is not found --- bot/cogs/doc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 653d48528..b04355e28 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -305,6 +305,8 @@ class Doc(commands.Cog): if len(description) > 1000: shortened = description[:1000] last_paragraph_end = shortened.rfind('\n\n') + if last_paragraph_end == -1: + last_paragraph_end = 1000 description = description[:last_paragraph_end] # If there is an incomplete code block, cut it out -- cgit v1.2.3 From 34510f52c6bbe5e2a8bbfc34f8e5d648d0d39a96 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 20:03:48 +0100 Subject: Move paragraph search to not cut off long starting paragraphs Co-authored-by: scargly <29337040+scragly@users.noreply.github.com> --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index b04355e28..73895e3eb 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -304,7 +304,7 @@ class Doc(commands.Cog): # of a double newline (interpreted as a paragraph) before index 1000. if len(description) > 1000: shortened = description[:1000] - last_paragraph_end = shortened.rfind('\n\n') + last_paragraph_end = shortened.rfind('\n\n', 100) if last_paragraph_end == -1: last_paragraph_end = 1000 description = description[:last_paragraph_end] -- cgit v1.2.3 From 219cde70f03476ac6ae4a7f84322757bebeec51e Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 21:30:26 +0100 Subject: Add a command for refreshing inventories --- bot/cogs/doc.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 73895e3eb..8cf32fc7f 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -447,6 +447,28 @@ class Doc(commands.Cog): await self.refresh_inventory() await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"`+ {added}`" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"`- {removed}`" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"{added}\n{removed}" if added or removed else "" + ) + await ctx.send(embed=embed) + async def _fetch_inventory(self, inventory_url: str, config: SphinxConfiguration) -> Optional[dict]: """Get and return inventory from `inventory_url`. If fetching fails, return None.""" fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) -- cgit v1.2.3 From 4f393d7b95101cc31269eb30742195e771deb705 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Sun, 10 Nov 2019 21:31:47 +0100 Subject: Move signatures definition --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 8cf32fc7f..f7e8ae9d6 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -252,7 +252,6 @@ class Doc(commands.Cog): symbol_id = url.split('#')[-1] soup = BeautifulSoup(html, 'lxml') symbol_heading = soup.find(id=symbol_id) - signatures = [] search_html = str(soup) if symbol_heading is None: @@ -275,6 +274,7 @@ class Doc(commands.Cog): signatures = None else: + signatures = [] description = str(symbol_heading.find_next_sibling("dd")) description_pos = search_html.find(description) # Get text of up to 3 signatures, remove unwanted symbols -- cgit v1.2.3 From 6944175cea2c6595ec29b9ef67ff2ad9a8efb8ae Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 11 Nov 2019 00:58:32 +0100 Subject: clear renamed symbols on inventory refresh --- bot/cogs/doc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index f7e8ae9d6..90f496ceb 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -216,6 +216,7 @@ class Doc(commands.Cog): # Also, reset the cache used for fetching documentation. self.base_urls.clear() self.inventories.clear() + self.renamed_symbols.clear() async_cache.cache = OrderedDict() # Since Intersphinx is intended to be used with Sphinx, -- cgit v1.2.3 From 4a7de0bd155a4717f6cbc593a60dbec130e7ca40 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 11 Nov 2019 01:12:21 +0100 Subject: Do not cut off text arbitrarily but at last sentence to make sure no unfinished markdown is left in --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 90f496ceb..bf6cee101 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -307,7 +307,7 @@ class Doc(commands.Cog): shortened = description[:1000] last_paragraph_end = shortened.rfind('\n\n', 100) if last_paragraph_end == -1: - last_paragraph_end = 1000 + last_paragraph_end = shortened.rfind('. ') description = description[:last_paragraph_end] # If there is an incomplete code block, cut it out -- cgit v1.2.3 From fb338545c4c2a133e23a664c77813d2ce9aba41c Mon Sep 17 00:00:00 2001 From: Numerlor Date: Mon, 11 Nov 2019 01:15:01 +0100 Subject: syntax highlight diff of reloaded inventories --- bot/cogs/doc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index bf6cee101..0d4884e8b 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -458,15 +458,15 @@ class Doc(commands.Cog): # Get differences of added and removed inventories added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) if added: - added = f"`+ {added}`" + added = f"+ {added}" removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) if removed: - removed = f"`- {removed}`" + removed = f"- {removed}" embed = discord.Embed( title="Inventories refreshed", - description=f"{added}\n{removed}" if added or removed else "" + description=f"```diff\n{added}\n{removed}```" if added or removed else "" ) await ctx.send(embed=embed) -- cgit v1.2.3 From aac8404f65b419e212e5372015b63871fab7f3d1 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Mon, 11 Nov 2019 07:19:07 +0100 Subject: Adding ducky count tests and a new AsyncIteratorMock --- tests/bot/cogs/test_duck_pond.py | 70 ++++++++++++++++++++++++++++++++-------- tests/helpers.py | 38 +++++++++++++++++++++- 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 31c7e9f89..af8ef0e4d 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -1,8 +1,10 @@ +import asyncio import logging import unittest +from bot import constants from bot.cogs import duck_pond -from tests.helpers import MockBot, MockMember, MockMessage, MockReaction, MockRole +from tests.helpers import MockBot, MockEmoji, MockMember, MockMessage, MockReaction, MockRole class DuckPondTest(unittest.TestCase): @@ -13,49 +15,89 @@ class DuckPondTest(unittest.TestCase): self.bot = MockBot() self.cog = duck_pond.DuckPond(bot=self.bot) + # Override the constants we'll be needing + constants.STAFF_ROLES = (123,) + constants.DuckPond.custom_emojis = (789,) + constants.DuckPond.threshold = 1 + # Set up some roles - self.admin_role = MockRole(name="Admins", role_id=476190234653229056) - self.contrib_role = MockRole(name="Contributor", role_id=476190302659543061) + self.admin_role = MockRole(name="Admins", role_id=123) + self.contrib_role = MockRole(name="Contributor", role_id=456) # Set up some users - self.admin_member = MockMember(roles=(self.admin_role,)) + self.admin_member_1 = MockMember(roles=(self.admin_role,), id=1) + self.admin_member_2 = MockMember(roles=(self.admin_role,), id=2) self.contrib_member = MockMember(roles=(self.contrib_role,)) self.no_role_member = MockMember() # Set up emojis self.checkmark_emoji = "✅" self.thumbs_up_emoji = "👍" + self.unicode_duck_emoji = "🦆" + self.yellow_ducky_emoji = MockEmoji(id=789) # Set up reactions - self.checkmark_reaction = MockReaction(emoji=self.checkmark_emoji) - self.thumbs_up_reaction = MockReaction(emoji=self.thumbs_up_emoji) + self.checkmark_reaction = MockReaction( + emoji=self.checkmark_emoji, + user_list=[self.admin_member_1] + ) + self.thumbs_up_reaction = MockReaction( + emoji=self.thumbs_up_emoji, + user_list=[self.admin_member_1, self.contrib_member] + ) + self.yellow_ducky_reaction = MockReaction( + emoji=self.yellow_ducky_emoji, + user_list=[self.admin_member_1, self.contrib_member] + ) + self.unicode_duck_reaction_1 = MockReaction( + emoji=self.unicode_duck_emoji, + user_list=[self.admin_member_1] + ) + self.unicode_duck_reaction_2 = MockReaction( + emoji=self.unicode_duck_emoji, + user_list=[self.admin_member_2] + ) # Set up a messages self.checkmark_message = MockMessage(reactions=(self.checkmark_reaction,)) self.thumbs_up_message = MockMessage(reactions=(self.thumbs_up_reaction,)) + self.yellow_ducky_message = MockMessage(reactions=(self.yellow_ducky_reaction,)) + self.unicode_duck_message = MockMessage(reactions=(self.unicode_duck_reaction_1,)) + self.double_duck_message = MockMessage(reactions=(self.unicode_duck_reaction_1, self.unicode_duck_reaction_2)) self.no_reaction_message = MockMessage() def test_is_staff_correctly_identifies_staff(self): """Test that is_staff correctly identifies a staff member.""" with self.subTest(): - self.assertTrue(duck_pond.DuckPond.is_staff(self.admin_member)) - self.assertFalse(duck_pond.DuckPond.is_staff(self.contrib_member)) - self.assertFalse(duck_pond.DuckPond.is_staff(self.no_role_member)) + self.assertTrue(self.cog.is_staff(self.admin_member_1)) + self.assertFalse(self.cog.is_staff(self.contrib_member)) + self.assertFalse(self.cog.is_staff(self.no_role_member)) def test_has_green_checkmark_correctly_identifies_messages(self): """Test that has_green_checkmark recognizes messages with checkmarks.""" with self.subTest(): - self.assertTrue(duck_pond.DuckPond.has_green_checkmark(self.checkmark_message)) - self.assertFalse(duck_pond.DuckPond.has_green_checkmark(self.thumbs_up_message)) - self.assertFalse(duck_pond.DuckPond.has_green_checkmark(self.no_reaction_message)) + self.assertTrue(self.cog.has_green_checkmark(self.checkmark_message)) + self.assertFalse(self.cog.has_green_checkmark(self.thumbs_up_message)) + self.assertFalse(self.cog.has_green_checkmark(self.no_reaction_message)) def test_count_custom_duck_emojis(self): """A string decoding to numeric characters is a valid user ID.""" - pass + count_one_duck = self.cog.count_ducks(self.yellow_ducky_message) + count_no_ducks = self.cog.count_ducks(self.thumbs_up_message) + with self.subTest(): + self.assertEqual(asyncio.run(count_one_duck), 1) + self.assertEqual(asyncio.run(count_no_ducks), 0) def test_count_unicode_duck_emojis(self): """A string decoding to numeric characters is a valid user ID.""" - pass + count_no_ducks = self.cog.count_ducks(self.thumbs_up_message) + count_one_duck = self.cog.count_ducks(self.unicode_duck_message) + count_two_ducks = self.cog.count_ducks(self.double_duck_message) + + with self.subTest(): + self.assertEqual(asyncio.run(count_no_ducks), 0) + self.assertEqual(asyncio.run(count_one_duck), 1) + self.assertEqual(asyncio.run(count_two_ducks), 2) def test_count_mixed_duck_emojis(self): """A string decoding to numeric characters is a valid user ID.""" diff --git a/tests/helpers.py b/tests/helpers.py index 8496ba031..fd79141ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -102,8 +102,30 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): Python 3.8 will introduce an AsyncMock class in the standard library that will have some more features; this stand-in only overwrites the `__call__` method to an async version. """ + async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) + return super().__call__(*args, **kwargs) + + +class AsyncIteratorMock: + """ + A class to mock asyncronous iterators. + + This allows async for, which is used in certain Discord.py objects. For example, + an async iterator is returned by the Reaction.users() coroutine. + """ + + def __init__(self, sequence): + self.iter = iter(sequence) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration # Create a guild instance to get a realistic Mock of `discord.Guild` @@ -155,6 +177,7 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): For more info, see the `Mocking` section in `tests/README.md`. """ + def __init__( self, guild_id: int = 1, @@ -187,6 +210,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None: super().__init__(spec=role_instance, **kwargs) @@ -213,6 +237,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ + def __init__( self, name: str = "member", @@ -243,6 +268,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=bot_instance, **kwargs) @@ -279,6 +305,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: super().__init__(spec=channel_instance, **kwargs) self.id = channel_id @@ -320,6 +347,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=context_instance, **kwargs) self.bot = kwargs.get('bot', MockBot()) @@ -336,6 +364,7 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=message_instance, **kwargs) self.author = kwargs.get('author', MockMember()) @@ -353,6 +382,7 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=emoji_instance, **kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -371,6 +401,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=partial_emoji_instance, **kwargs) @@ -385,7 +416,12 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec=reaction_instance, **kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) + self.user_list = AsyncIteratorMock(kwargs.get('user_list', [])) + + def users(self): + return self.user_list -- cgit v1.2.3 From 98ccfbc218dc762e45f0146d0503dba1fe06fdb9 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Mon, 11 Nov 2019 14:55:40 +0100 Subject: Implement a mixed duck test. Also gets started setting up for the final tests, which will require more mockwork. --- tests/bot/cogs/test_duck_pond.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index af8ef0e4d..211e8b084 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -20,6 +20,10 @@ class DuckPondTest(unittest.TestCase): constants.DuckPond.custom_emojis = (789,) constants.DuckPond.threshold = 1 + # Mock bot.get_all_channels() + CHANNEL_ID = 555 + USER_ID = 666 + # Set up some roles self.admin_role = MockRole(name="Admins", role_id=123) self.contrib_role = MockRole(name="Contributor", role_id=456) @@ -63,7 +67,12 @@ class DuckPondTest(unittest.TestCase): self.thumbs_up_message = MockMessage(reactions=(self.thumbs_up_reaction,)) self.yellow_ducky_message = MockMessage(reactions=(self.yellow_ducky_reaction,)) self.unicode_duck_message = MockMessage(reactions=(self.unicode_duck_reaction_1,)) - self.double_duck_message = MockMessage(reactions=(self.unicode_duck_reaction_1, self.unicode_duck_reaction_2)) + self.double_unicode_duck_message = MockMessage( + reactions=(self.unicode_duck_reaction_1, self.unicode_duck_reaction_2) + ) + self.double_mixed_duck_message = MockMessage( + reactions=(self.unicode_duck_reaction_1, self.yellow_ducky_reaction) + ) self.no_reaction_message = MockMessage() def test_is_staff_correctly_identifies_staff(self): @@ -81,27 +90,28 @@ class DuckPondTest(unittest.TestCase): self.assertFalse(self.cog.has_green_checkmark(self.no_reaction_message)) def test_count_custom_duck_emojis(self): - """A string decoding to numeric characters is a valid user ID.""" - count_one_duck = self.cog.count_ducks(self.yellow_ducky_message) + """Test that count_ducks counts custom ducks correctly.""" count_no_ducks = self.cog.count_ducks(self.thumbs_up_message) + count_one_duck = self.cog.count_ducks(self.yellow_ducky_message) with self.subTest(): - self.assertEqual(asyncio.run(count_one_duck), 1) self.assertEqual(asyncio.run(count_no_ducks), 0) + self.assertEqual(asyncio.run(count_one_duck), 1) def test_count_unicode_duck_emojis(self): - """A string decoding to numeric characters is a valid user ID.""" - count_no_ducks = self.cog.count_ducks(self.thumbs_up_message) + """Test that count_ducks counts unicode ducks correctly.""" count_one_duck = self.cog.count_ducks(self.unicode_duck_message) - count_two_ducks = self.cog.count_ducks(self.double_duck_message) + count_two_ducks = self.cog.count_ducks(self.double_unicode_duck_message) with self.subTest(): - self.assertEqual(asyncio.run(count_no_ducks), 0) self.assertEqual(asyncio.run(count_one_duck), 1) self.assertEqual(asyncio.run(count_two_ducks), 2) def test_count_mixed_duck_emojis(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + """Test that count_ducks counts mixed ducks correctly.""" + count_two_ducks = self.cog.count_ducks(self.double_mixed_duck_message) + + with self.subTest(): + self.assertEqual(asyncio.run(count_two_ducks), 2) def test_raw_reaction_add_rejects_bot(self): """A string decoding to numeric characters is a valid user ID.""" -- cgit v1.2.3 From 160962a56110ed970c7419ed650d9d8a84dbaa9a Mon Sep 17 00:00:00 2001 From: Numerlor Date: Tue, 12 Nov 2019 16:46:51 +0100 Subject: Adjust code style and comments --- bot/cogs/doc.py | 77 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 0d4884e8b..b82eac5fe 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -176,36 +176,34 @@ class Doc(commands.Cog): self.base_urls[package_name] = base_url package = await self._fetch_inventory(inventory_url, config) - if package: - for group, value in package.items(): - # Each value has a bunch of information in the form - # `(package_name, version, relative_url, ???)`, and we only - # need the package_name and the relative documentation URL. - for symbol, (package_name, _, relative_doc_url, _) in value.items(): - absolute_doc_url = base_url + relative_doc_url - - if symbol in self.inventories: - # get `group_name` from _:group_name - group_name = group.split(":")[1] - if (group_name in NO_OVERRIDE_GROUPS - # check if any package from `NO_OVERRIDE_PACKAGES` - # is in base URL of the symbol that would be overridden - or any(package in self.inventories[symbol].split("/", 3)[2] - for package in NO_OVERRIDE_PACKAGES)): - - symbol = f"{group_name}.{symbol}" - # If renamed `symbol` already exists, add library name in front. - if symbol in self.renamed_symbols: - # Split `package_name` because of packages like Pillow that have spaces in them. - symbol = f"{package_name.split()[0]}.{symbol}" - - self.inventories[symbol] = absolute_doc_url - self.renamed_symbols.add(symbol) - continue - - self.inventories[symbol] = absolute_doc_url - - log.trace(f"Fetched inventory for {package_name}.") + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _, relative_doc_url, _) in value.items(): + absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + + self.inventories[symbol] = absolute_doc_url + + log.trace(f"Fetched inventory for {package_name}.") async def refresh_inventory(self) -> None: """Refresh internal documentation inventory.""" @@ -337,9 +335,10 @@ class Doc(commands.Cog): description=embed_description ) # Show all symbols with the same name that were renamed in the footer. - embed.set_footer(text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} - if renamed.endswith(f".{symbol}")) - ) + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} + if renamed.endswith(f".{symbol}")) + ) return embed @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) @@ -477,11 +476,15 @@ class Doc(commands.Cog): try: package = await self.bot.loop.run_in_executor(None, fetch_func) except ConnectTimeout: - log.error(f"Fetching of inventory {inventory_url} timed out," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})") + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) except ProtocolError: - log.error(f"Connection lost while fetching inventory {inventory_url}," - f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})") + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) except HTTPError as e: log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") return None -- cgit v1.2.3 From a89349ee32bbf2b3506cc278999575db1fbfde74 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Tue, 12 Nov 2019 22:05:28 +0100 Subject: Add tests for on_raw_reaction_add. Basically I suck at this and I can't get this return_value thing to work. I'll have Ves look at it to resolve it. As of right now, multiple tests are failing. --- tests/bot/cogs/test_duck_pond.py | 84 +++++++++++++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 211e8b084..088d8ac79 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -1,10 +1,11 @@ import asyncio import logging import unittest +from unittest.mock import MagicMock from bot import constants from bot.cogs import duck_pond -from tests.helpers import MockBot, MockEmoji, MockMember, MockMessage, MockReaction, MockRole +from tests.helpers import MockBot, MockEmoji, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel class DuckPondTest(unittest.TestCase): @@ -15,23 +16,27 @@ class DuckPondTest(unittest.TestCase): self.bot = MockBot() self.cog = duck_pond.DuckPond(bot=self.bot) + # Set up some constants + self.CHANNEL_ID = 555 + self.MESSAGE_ID = 666 + self.BOT_ID = 777 + self.CONTRIB_ID = 888 + self.ADMIN_ID = 999 + # Override the constants we'll be needing constants.STAFF_ROLES = (123,) constants.DuckPond.custom_emojis = (789,) constants.DuckPond.threshold = 1 - # Mock bot.get_all_channels() - CHANNEL_ID = 555 - USER_ID = 666 - # Set up some roles self.admin_role = MockRole(name="Admins", role_id=123) self.contrib_role = MockRole(name="Contributor", role_id=456) # Set up some users - self.admin_member_1 = MockMember(roles=(self.admin_role,), id=1) - self.admin_member_2 = MockMember(roles=(self.admin_role,), id=2) - self.contrib_member = MockMember(roles=(self.contrib_role,)) + self.admin_member_1 = MockMember(roles=(self.admin_role,), id=self.ADMIN_ID) + self.admin_member_2 = MockMember(roles=(self.admin_role,), id=911) + self.contrib_member = MockMember(roles=(self.contrib_role,), id=self.CONTRIB_ID) + self.bot_member = MockMember(roles=(self.contrib_role,), id=self.BOT_ID, bot=True) self.no_role_member = MockMember() # Set up emojis @@ -61,6 +66,14 @@ class DuckPondTest(unittest.TestCase): emoji=self.unicode_duck_emoji, user_list=[self.admin_member_2] ) + self.bot_reaction = MockReaction( + emoji=self.yellow_ducky_emoji, + user_list=[self.bot_member] + ) + self.contrib_reaction = MockReaction( + emoji=self.yellow_ducky_emoji, + user_list=[self.contrib_member] + ) # Set up a messages self.checkmark_message = MockMessage(reactions=(self.checkmark_reaction,)) @@ -73,8 +86,18 @@ class DuckPondTest(unittest.TestCase): self.double_mixed_duck_message = MockMessage( reactions=(self.unicode_duck_reaction_1, self.yellow_ducky_reaction) ) + + self.bot_message = MockMessage(reactions=(self.bot_reaction,)) + self.contrib_message = MockMessage(reactions=(self.contrib_reaction,)) self.no_reaction_message = MockMessage() + # Set up some channels + self.text_channel = MockTextChannel(id=self.CHANNEL_ID) + + @staticmethod + def _mock_send_webhook(content, username, avatar_url, embed): + """Mock for the send_webhook method in DuckPond""" + def test_is_staff_correctly_identifies_staff(self): """Test that is_staff correctly identifies a staff member.""" with self.subTest(): @@ -114,16 +137,49 @@ class DuckPondTest(unittest.TestCase): self.assertEqual(asyncio.run(count_two_ducks), 2) def test_raw_reaction_add_rejects_bot(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + """Test that send_webhook is not called if the user is a bot.""" + self.text_channel.fetch_message.return_value = self.bot_message + self.bot.get_all_channels.return_value = (self.text_channel,) + + payload = MagicMock( # RawReactionActionEvent + channel_id=self.CHANNEL_ID, + message_id=self.MESSAGE_ID, + user_id=self.BOT_ID, + ) + + with self.subTest(): + asyncio.run(self.cog.on_raw_reaction_add(payload)) + self.bot.cog.send_webhook.assert_not_called() def test_raw_reaction_add_rejects_non_staff(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + """Test that send_webhook is not called if the user is not a member of staff.""" + self.text_channel.fetch_message.return_value = self.contrib_message + self.bot.get_all_channels.return_value = (self.text_channel,) + + payload = MagicMock( # RawReactionActionEvent + channel_id=self.CHANNEL_ID, + message_id=self.MESSAGE_ID, + user_id=self.CONTRIB_ID, + ) + + with self.subTest(): + asyncio.run(self.cog.on_raw_reaction_add(payload)) + self.bot.cog.send_webhook.assert_not_called() def test_raw_reaction_add_sends_message_on_valid_input(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + """Test that send_webhook is called if payload is valid.""" + self.text_channel.fetch_message.return_value = self.unicode_duck_message + self.bot.get_all_channels.return_value = (self.text_channel,) + + payload = MagicMock( # RawReactionActionEvent + channel_id=self.CHANNEL_ID, + message_id=self.MESSAGE_ID, + user_id=self.ADMIN_ID, + ) + + with self.subTest(): + asyncio.run(self.cog.on_raw_reaction_add(payload)) + self.bot.cog.send_webhook.assert_called_once() def test_raw_reaction_remove_rejects_non_checkmarks(self): """A string decoding to numeric characters is a valid user ID.""" -- cgit v1.2.3 From 2ff711b5d7299baab84e20842cec856c5f17f992 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 13 Nov 2019 09:15:17 +0700 Subject: Switched to using list instead of `str.replace()` for much better control over each word. --- bot/cogs/moderation/modlog.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index c86bf6faa..3a7e0d3ce 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -639,39 +639,42 @@ class ModLog(Cog, name="ModLog"): channel = before.channel channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - _before = before.clean_content - _after = after.clean_content - # Getting the difference per words and group them by type - add, remove, same # Note that this is intended grouping without sorting - diff = difflib.ndiff(_before.split(), _after.split()) + diff = difflib.ndiff(before.clean_content.split(), after.clean_content.split()) diff_groups = tuple( (diff_type, tuple(s[2:] for s in diff_words)) for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) ) + _before = [] + _after = [] + for index, (diff_type, words) in enumerate(diff_groups): sub = ' '.join(words) if diff_type == '-': - _before = _before.replace(sub, f"[{sub}](http://o.hi)") + _before.append(f"[{sub}](http://o.hi)") elif diff_type == '+': - _after = _after.replace(sub, f"[{sub}](http://o.hi)") + _after.append(f"[{sub}](http://o.hi)") elif len(words) > 2: new = ( f"{words[0] if index > 0 else ''}" " ... " f"{words[-1] if index < len(diff_groups) - 1 else ''}" ) - _before = _before.replace(sub, new) - _after = _after.replace(sub, new) + _before.append(new) + _after.append(new) + elif diff_type == ' ': + _before.append(sub) + _after.append(sub) response = ( f"**Author:** {author} (`{author.id}`)\n" f"**Channel:** {channel_name} (`{channel.id}`)\n" f"**Message ID:** `{before.id}`\n" "\n" - f"**Before**:\n{_before}\n" - f"**After**:\n{_after}\n" + f"**Before**:\n{' '.join(_before)}\n" + f"**After**:\n{' '.join(_after)}\n" "\n" f"[jump to message]({after.jump_url})" ) -- cgit v1.2.3 From 501aa5655c8039f43b3cf3106474b8be16b4074a Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 14 Nov 2019 12:05:18 +0700 Subject: Condensed logic, now only check for `add` `remove` `same` diff_type only. --- bot/cogs/moderation/modlog.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 3a7e0d3ce..ce2a5e1f7 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -656,15 +656,13 @@ class ModLog(Cog, name="ModLog"): _before.append(f"[{sub}](http://o.hi)") elif diff_type == '+': _after.append(f"[{sub}](http://o.hi)") - elif len(words) > 2: - new = ( - f"{words[0] if index > 0 else ''}" - " ... " - f"{words[-1] if index < len(diff_groups) - 1 else ''}" - ) - _before.append(new) - _after.append(new) elif diff_type == ' ': + if len(words) > 2: + sub = ( + f"{words[0] if index > 0 else ''}" + " ... " + f"{words[-1] if index < len(diff_groups) - 1 else ''}" + ) _before.append(sub) _after.append(sub) -- cgit v1.2.3 From ccda39c5e42e94011c9c1bd14080d004d3d61f02 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Thu, 14 Nov 2019 10:50:49 +0100 Subject: Add bot=False default value to MockMember By default, a mocked value is considered `truthy` in Python, like all non-empty/non-zero/non-None values in Python. This means that if an attribute is not explicitly set on a mock, it will evaluate at as truthy in a boolean context, since the mock will provide a truthy mocked value by default. This is not the best default value for the `bot` attribute of our MockMember type, since members are rarely bots. It makes much more intuitive sense to me to consider a member to not be a bot, unless we explicitly set `bot=True`. This commit sets that sensible default value that can be overwritten by passing `bot=False` to the constructor or setting the `object.bot` attribute to `False` after the creation of the mock. --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 22f07934f..199d45700 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -242,7 +242,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin information, see the `MockGuild` docstring. """ def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'name': 'member', 'id': next(self.discord_id)} + default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] -- cgit v1.2.3 From 61051f9cc5abbf571dfa13c49324109ef16f78fc Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Thu, 14 Nov 2019 10:58:40 +0100 Subject: Add MockAttachment type and attachments default for MockMessage As stated from the start, our intention is to add custom mock types as we need them for testing. While writing tests for DuckPond, I noticed that we did not have a mock type for Attachments, so I added one with this commit. In addition, I think it's a very sensible for MockMessage to have an empty list as a default value for the `attachements` attribute. This is equal to what `discord.Message` returns for a message without attachments and makes sure that if you don't explicitely add an attachment to a message, `MockMessage.attachments` tests as falsey. --- tests/helpers.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 199d45700..3e43679fe 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -355,6 +355,20 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.channel = kwargs.get('channel', MockTextChannel()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) + + +class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Attachment objects. + + Instances of this class will follow the specifications of `discord.Attachment` instances. For + more information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=attachment_instance, **kwargs) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. @@ -364,7 +378,8 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ def __init__(self, **kwargs) -> None: - super().__init__(spec_set=message_instance, **kwargs) + default_kwargs = {'attachments': []} + super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) -- cgit v1.2.3 From 8c64fc637dda73cfa4b79d1f3541d067380e51d8 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:02:40 +0100 Subject: Check only for bot's green checkmark in DuckPond Previously, the presence of any green checkmark as a reaction would prevent a message from being relayed to the duck pond, regardless of the actor of that reaction. Since we only want to check if the bot has already processed this message, we should check for a checkmark added by the bot. This commit adds such a user check. --- bot/cogs/duck_pond.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 45bbc410b..aac023a2e 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -37,12 +37,13 @@ class DuckPond(Cog): return True return False - @staticmethod - def has_green_checkmark(message: Message) -> bool: + async def has_green_checkmark(self, message: Message) -> bool: """Check if the message has a green checkmark reaction.""" for reaction in message.reactions: if reaction.emoji == "✅": - return True + async for user in reaction.users(): + if user == self.bot.user: + return True return False async def send_webhook( @@ -115,7 +116,7 @@ class DuckPond(Cog): return # Does the message already have a green checkmark? - if self.has_green_checkmark(message): + if await self.has_green_checkmark(message): return # Time to count our ducks! -- cgit v1.2.3 From f56f6cebc5300ec3c1b52ec8988ae9c27571c14e Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:09:22 +0100 Subject: Refactor DuckPond msg relay to separate method To allow for separate testing of the code that relays messages to the duck pond, I have moved this part of the code from the event listener to a separate method. The overall logic has remained unchanged. In addition, I've kaizened to things: - Removed unnecessary f-string without interpolation; - Removed double negative (not item not in list) --- bot/cogs/duck_pond.py | 62 +++++++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index aac023a2e..b2b4ad0c2 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -62,7 +62,7 @@ class DuckPond(Cog): embed=embed ) except discord.HTTPException: - log.exception(f"Failed to send a message to the Duck Pool webhook") + log.exception("Failed to send a message to the Duck Pool webhook") async def count_ducks(self, message: Message) -> int: """ @@ -76,8 +76,8 @@ class DuckPond(Cog): for reaction in message.reactions: async for user in reaction.users(): - # Is the user or member a staff member? - if not self.is_staff(user) or not user.id not in duck_reactors: + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: continue # Is the emoji a duck? @@ -91,6 +91,35 @@ class DuckPond(Cog): duck_reactors.append(user.id) return duck_count + async def relay_message_to_duck_pond(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + clean_content = message.clean_content + + if clean_content: + await self.send_webhook( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.send_webhook( + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception(f"Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + @Cog.listener() async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: """ @@ -124,32 +153,7 @@ class DuckPond(Cog): # If we've got more than the required amount of ducks, send the message to the duck_pond. if duck_count >= constants.DuckPond.threshold: - clean_content = message.clean_content - - if clean_content: - await self.send_webhook( - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - if message.attachments: - try: - await send_attachments(message, self.webhook) - except (errors.Forbidden, errors.NotFound): - e = Embed( - description=":x: **This message contained an attachment, but it could not be retrieved**", - color=Color.red() - ) - await self.send_webhook( - embed=e, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - except discord.HTTPException: - log.exception(f"Failed to send an attachment to the webhook") - - await message.add_reaction("✅") + await self.relay_message_to_duck_pond(message) @Cog.listener() async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: -- cgit v1.2.3 From 89890d6e1b673622cba918be48f325540e45db9e Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:15:01 +0100 Subject: Move payload checks to start of DuckPond.on_raw_reaction_add The `DuckPond.on_raw_message_add` event listener makes an API call to fetch the message the reaction was added to. However, we don't need to fetch the message if the reaction that was added is not relevant to the duck pond. To prevent such unnecessary API calls, I have moved the code that checks for the relevance of the reaction event to before the code that fetches the message. --- bot/cogs/duck_pond.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index b2b4ad0c2..68fb09408 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -129,6 +129,13 @@ class DuckPond(Cog): amount of ducks specified in the config under duck_pond/threshold, it will send the message off to the duck pond. """ + # Is the emoji in the reaction a duck? + if payload.emoji.is_custom_emoji(): + if payload.emoji.id not in constants.DuckPond.custom_emojis: + return + elif payload.emoji.name != "🦆": + return + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) message = await channel.fetch_message(payload.message_id) member = discord.utils.get(message.guild.members, id=payload.user_id) @@ -137,13 +144,6 @@ class DuckPond(Cog): if not self.is_staff(member) or member.bot: return - # Is the emoji in the reaction a duck? - if payload.emoji.is_custom_emoji(): - if payload.emoji.id not in constants.DuckPond.custom_emojis: - return - elif payload.emoji.name != "🦆": - return - # Does the message already have a green checkmark? if await self.has_green_checkmark(message): return -- cgit v1.2.3 From 2779a912a8fbe29453543a8fd2888a842c3beb47 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:21:33 +0100 Subject: Add `return_value` support and assertions to AsyncIteratorMock The AsyncIteratorMock included in Python 3.8 will work similarly to the mocks of callabes. This means that it allows you to set the items it will yield using the `return_value` attribute. It will also have support for the common Mock-specific assertions. This commit introduces some backports of those features in a slightly simplified way to make the transition to Python 3.8 easier in the future. --- tests/helpers.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 3e43679fe..50652ef9a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -127,14 +127,20 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): class AsyncIteratorMock: """ - A class to mock asyncronous iterators. + A class to mock asynchronous iterators. This allows async for, which is used in certain Discord.py objects. For example, - an async iterator is returned by the Reaction.users() coroutine. + an async iterator is returned by the Reaction.users() method. """ - def __init__(self, sequence): - self.iter = iter(sequence) + def __init__(self, iterable: Iterable = None): + if iterable is None: + iterable = [] + + self.iter = iter(iterable) + self.iterable = iterable + + self.call_count = 0 def __aiter__(self): return self @@ -145,6 +151,50 @@ class AsyncIteratorMock: except StopIteration: raise StopAsyncIteration + def __call__(self): + """ + Keeps track of the number of times an instance has been called. + + This is useful, since it typically shows that the iterator has actually been used somewhere after we have + instantiated the mock for an attribute that normally returns an iterator when called. + """ + self.call_count += 1 + return self + + @property + def return_value(self): + """Makes `self.iterable` accessible as self.return_value.""" + return self.iterable + + @return_value.setter + def return_value(self, iterable): + """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" + self.iter = iter(iterable) + self.iterable = iterable + + def assert_called(self): + """Asserts if the AsyncIteratorMock instance has been called at least once.""" + if self.call_count == 0: + raise AssertionError("Expected AsyncIteratorMock to have been called.") + + def assert_called_once(self): + """Asserts if the AsyncIteratorMock instance has been called exactly once.""" + if self.call_count != 1: + raise AssertionError( + f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." + ) + + def assert_not_called(self): + """Asserts if the AsyncIteratorMock instance has not been called.""" + if self.call_count != 0: + raise AssertionError( + f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." + ) + + def reset_mock(self): + """Resets the call count, but not the return value or iterator.""" + self.call_count = 0 + # Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { -- cgit v1.2.3 From 2c77288eb3ff081e70508094bb8d030900860259 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:28:56 +0100 Subject: Add MockUser to mock `discord.User` objects I have added a special mock that follows the specifications of a `discord.User` instance. This is useful, since `Users` have less attributes available than `discord.Members`. Since this difference in availability of information can be important, we should not use a `MockMember` to mock a `discord.user`. --- tests/helpers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/helpers.py b/tests/helpers.py index 50652ef9a..4da6bf84d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -303,6 +303,25 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" +# Create a User instance to get a realistic Mock of `discord.User` +user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock()) + + +class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): + """ + A Mock subclass to mock User objects. + + Instances of this class will follow the specifications of `discord.User` instances. For more + information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} + super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"@{self.name}" + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None -- cgit v1.2.3 From 647370d7881d1ab242186599adb76a56a0815150 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:34:46 +0100 Subject: Adjust MockReaction for new AsyncIteratorMock protocol The new AsyncIteratorMock no longer needs an additional method to be used with a Mock object. --- tests/helpers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 4da6bf84d..13852397f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -500,7 +500,5 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): super().__init__(spec_set=reaction_instance, **kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) - self.user_list = AsyncIteratorMock(kwargs.get('user_list', [])) + self.users = AsyncIteratorMock(kwargs.get('users', [])) - def users(self): - return self.user_list -- cgit v1.2.3 From b42a7b5b7f2c1c9f9924eeb9d39f7767306824ec Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:37:02 +0100 Subject: Add MockAsyncWebhook to mock `discord.Webhook` objects I have added a mock type to mock `discord.Webhook` instances. Note that the current type is specifically meant to mock webhooks that use an AsyncAdaptor and therefore has AsyncMock/coroutine mocks for the "maybe-coroutine" methods specified in the `discord.py` docs. --- tests/helpers.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/helpers.py b/tests/helpers.py index 13852397f..b2daae92d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -502,3 +502,24 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.message = kwargs.get('message', MockMessage()) self.users = AsyncIteratorMock(kwargs.get('users', [])) + +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) + + +class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. + + Instances of this class will follow the specifications of `discord.Webhook` instances. For + more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=webhook_instance, **kwargs) + + # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined + # as coroutines. That's why we need to set the methods manually. + self.send = AsyncMock() + self.edit = AsyncMock() + self.delete = AsyncMock() + self.execute = AsyncMock() -- cgit v1.2.3 From a692a95896328adf1d52c5a5548e0c72540d6cbc Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Fri, 15 Nov 2019 01:39:51 +0100 Subject: Add unit tests with full coverage for `bot.cogs.duck_pond` This commit adds unit tests that provide a full branch coverage of the `bot.cogs.duck_pond` file. --- tests/bot/cogs/test_duck_pond.py | 649 +++++++++++++++++++++++++++++---------- 1 file changed, 490 insertions(+), 159 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 088d8ac79..ceefc286f 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -1,193 +1,524 @@ import asyncio import logging +import typing import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch + +import discord from bot import constants from bot.cogs import duck_pond -from tests.helpers import MockBot, MockEmoji, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel +from tests import base +from tests import helpers + +MODULE_PATH = "bot.cogs.duck_pond" + + +class DuckPondTests(base.LoggingTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) -class DuckPondTest(unittest.TestCase): - """Tests the `DuckPond` cog.""" + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) def setUp(self): - """Adds the cog, a bot, and the mocks we'll need for our tests.""" - self.bot = MockBot() + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) self.cog = duck_pond.DuckPond(bot=self.bot) - # Set up some constants - self.CHANNEL_ID = 555 - self.MESSAGE_ID = 666 - self.BOT_ID = 777 - self.CONTRIB_ID = 888 - self.ADMIN_ID = 999 - - # Override the constants we'll be needing - constants.STAFF_ROLES = (123,) - constants.DuckPond.custom_emojis = (789,) - constants.DuckPond.threshold = 1 - - # Set up some roles - self.admin_role = MockRole(name="Admins", role_id=123) - self.contrib_role = MockRole(name="Contributor", role_id=456) - - # Set up some users - self.admin_member_1 = MockMember(roles=(self.admin_role,), id=self.ADMIN_ID) - self.admin_member_2 = MockMember(roles=(self.admin_role,), id=911) - self.contrib_member = MockMember(roles=(self.contrib_role,), id=self.CONTRIB_ID) - self.bot_member = MockMember(roles=(self.contrib_role,), id=self.BOT_ID, bot=True) - self.no_role_member = MockMember() - - # Set up emojis - self.checkmark_emoji = "✅" - self.thumbs_up_emoji = "👍" - self.unicode_duck_emoji = "🦆" - self.yellow_ducky_emoji = MockEmoji(id=789) - - # Set up reactions - self.checkmark_reaction = MockReaction( - emoji=self.checkmark_emoji, - user_list=[self.admin_member_1] - ) - self.thumbs_up_reaction = MockReaction( - emoji=self.thumbs_up_emoji, - user_list=[self.admin_member_1, self.contrib_member] - ) - self.yellow_ducky_reaction = MockReaction( - emoji=self.yellow_ducky_emoji, - user_list=[self.admin_member_1, self.contrib_member] - ) - self.unicode_duck_reaction_1 = MockReaction( - emoji=self.unicode_duck_emoji, - user_list=[self.admin_member_1] + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_loop.called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_ready.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_ready.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + [record] = log_watcher.records + self.assertEqual(record.message, f"Failed to fetch webhook with id `{self.cog.webhook_id}`") + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) ) - self.unicode_duck_reaction_2 = MockReaction( - emoji=self.unicode_duck_emoji, - user_list=[self.admin_member_2] + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @helpers.async_test + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji), + helpers.MockReaction(emoji=self.thumbs_up_emoji) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) ) - self.bot_reaction = MockReaction( - emoji=self.yellow_ducky_emoji, - user_list=[self.bot_member] + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def test_send_webhook_correctly_passes_on_arguments(self): + """The `send_webhook` method should pass the arguments to the webhook correctly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + + content = "fake content" + username = "fake username" + avatar_url = "fake avatar_url" + embed = "fake embed" + + asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed)) + + self.cog.webhook.send.assert_called_once_with( + content=content, + username=username, + avatar_url=avatar_url, + embed=embed ) - self.contrib_reaction = MockReaction( - emoji=self.yellow_ducky_emoji, - user_list=[self.contrib_member] + + def test_send_webhook_logs_when_sending_message_fails(self): + """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.") + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.send_webhook()) + + self.assertEqual(len(log_watcher.records), 1) + + [record] = log_watcher.records + self.assertEqual(record.message, "Failed to send a message to the Duck Pool webhook") + self.assertEqual(record.levelno, logging.ERROR) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + @helpers.async_test + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ]), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ]), + 3+4 + ), ) - # Set up a messages - self.checkmark_message = MockMessage(reactions=(self.checkmark_reaction,)) - self.thumbs_up_message = MockMessage(reactions=(self.thumbs_up_reaction,)) - self.yellow_ducky_message = MockMessage(reactions=(self.yellow_ducky_reaction,)) - self.unicode_duck_message = MockMessage(reactions=(self.unicode_duck_reaction_1,)) - self.double_unicode_duck_message = MockMessage( - reactions=(self.unicode_duck_reaction_1, self.unicode_duck_reaction_2) + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + @helpers.async_test + async def test_relay_message_to_duck_pond_correctly_relays_content_and_attachments(self): + """The `relay_message_to_duck_pond` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(clean_content="", attachments=[]), False, False), + (helpers.MockMessage(clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True), ) - self.double_mixed_duck_message = MockMessage( - reactions=(self.unicode_duck_reaction_1, self.yellow_ducky_reaction) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message_to_duck_pond(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_to_duck_pond_handles_send_attachments_exceptions(self, send_attachments, send_webhook): + """The `relay_message_to_duck_pond` method should handle exceptions when calling `send_attachment`.""" + + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") + + # Subtests for the first `except` block + for side_effect in side_effects: + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message_to_duck_pond(message) + + self.assertEqual(send_webhook.call_count, 2) + send_webhook.reset_mock() + + # Subtests for the second `except` block + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message_to_duck_pond(message) + + send_webhook.assert_called_once_with( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + [record] = log_watcher.records + self.assertEqual(record.message, "Failed to send an attachment to the webhook") + self.assertEqual(record.levelno, logging.ERROR) + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + @helpers.async_test + async def test_on_raw_reaction_add_returns_for_non_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + payload_custom_emoji = MagicMock(label="Non-Duck Custom Emoji") + payload_custom_emoji.emoji.is_custom_emoji.return_value = True + payload_custom_emoji.emoji.id = 12345 + + payload_unicode_emoji = MagicMock(label="Non-Duck Unicode Emoji") + payload_unicode_emoji.emoji.is_custom_emoji.return_value = False + payload_unicode_emoji.emoji.name = self.thumbs_up_emoji + + for payload in (payload_custom_emoji, payload_unicode_emoji): + with self.subTest(case=payload.label), patch(f"{MODULE_PATH}.discord.utils.get") as discord_utils_get: + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + discord_utils_get.assert_not_called() + + @helpers.async_test + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), ) - self.bot_message = MockMessage(reactions=(self.bot_reaction,)) - self.contrib_message = MockMessage(reactions=(self.contrib_reaction,)) - self.no_reaction_message = MockMessage() - - # Set up some channels - self.text_channel = MockTextChannel(id=self.CHANNEL_ID) - - @staticmethod - def _mock_send_webhook(content, username, avatar_url, embed): - """Mock for the send_webhook method in DuckPond""" - - def test_is_staff_correctly_identifies_staff(self): - """Test that is_staff correctly identifies a staff member.""" - with self.subTest(): - self.assertTrue(self.cog.is_staff(self.admin_member_1)) - self.assertFalse(self.cog.is_staff(self.contrib_member)) - self.assertFalse(self.cog.is_staff(self.no_role_member)) - - def test_has_green_checkmark_correctly_identifies_messages(self): - """Test that has_green_checkmark recognizes messages with checkmarks.""" - with self.subTest(): - self.assertTrue(self.cog.has_green_checkmark(self.checkmark_message)) - self.assertFalse(self.cog.has_green_checkmark(self.thumbs_up_message)) - self.assertFalse(self.cog.has_green_checkmark(self.no_reaction_message)) - - def test_count_custom_duck_emojis(self): - """Test that count_ducks counts custom ducks correctly.""" - count_no_ducks = self.cog.count_ducks(self.thumbs_up_message) - count_one_duck = self.cog.count_ducks(self.yellow_ducky_message) - with self.subTest(): - self.assertEqual(asyncio.run(count_no_ducks), 0) - self.assertEqual(asyncio.run(count_one_duck), 1) - - def test_count_unicode_duck_emojis(self): - """Test that count_ducks counts unicode ducks correctly.""" - count_one_duck = self.cog.count_ducks(self.unicode_duck_message) - count_two_ducks = self.cog.count_ducks(self.double_unicode_duck_message) - - with self.subTest(): - self.assertEqual(asyncio.run(count_one_duck), 1) - self.assertEqual(asyncio.run(count_two_ducks), 2) - - def test_count_mixed_duck_emojis(self): - """Test that count_ducks counts mixed ducks correctly.""" - count_two_ducks = self.cog.count_ducks(self.double_mixed_duck_message) - - with self.subTest(): - self.assertEqual(asyncio.run(count_two_ducks), 2) - - def test_raw_reaction_add_rejects_bot(self): - """Test that send_webhook is not called if the user is a bot.""" - self.text_channel.fetch_message.return_value = self.bot_message - self.bot.get_all_channels.return_value = (self.text_channel,) - - payload = MagicMock( # RawReactionActionEvent - channel_id=self.CHANNEL_ID, - message_id=self.MESSAGE_ID, - user_id=self.BOT_ID, + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + @patch(f"{MODULE_PATH}.DuckPond.relay_message_to_duck_pond", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self, count_ducks, message_relay): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold-1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold+1, True), ) - with self.subTest(): - asyncio.run(self.cog.on_raw_reaction_add(payload)) - self.bot.cog.send_webhook.assert_not_called() + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + count_ducks.reset_mock() + + # Did we relay a message? + has_relayed = message_relay.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + message_relay.assert_called_once_with(message) + message_relay.reset_mock() - def test_raw_reaction_add_rejects_non_staff(self): - """Test that send_webhook is not called if the user is not a member of staff.""" - self.text_channel.fetch_message.return_value = self.contrib_message - self.bot.get_all_channels.return_value = (self.text_channel,) + @helpers.async_test + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - payload = MagicMock( # RawReactionActionEvent - channel_id=self.CHANNEL_ID, - message_id=self.MESSAGE_ID, - user_id=self.CONTRIB_ID, + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), ) + for duck_count, should_readd_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_readd_checkmark=should_readd_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) - with self.subTest(): - asyncio.run(self.cog.on_raw_reaction_add(payload)) - self.bot.cog.send_webhook.assert_not_called() + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) - def test_raw_reaction_add_sends_message_on_valid_input(self): - """Test that send_webhook is called if payload is valid.""" - self.text_channel.fetch_message.return_value = self.unicode_duck_message - self.bot.get_all_channels.return_value = (self.text_channel,) + has_readded_checkmark = message.add_reaction.called + self.assertEqual(should_readd_checkmark, has_readded_checkmark) - payload = MagicMock( # RawReactionActionEvent - channel_id=self.CHANNEL_ID, - message_id=self.MESSAGE_ID, - user_id=self.ADMIN_ID, + if should_readd_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + count_ducks.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" ) - with self.subTest(): - asyncio.run(self.cog.on_raw_reaction_add(payload)) - self.bot.cog.send_webhook.assert_called_once() + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - def test_raw_reaction_remove_rejects_non_checkmarks(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - def test_raw_reaction_remove_prevents_checkmark_removal(self): - """A string decoding to numeric characters is a valid user ID.""" - pass + channel.fetch_message.assert_not_called() class DuckPondSetupTests(unittest.TestCase): @@ -195,7 +526,7 @@ class DuckPondSetupTests(unittest.TestCase): def test_setup(self): """Setup of the cog should log a message at `INFO` level.""" - bot = MockBot() + bot = helpers.MockBot() log = logging.getLogger('bot.cogs.duck_pond') with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: -- cgit v1.2.3 From f212ddeea4de54d6eb75081c13162c2ad64bfeff Mon Sep 17 00:00:00 2001 From: Numerlor Date: Fri, 15 Nov 2019 13:10:19 +0100 Subject: join extra newline --- bot/cogs/doc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index b82eac5fe..20bc010d9 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -336,8 +336,7 @@ class Doc(commands.Cog): ) # Show all symbols with the same name that were renamed in the footer. embed.set_footer( - text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} - if renamed.endswith(f".{symbol}")) + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) ) return embed -- cgit v1.2.3 From a0ed0c1d6c6d3ba32df4d9bb355ffe1a59e8f76b Mon Sep 17 00:00:00 2001 From: Numerlor Date: Fri, 15 Nov 2019 13:13:18 +0100 Subject: Add variable info after comment was deleted Co-authored-by: scargly <29337040+scragly@users.noreply.github.com> --- bot/cogs/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 20bc010d9..76fdcd831 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -180,7 +180,7 @@ class Doc(commands.Cog): return None for group, value in package.items(): - for symbol, (package_name, _, relative_doc_url, _) in value.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): absolute_doc_url = base_url + relative_doc_url if symbol in self.inventories: -- cgit v1.2.3 From f1180d9cd05329f61439c8a45dedb47e841e7216 Mon Sep 17 00:00:00 2001 From: Numerlor Date: Fri, 15 Nov 2019 13:35:44 +0100 Subject: group and order constants --- bot/cogs/doc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 76fdcd831..dc53937ee 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -26,7 +26,6 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) logging.getLogger('urllib3').setLevel(logging.WARNING) -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay NO_OVERRIDE_GROUPS = ( "2to3fixer", "token", @@ -37,8 +36,7 @@ NO_OVERRIDE_GROUPS = ( NO_OVERRIDE_PACKAGES = ( "python", ) -FAILED_REQUEST_RETRY_AMOUNT = 3 -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") + SEARCH_END_TAG_ATTRS = ( "data", "function", @@ -49,8 +47,12 @@ SEARCH_END_TAG_ATTRS = ( "rubric", "sphinxsidebar", ) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: """ -- cgit v1.2.3 From e66ed4b4f534e6fa4178d8b2a82bc486b97affd5 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Sat, 16 Nov 2019 08:26:38 +0700 Subject: Renamed variables to be more explicit, added type hinting for `content_before` and `content_after` --- bot/cogs/moderation/modlog.py | 48 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index ce2a5e1f7..41d7709e4 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -620,42 +620,42 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: """Log message edit event to message change log.""" if ( - not before.guild - or before.guild.id != GuildConstant.id - or before.channel.id in GuildConstant.ignored - or before.author.bot + not msg_before.guild + or msg_before.guild.id != GuildConstant.id + or msg_before.channel.id in GuildConstant.ignored + or msg_before.author.bot ): return - self._cached_edits.append(before.id) + self._cached_edits.append(msg_before.id) - if before.content == after.content: + if msg_before.content == msg_after.content: return - author = before.author - channel = before.channel + author = msg_before.author + channel = msg_before.channel channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" # Getting the difference per words and group them by type - add, remove, same # Note that this is intended grouping without sorting - diff = difflib.ndiff(before.clean_content.split(), after.clean_content.split()) + diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) diff_groups = tuple( (diff_type, tuple(s[2:] for s in diff_words)) for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) ) - _before = [] - _after = [] + content_before: t.List[str] = [] + content_after: t.List[str] = [] for index, (diff_type, words) in enumerate(diff_groups): sub = ' '.join(words) if diff_type == '-': - _before.append(f"[{sub}](http://o.hi)") + content_before.append(f"[{sub}](http://o.hi)") elif diff_type == '+': - _after.append(f"[{sub}](http://o.hi)") + content_after.append(f"[{sub}](http://o.hi)") elif diff_type == ' ': if len(words) > 2: sub = ( @@ -663,31 +663,31 @@ class ModLog(Cog, name="ModLog"): " ... " f"{words[-1] if index < len(diff_groups) - 1 else ''}" ) - _before.append(sub) - _after.append(sub) + content_before.append(sub) + content_after.append(sub) response = ( f"**Author:** {author} (`{author.id}`)\n" f"**Channel:** {channel_name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" + f"**Message ID:** `{msg_before.id}`\n" "\n" - f"**Before**:\n{' '.join(_before)}\n" - f"**After**:\n{' '.join(_after)}\n" + f"**Before**:\n{' '.join(content_before)}\n" + f"**After**:\n{' '.join(content_after)}\n" "\n" - f"[jump to message]({after.jump_url})" + f"[jump to message]({msg_after.jump_url})" ) - if before.edited_at: + if msg_before.edited_at: # Message was previously edited, to assist with self-bot detection, use the edited_at # datetime as the baseline and create a human-readable delta between this edit event # and the last time the message was edited - timestamp = before.edited_at - delta = humanize_delta(relativedelta(after.edited_at, before.edited_at)) + timestamp = msg_before.edited_at + delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) footer = f"Last edited {delta} ago" else: # Message was not previously edited, use the created_at datetime as the baseline, no # delta calculation needed - timestamp = before.created_at + timestamp = msg_before.created_at footer = None await self.send_log_message( -- cgit v1.2.3 From e67822a46621d20ff0a1a27de1322b14432e4eb9 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Sat, 16 Nov 2019 17:04:27 +0100 Subject: Apply suggestions from code review Co-Authored-By: Mark --- tests/bot/cogs/test_duck_pond.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index ceefc286f..8f0c4f068 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -269,7 +269,7 @@ class DuckPondTests(base.LoggingTestCase): self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), ]), - 3+4 + 3 + 4 ), ) @@ -310,7 +310,6 @@ class DuckPondTests(base.LoggingTestCase): @helpers.async_test async def test_relay_message_to_duck_pond_handles_send_attachments_exceptions(self, send_attachments, send_webhook): """The `relay_message_to_duck_pond` method should handle exceptions when calling `send_attachment`.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) @@ -435,9 +434,9 @@ class DuckPondTests(base.LoggingTestCase): async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self, count_ducks, message_relay): """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" test_cases = ( - (constants.DuckPond.threshold-1, False), + (constants.DuckPond.threshold - 1, False), (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold+1, True), + (constants.DuckPond.threshold + 1, True), ) channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) -- cgit v1.2.3 From f96631eba92c2c00b831004f8a70b5de5709a4cd Mon Sep 17 00:00:00 2001 From: scragly <29337040+scragly@users.noreply.github.com> Date: Tue, 19 Nov 2019 10:31:48 +1000 Subject: Relock to d.py 1.2.5 due to API breaking change for emoji. --- Pipfile.lock | 175 ++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 95 insertions(+), 80 deletions(-) diff --git a/Pipfile.lock b/Pipfile.lock index 95955ff89..69caf4646 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -18,11 +18,11 @@ "default": { "aio-pika": { "hashes": [ - "sha256:1dcec3e3e3309e277511dc0d7d157676d0165c174a6a745673fc9cf0510db8f0", - "sha256:dd5a23ca26a4872ee73bd107e4c545bace572cdec2a574aeb61f4062c7774b2a" + "sha256:1da038b3d2c1b49e0e816d87424e702912bb77f9b5197f2bf279217915b4f7ed", + "sha256:29fe851374b86c997a22174c04352b5941bc1c2e36bbf542918ac18a76cfc9d3" ], "index": "pypi", - "version": "==6.1.3" + "version": "==6.3.0" }, "aiodns": { "hashes": [ @@ -62,10 +62,10 @@ }, "aiormq": { "hashes": [ - "sha256:c3e4dd01a2948a75f739fb637334dbb8c6f1a4cecf74d5ed662dc3bab7f39973", - "sha256:e220d3f9477bb2959b729b79bec815148ddb8a7686fc6c3d05d41c88ebd7c59e" + "sha256:afc0d46837b121585e4faec0a7646706429b4e2f5110ae8d0b5cdc3708b4b0e5", + "sha256:dc0fbbc7f8ad5af6a2cc18e00ccc5f925984cde3db6e8fe952c07b7ef157b5f2" ], - "version": "==2.8.0" + "version": "==2.9.1" }, "alabaster": { "hashes": [ @@ -83,10 +83,10 @@ }, "attrs": { "hashes": [ - "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", - "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" ], - "version": "==19.2.0" + "version": "==19.3.0" }, "babel": { "hashes": [ @@ -112,36 +112,41 @@ }, "cffi": { "hashes": [ - "sha256:041c81822e9f84b1d9c401182e174996f0bae9991f33725d059b771744290774", - "sha256:046ef9a22f5d3eed06334d01b1e836977eeef500d9b78e9ef693f9380ad0b83d", - "sha256:066bc4c7895c91812eff46f4b1c285220947d4aa46fa0a2651ff85f2afae9c90", - "sha256:066c7ff148ae33040c01058662d6752fd73fbc8e64787229ea8498c7d7f4041b", - "sha256:2444d0c61f03dcd26dbf7600cf64354376ee579acad77aef459e34efcb438c63", - "sha256:300832850b8f7967e278870c5d51e3819b9aad8f0a2c8dbe39ab11f119237f45", - "sha256:34c77afe85b6b9e967bd8154e3855e847b70ca42043db6ad17f26899a3df1b25", - "sha256:46de5fa00f7ac09f020729148ff632819649b3e05a007d286242c4882f7b1dc3", - "sha256:4aa8ee7ba27c472d429b980c51e714a24f47ca296d53f4d7868075b175866f4b", - "sha256:4d0004eb4351e35ed950c14c11e734182591465a33e960a4ab5e8d4f04d72647", - "sha256:4e3d3f31a1e202b0f5a35ba3bc4eb41e2fc2b11c1eff38b362de710bcffb5016", - "sha256:50bec6d35e6b1aaeb17f7c4e2b9374ebf95a8975d57863546fa83e8d31bdb8c4", - "sha256:55cad9a6df1e2a1d62063f79d0881a414a906a6962bc160ac968cc03ed3efcfb", - "sha256:5662ad4e4e84f1eaa8efce5da695c5d2e229c563f9d5ce5b0113f71321bcf753", - "sha256:59b4dc008f98fc6ee2bb4fd7fc786a8d70000d058c2bbe2698275bc53a8d3fa7", - "sha256:73e1ffefe05e4ccd7bcea61af76f36077b914f92b76f95ccf00b0c1b9186f3f9", - "sha256:a1f0fd46eba2d71ce1589f7e50a9e2ffaeb739fb2c11e8192aa2b45d5f6cc41f", - "sha256:a2e85dc204556657661051ff4bab75a84e968669765c8a2cd425918699c3d0e8", - "sha256:a5457d47dfff24882a21492e5815f891c0ca35fefae8aa742c6c263dac16ef1f", - "sha256:a8dccd61d52a8dae4a825cdbb7735da530179fea472903eb871a5513b5abbfdc", - "sha256:ae61af521ed676cf16ae94f30fe202781a38d7178b6b4ab622e4eec8cefaff42", - "sha256:b012a5edb48288f77a63dba0840c92d0504aa215612da4541b7b42d849bc83a3", - "sha256:d2c5cfa536227f57f97c92ac30c8109688ace8fa4ac086d19d0af47d134e2909", - "sha256:d42b5796e20aacc9d15e66befb7a345454eef794fdb0737d1af593447c6c8f45", - "sha256:dee54f5d30d775f525894d67b1495625dd9322945e7fee00731952e0368ff42d", - "sha256:e070535507bd6aa07124258171be2ee8dfc19119c28ca94c9dfb7efd23564512", - "sha256:e1ff2748c84d97b065cc95429814cdba39bcbd77c9c85c89344b317dc0d9cbff", - "sha256:ed851c75d1e0e043cbf5ca9a8e1b13c4c90f3fbd863dacb01c0808e2b5204201" - ], - "version": "==1.12.3" + "sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42", + "sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04", + "sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5", + "sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54", + "sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba", + "sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57", + "sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396", + "sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12", + "sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97", + "sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43", + "sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db", + "sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3", + "sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b", + "sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579", + "sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346", + "sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159", + "sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652", + "sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e", + "sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a", + "sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506", + "sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f", + "sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d", + "sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c", + "sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20", + "sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858", + "sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc", + "sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a", + "sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3", + "sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e", + "sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410", + "sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25", + "sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b", + "sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d" + ], + "version": "==1.13.2" }, "chardet": { "hashes": [ @@ -152,18 +157,18 @@ }, "deepdiff": { "hashes": [ - "sha256:1123762580af0904621136d117c8397392a244d3ff0fa0a50de57a7939582476", - "sha256:6ab13e0cbb627dadc312deaca9bef38de88a737a9bbdbfbe6e3857748219c127" + "sha256:3457ea7cecd51ba48015d89edbb569358af4d9b9e65e28bdb3209608420627f9", + "sha256:5e2343398e90538edaa59c0c99207e996a3a834fdc878c666376f632a760c35a" ], "index": "pypi", - "version": "==4.0.7" + "version": "==4.0.9" }, "discord-py": { "hashes": [ - "sha256:4684733fa137cc7def18087ae935af615212e423e3dbbe3e84ef01d7ae8ed17d" + "sha256:7c843b523bb011062b453864e75c7b675a03faf573c58d14c9f096e85984329d" ], "index": "pypi", - "version": "==1.2.3" + "version": "==1.2.5" }, "docutils": { "hashes": [ @@ -221,6 +226,7 @@ "sha256:02ca7bf899da57084041bb0f6095333e4d239948ad3169443f454add9f4e9cb4", "sha256:096b82c5e0ea27ce9138bcbb205313343ee66a6e132f25c5ed67e2c8d960a1bc", "sha256:0a920ff98cf1aac310470c644bc23b326402d3ef667ddafecb024e1713d485f1", + "sha256:1409b14bf83a7d729f92e2a7fbfe7ec929d4883ca071b06e95c539ceedb6497c", "sha256:17cae1730a782858a6e2758fd20dd0ef7567916c47757b694a06ffafdec20046", "sha256:17e3950add54c882e032527795c625929613adbd2ce5162b94667334458b5a36", "sha256:1f4f214337f6ee5825bf90a65d04d70aab05526c08191ab888cb5149501923c5", @@ -231,11 +237,14 @@ "sha256:760c12276fee05c36f95f8040180abc7fbebb9e5011447a97cdc289b5d6ab6fc", "sha256:796685d3969815a633827c818863ee199440696b0961e200b011d79b9394bbe7", "sha256:891fe897b49abb7db470c55664b198b1095e4943b9f82b7dcab317a19116cd38", + "sha256:9277562f175d2334744ad297568677056861070399cec56ff06abbe2564d1232", "sha256:a471628e20f03dcdfde00770eeaf9c77811f0c331c8805219ca7b87ac17576c5", "sha256:a63b4fd3e2cabdcc9d918ed280bdde3e8e9641e04f3c59a2a3109644a07b9832", + "sha256:ae88588d687bd476be588010cbbe551e9c2872b816f2da8f01f6f1fda74e1ef0", "sha256:b0b84408d4eabc6de9dd1e1e0bc63e7731e890c0b378a62443e5741cfd0ae90a", "sha256:be78485e5d5f3684e875dab60f40cddace2f5b2a8f7fede412358ab3214c3a6f", "sha256:c27eaed872185f047bb7f7da2d21a7d8913457678c9a100a50db6da890bc28b9", + "sha256:c7fccd08b14aa437fe096c71c645c0f9be0655a9b1a4b7cffc77bcb23b3d61d2", "sha256:c81cb40bff373ab7a7446d6bbca0190bccc5be3448b47b51d729e37799bb5692", "sha256:d11874b3c33ee441059464711cd365b89fa1a9cf19ae75b0c189b01fbf735b84", "sha256:e9c028b5897901361d81a4718d1db217b716424a0283afe9d6735fe0caf70f79", @@ -379,18 +388,18 @@ }, "pyparsing": { "hashes": [ - "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", - "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" + "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", + "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" ], - "version": "==2.4.2" + "version": "==2.4.5" }, "python-dateutil": { "hashes": [ - "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", - "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" + "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", + "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" ], "index": "pypi", - "version": "==2.8.0" + "version": "==2.8.1" }, "python-json-logger": { "hashes": [ @@ -434,10 +443,10 @@ }, "six": { "hashes": [ - "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", - "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", + "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" ], - "version": "==1.12.0" + "version": "==1.13.0" }, "snowballstemmer": { "hashes": [ @@ -448,18 +457,18 @@ }, "soupsieve": { "hashes": [ - "sha256:605f89ad5fdbfefe30cdc293303665eff2d188865d4dbe4eb510bba1edfbfce3", - "sha256:b91d676b330a0ebd5b21719cb6e9b57c57d433671f65b9c28dd3461d9a1ed0b6" + "sha256:bdb0d917b03a1369ce964056fc195cfdff8819c40de04695a80bc813c3cfa1f5", + "sha256:e2c1c5dee4a1c36bcb790e0fabd5492d874b8ebd4617622c4f6a731701060dda" ], - "version": "==1.9.4" + "version": "==1.9.5" }, "sphinx": { "hashes": [ - "sha256:0d586b0f8c2fc3cc6559c5e8fd6124628110514fda0e5d7c82e682d749d2e845", - "sha256:839a3ed6f6b092bb60f492024489cc9e6991360fb9f52ed6361acd510d261069" + "sha256:31088dfb95359384b1005619827eaee3056243798c62724fd3fa4b84ee4d71bd", + "sha256:52286a0b9d7caa31efee301ec4300dbdab23c3b05da1c9024b4e84896fb73d79" ], "index": "pypi", - "version": "==2.2.0" + "version": "==2.2.1" }, "sphinxcontrib-applehelp": { "hashes": [ @@ -564,10 +573,10 @@ }, "attrs": { "hashes": [ - "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", - "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" ], - "version": "==19.2.0" + "version": "==19.3.0" }, "certifi": { "hashes": [ @@ -658,11 +667,11 @@ }, "flake8": { "hashes": [ - "sha256:19241c1cbc971b9962473e4438a2ca19749a7dd002dd1a946eaba171b4114548", - "sha256:8e9dfa3cecb2400b3738a42c54c3043e821682b9c840b0448c0503f781130696" + "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", + "sha256:49356e766643ad15072a789a20915d3c91dc89fd313ccd71802303fd67e4deca" ], "index": "pypi", - "version": "==3.7.8" + "version": "==3.7.9" }, "flake8-annotations": { "hashes": [ @@ -738,6 +747,7 @@ "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af" ], + "markers": "python_version < '3.8'", "version": "==0.23" }, "mccabe": { @@ -770,11 +780,11 @@ }, "pre-commit": { "hashes": [ - "sha256:1d3c0587bda7c4e537a46c27f2c84aa006acc18facf9970bf947df596ce91f3f", - "sha256:fa78ff96e8e9ac94c748388597693f18b041a181c94a4f039ad20f45287ba44a" + "sha256:9f152687127ec90642a2cc3e4d9e1e6240c4eb153615cb02aa1ad41d331cbb6e", + "sha256:c2e4810d2d3102d354947907514a78c5d30424d299dc0fe48f5aa049826e9b50" ], "index": "pypi", - "version": "==1.18.3" + "version": "==1.20.0" }, "pycodestyle": { "hashes": [ @@ -799,10 +809,10 @@ }, "pyparsing": { "hashes": [ - "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", - "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" + "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", + "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" ], - "version": "==2.4.2" + "version": "==2.4.5" }, "pyyaml": { "hashes": [ @@ -841,10 +851,10 @@ }, "six": { "hashes": [ - "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", - "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", + "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" ], - "version": "==1.12.0" + "version": "==1.13.0" }, "snowballstemmer": { "hashes": [ @@ -862,31 +872,36 @@ }, "typed-ast": { "hashes": [ + "sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161", "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", + "sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47", "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", + "sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2", + "sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e", "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", + "sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66", "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" ], "version": "==1.4.0" }, "unittest-xml-reporting": { "hashes": [ - "sha256:140982e4b58e4052d9ecb775525b246a96bfc1fc26097806e05ea06e9166dd6c", - "sha256:d1fbc7a1b6c6680ccfe75b5e9701e5431c646970de049e687b4bb35ba4325d72" + "sha256:358bbdaf24a26d904cc1c26ef3078bca7fc81541e0a54c8961693cc96a6f35e0", + "sha256:9d28ddf6524cf0ff9293f61bd12e792de298f8561a5c945acea63fb437789e0e" ], "index": "pypi", - "version": "==2.5.1" + "version": "==2.5.2" }, "urllib3": { "hashes": [ @@ -898,10 +913,10 @@ }, "virtualenv": { "hashes": [ - "sha256:680af46846662bb38c5504b78bad9ed9e4f3ba2d54f54ba42494fdf94337fe30", - "sha256:f78d81b62d3147396ac33fc9d77579ddc42cc2a98dd9ea38886f616b33bc7fb2" + "sha256:11cb4608930d5fd3afb545ecf8db83fa50e1f96fc4fca80c94b07d2c83146589", + "sha256:d257bb3773e48cac60e475a19b608996c73f4d333b3ba2e4e57d5ac6134e0136" ], - "version": "==16.7.5" + "version": "==16.7.7" }, "zipp": { "hashes": [ -- cgit v1.2.3 From d40b55841201b7546d49f9125fd54d181d67a43f Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 25 Nov 2019 15:21:08 +0100 Subject: Update antimalware.py to be more consistent with other information messages (like the codeblock reminder) & improve code a slight bit --- bot/cogs/antimalware.py | 44 ++++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index ababd6f18..e0c127d9a 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,11 +1,12 @@ import logging -from discord import Message, NotFound +from discord import Message, Embed from discord.ext.commands import Bot, Cog from bot.constants import AntiMalware as AntiMalwareConfig, Channels log = logging.getLogger(__name__) +PASTE_URL = "https://paste.pythondiscord.com/" class AntiMalware(Cog): @@ -17,37 +18,32 @@ class AntiMalware(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" - rejected_attachments = False - detected_pyfile = False + if len(message.attachments) == 0: + return + + embed = Embed() for attachment in message.attachments: if attachment.filename.lower().endswith('.py'): - detected_pyfile = True + embed.description = ( + "It looks like you tried to attach a Python file - please " + f"use a code-pasting service such as [{PASTE_URL}]" + f"({PASTE_URL}) instead." + ) break # Other detections irrelevant because we prioritize the .py message. if not attachment.filename.lower().endswith(tuple(AntiMalwareConfig.whitelist)): - rejected_attachments = True - - if detected_pyfile or rejected_attachments: - # Send a message to the user indicating the problem (with special treatment for .py) - author = message.author - if detected_pyfile: - msg = ( - f"{author.mention}, it looks like you tried to attach a Python file - please " - f"use a code-pasting service such as https://paste.pythondiscord.com/ instead." - ) - else: meta_channel = self.bot.get_channel(Channels.meta) - msg = ( - f"{author.mention}, it looks like you tried to attach a file type we don't " - f"allow. Feel free to ask in {meta_channel.mention} if you think this is a mistake." + embed.description = ( + "It looks like you tried to attach a file type that we " + "do not allow. We currently allow the following file " + f"types: **{', '.join(AntiMalwareConfig.whitelist)}**. \n\n" + f"Feel free to ask in {meta_channel.mention} if you think " + "this is a mistake." ) - - await message.channel.send(msg) + if embed.description: + await message.channel.send(message.author.mention, embed=embed) # Delete the offending message: - try: - await message.delete() - except NotFound: - log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + await message.delete() def setup(bot: Bot) -> None: -- cgit v1.2.3 From 98f4aee9b5b49628b86d0b9e1c952abb9389a839 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 25 Nov 2019 15:24:22 +0100 Subject: Forgot the word 'hey' --- bot/cogs/antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index e0c127d9a..4e8a3269b 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -40,7 +40,7 @@ class AntiMalware(Cog): "this is a mistake." ) if embed.description: - await message.channel.send(message.author.mention, embed=embed) + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) # Delete the offending message: await message.delete() -- cgit v1.2.3 From bc96cf43825dcdfea5f819e2bf52de46371b0b58 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 25 Nov 2019 15:56:57 +0100 Subject: Change order of imports --- bot/cogs/antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 4e8a3269b..24ce501c3 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,6 +1,6 @@ import logging -from discord import Message, Embed +from discord import Embed, Message from discord.ext.commands import Bot, Cog from bot.constants import AntiMalware as AntiMalwareConfig, Channels -- cgit v1.2.3 From 0a417f4828e229516d552d21ac678a8dc150beed Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 25 Nov 2019 19:50:20 +0100 Subject: Update PASTE_URL constant to be pydis instead of pythondiscord --- bot/cogs/antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 24ce501c3..72db5bc11 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -6,7 +6,7 @@ from discord.ext.commands import Bot, Cog from bot.constants import AntiMalware as AntiMalwareConfig, Channels log = logging.getLogger(__name__) -PASTE_URL = "https://paste.pythondiscord.com/" +PASTE_URL = "https://paste.pydis.com/" class AntiMalware(Cog): -- cgit v1.2.3 From c73c9bae2767da1a6dff5b4098d4af50a61aabe5 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 25 Nov 2019 20:17:08 +0100 Subject: Make requested tweaks: Use URL constant from constants.py, re-add try/except block and implement the changes requested by Ava --- bot/cogs/antimalware.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 72db5bc11..745dd8082 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,12 +1,11 @@ import logging -from discord import Embed, Message +from discord import Embed, Message, NotFound from discord.ext.commands import Bot, Cog -from bot.constants import AntiMalware as AntiMalwareConfig, Channels +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs log = logging.getLogger(__name__) -PASTE_URL = "https://paste.pydis.com/" class AntiMalware(Cog): @@ -18,32 +17,35 @@ class AntiMalware(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" - if len(message.attachments) == 0: + if not message.attachments: return embed = Embed() for attachment in message.attachments: - if attachment.filename.lower().endswith('.py'): + filename = attachment.filename.lower() + if filename.endswith('.py'): embed.description = ( - "It looks like you tried to attach a Python file - please " - f"use a code-pasting service such as [{PASTE_URL}]" - f"({PASTE_URL}) instead." + f"It looks like you tried to attach a Python file - please " + f"use a code-pasting service such as {URLs.paste_service}" ) break # Other detections irrelevant because we prioritize the .py message. - if not attachment.filename.lower().endswith(tuple(AntiMalwareConfig.whitelist)): + if not filename.endswith(tuple(AntiMalwareConfig.whitelist)): + whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) meta_channel = self.bot.get_channel(Channels.meta) embed.description = ( - "It looks like you tried to attach a file type that we " - "do not allow. We currently allow the following file " - f"types: **{', '.join(AntiMalwareConfig.whitelist)}**. \n\n" - f"Feel free to ask in {meta_channel.mention} if you think " - "this is a mistake." + f"It looks like you tried to attach a file type that we " + f"do not allow. We currently allow the following file " + f"types: **{whitelisted_types}**. \n\n Feel free to ask " + f"in {meta_channel.mention} if you think this is a mistake." ) if embed.description: await message.channel.send(f"Hey {message.author.mention}!", embed=embed) # Delete the offending message: - await message.delete() + try: + await message.delete() + except NotFound: + log.info(f"Tried to delete message `{message.id}`, but message could not be found.") def setup(bot: Bot) -> None: -- cgit v1.2.3 From 686936646526332bcb018158488253b85b124350 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 13:54:41 +0700 Subject: Implemented `get_duration()` for `bot.utils.time` --- bot/utils/time.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index 2aea2c099..740ede0d3 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,12 +1,21 @@ import asyncio import datetime -from typing import Optional +from typing import List, Optional import dateutil.parser from dateutil.relativedelta import relativedelta RFC1123_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" INFRACTION_FORMAT = "%Y-%m-%d %H:%M" +TIME_MARKS = ( + (60, 'second'), # 1 minute + (60, 'minute'), # 1 hour + (24, 'hour'), # 1 day + (7, 'day'), # 1 week + (4, 'week'), # 1 month + (12, 'month'), # 1 year + (999, 'year') # dumb the rest as year, max 999 +) def _stringify_time_unit(value: int, unit: str) -> str: @@ -111,3 +120,28 @@ async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] def format_infraction(timestamp: str) -> str: """Format an infraction timestamp to a more readable ISO 8601 format.""" return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) + + +def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> str: + """ + Get the duration between two datetime, in human readable format. + + Will return the two biggest units avaiable, for example: + - 11 hours, 59 minutes + - 1 week, 6 minutes + - 7 months, 2 weeks + - 3 years, 3 months + - 5 minutes + + :param date_from: A datetime.datetime object. + :param date_to: A datetime.datetime object. + """ + div = abs(date_from - date_to).total_seconds() + results: List[str] = [] + for unit, name in TIME_MARKS: + div, amount = divmod(div, unit) + if amount > 0: + plural = 's' if amount > 1 else '' + results.append(f"{amount:.0f} {name}{plural}") + # We have to reverse the order of units because currently it's smallest -> largest + return ', '.join(results[::-1][:2]) -- cgit v1.2.3 From 66ffef0c0901ff00a01081eca398fac6aac3ed67 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 14:00:55 +0700 Subject: Added pytest for `get_duration()` --- tests/utils/test_time.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 4baa6395c..29aca5cfe 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -60,3 +60,20 @@ def test_wait_until(sleep_patch): assert asyncio.run(time.wait_until(then, start)) is None sleep_patch.assert_called_once_with(10 * 60) + + +@pytest.mark.parametrize( + ('date_from', 'date_to', 'expected'), + ( + (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), '11 hours, 59 minutes'), + (datetime(2019, 12, 12), datetime(2019, 12, 11, 23, 59), '1 minute'), + (datetime(2019, 11, 23, 20, 9), datetime(2019, 11, 30, 20, 15), '1 week, 6 minutes'), + (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), '7 months, 2 weeks'), + (datetime(2019, 11, 23, 20, 58), datetime(2019, 11, 23, 21, 3), '5 minutes'), + (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 24, 0, 0), '1 minute'), + (datetime(2019, 11, 23, 23, 59), datetime(2022, 11, 23, 23, 0), '3 years, 3 months'), + (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 23, 23, 49, 5), '9 minutes, 55 seconds'), + ) +) +def test_get_duration(date_from: datetime, date_to: datetime, expected: str): + assert time.get_duration(date_from, date_to) == expected -- cgit v1.2.3 From dadb91573c519c1444608ce0cce3de7b01b860a9 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 14:34:36 +0700 Subject: Implemented `get_duration_from_expiry()` which call `get_duration()` for `expiry` and `datetime.utcnow()` --- bot/utils/time.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/bot/utils/time.py b/bot/utils/time.py index 740ede0d3..00f39b940 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -145,3 +145,21 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st results.append(f"{amount:.0f} {name}{plural}") # We have to reverse the order of units because currently it's smallest -> largest return ', '.join(results[::-1][:2]) + + +def get_duration_from_expiry(expiry: str) -> str: + """ + Get the duration between datetime.utcnow() and an expiry, in human readable format. + + Will return the two biggest units avaiable, for example: + - 11 hours, 59 minutes + - 1 week, 6 minutes + - 7 months, 2 weeks + - 3 years, 3 months + - 5 minutes + + :param expiry: A string. + """ + date_from = datetime.datetime.utcnow() + date_to = dateutil.parser.isoparse(expiry) + return get_duration(date_from, date_to) -- cgit v1.2.3 From 4ee01649786edcd9b0bbb88d55f1672953afc6fe Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 14:39:18 +0700 Subject: Fixed TypeError raised by substracting offset-naive and offset-aware datetimes ( removed tzinfo from expiry ) --- bot/utils/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index 00f39b940..fc003f9e2 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -161,5 +161,5 @@ def get_duration_from_expiry(expiry: str) -> str: :param expiry: A string. """ date_from = datetime.datetime.utcnow() - date_to = dateutil.parser.isoparse(expiry) + date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None) return get_duration(date_from, date_to) -- cgit v1.2.3 From 1c84213045f778ef0739b474b8a2862ccf1a620b Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 15:05:43 +0700 Subject: Added test for `get_duration_from_expiry()` --- tests/utils/test_time.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 29aca5cfe..0afffd9b1 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -77,3 +77,20 @@ def test_wait_until(sleep_patch): ) def test_get_duration(date_from: datetime, date_to: datetime, expected: str): assert time.get_duration(date_from, date_to) == expected + + +@pytest.mark.parametrize( + ('expiry', 'date_from', 'expected'), + ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), '11 hours, 59 minutes'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), '1 week, 6 minutes'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), '7 months, 2 weeks'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), '5 minutes'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), '3 years, 3 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), '9 minutes, 55 seconds'), + ) +) +def test_get_duration_from_expiry(expiry: str, date_from: datetime, expected: str): + assert time.get_duration_from_expiry(expiry, date_from) == expected -- cgit v1.2.3 From 44f5ae308f69aa1e3349e1a350590e58302076cb Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 15:06:47 +0700 Subject: Updated `bot.utils.time.get_duration_from_expiry()` to accept an optional `date_from` ( for pytest and more control over the behaviour ) --- bot/utils/time.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index fc003f9e2..533b7ef83 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -147,7 +147,7 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st return ', '.join(results[::-1][:2]) -def get_duration_from_expiry(expiry: str) -> str: +def get_duration_from_expiry(expiry: str, date_from: datetime = None) -> str: """ Get the duration between datetime.utcnow() and an expiry, in human readable format. @@ -160,6 +160,6 @@ def get_duration_from_expiry(expiry: str) -> str: :param expiry: A string. """ - date_from = datetime.datetime.utcnow() + date_from = date_from or datetime.datetime.utcnow() date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None) return get_duration(date_from, date_to) -- cgit v1.2.3 From 91b213227bb83a3e4d8be1f526b45c3c6d73fbc0 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 15:07:21 +0700 Subject: Added expiry duration when applying infraction ( including in the embed sent to user ) --- bot/cogs/moderation/scheduler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 49b61f35e..9e987d9ee 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -84,12 +84,15 @@ class InfractionScheduler(Scheduler): icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] expiry = infraction["expires_at"] + expiry_at = expiry id_ = infraction['id'] log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") if expiry: + duration = time.get_duration_from_expiry(expiry) expiry = time.format_infraction(expiry) + expiry_at = f"{expiry} ({duration})" # Default values for the confirmation message and mod log. confirm_msg = f":ok_hand: applied" @@ -98,11 +101,11 @@ class InfractionScheduler(Scheduler): if infr_type in ("note", "warning"): expiry_msg = "" else: - expiry_msg = f" until {expiry}" if expiry else " permanently" + expiry_msg = f" until {expiry_at}" if expiry else " permanently" dm_result = "" dm_log_text = "" - expiry_log_text = f"Expires: {expiry}" if expiry else "" + expiry_log_text = f"Expires: {expiry_at}" if expiry else "" log_title = "applied" log_content = None @@ -112,7 +115,7 @@ class InfractionScheduler(Scheduler): user = await self.bot.fetch_user(user.id) # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry, reason, icon): + if await utils.notify_infraction(user, infr_type, expiry_at, reason, icon): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" else: -- cgit v1.2.3 From f737fd4f6e0a351a95af856af7addf596f65ee5b Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 15:32:39 +0700 Subject: Fixed "14 minutes, 60 seconds" by rounding `.total_seconds()` in `bot.utils.time.get_durations()` --- bot/utils/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bot/utils/time.py b/bot/utils/time.py index 533b7ef83..873de21f0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -137,6 +137,7 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st :param date_to: A datetime.datetime object. """ div = abs(date_from - date_to).total_seconds() + div = round(div, 0) # to avoid (14 minutes, 60 seconds) results: List[str] = [] for unit, name in TIME_MARKS: div, amount = divmod(div, unit) -- cgit v1.2.3 From 2dc74fc6d97e32cbb9cad1dd2797b02a669b3793 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 15:42:19 +0700 Subject: Added duration until expiration for infraction searching. --- bot/cogs/moderation/management.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 44a508436..5c63b19ce 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -2,6 +2,7 @@ import asyncio import logging import textwrap import typing as t +from datetime import datetime import discord from discord.ext import commands @@ -97,7 +98,8 @@ class ModManagement(commands.Cog): elif duration is not None: request_data['expires_at'] = duration.isoformat() expiry = duration.strftime(time.INFRACTION_FORMAT) - confirm_messages.append(f"set to expire on {expiry}") + duration_string = time.get_duration(duration, datetime.utcnow()) + confirm_messages.append(f"set to expire on {expiry} ({duration_string})") else: confirm_messages.append("expiry unchanged") @@ -234,7 +236,8 @@ class ModManagement(commands.Cog): if infraction["expires_at"] is None: expires = "*Permanent*" else: - expires = time.format_infraction(infraction["expires_at"]) + duration = time.get_duration_from_expiry(infraction["expires_at"]) + expires = f"{time.format_infraction(infraction['expires_at'])} ({duration})" lines = textwrap.dedent(f""" {"**===============**" if active else "==============="} -- cgit v1.2.3 From f5f92b76fb536beedbbfacd97f2977ed1c2c8606 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:16:06 +0700 Subject: Changed `get_duration_from_expiry()` to return the `time (duration)` or a `''` --- bot/utils/time.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index 873de21f0..311a0a576 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -148,7 +148,7 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st return ', '.join(results[::-1][:2]) -def get_duration_from_expiry(expiry: str, date_from: datetime = None) -> str: +def get_duration_from_expiry(expiry: str = None, date_from: datetime = None) -> Optional[str]: """ Get the duration between datetime.utcnow() and an expiry, in human readable format. @@ -161,6 +161,15 @@ def get_duration_from_expiry(expiry: str, date_from: datetime = None) -> str: :param expiry: A string. """ + if not expiry: + return None + date_from = date_from or datetime.datetime.utcnow() date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None) - return get_duration(date_from, date_to) + + expiry_formatted = format_infraction(expiry) + + duration = get_duration(date_from, date_to) + duration_formatted = f" ({duration})" if duration else '' + + return f"{expiry_formatted}{duration_formatted}" -- cgit v1.2.3 From 0898ce98b6b2a9ac59369d8665ff51a077405c03 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:16:43 +0700 Subject: Refactored `management.py` to use the new `get_duration_from_expiry()` --- bot/cogs/moderation/management.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 5c63b19ce..5221baa81 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -97,9 +97,8 @@ class ModManagement(commands.Cog): confirm_messages.append("marked as permanent") elif duration is not None: request_data['expires_at'] = duration.isoformat() - expiry = duration.strftime(time.INFRACTION_FORMAT) - duration_string = time.get_duration(duration, datetime.utcnow()) - confirm_messages.append(f"set to expire on {expiry} ({duration_string})") + expiry = time.get_duration_from_expiry(request_data['expires_at']) + confirm_messages.append(f"set to expire on {expiry}") else: confirm_messages.append("expiry unchanged") @@ -236,8 +235,8 @@ class ModManagement(commands.Cog): if infraction["expires_at"] is None: expires = "*Permanent*" else: - duration = time.get_duration_from_expiry(infraction["expires_at"]) - expires = f"{time.format_infraction(infraction['expires_at'])} ({duration})" + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.get_duration_from_expiry(infraction["expires_at"], date_from) lines = textwrap.dedent(f""" {"**===============**" if active else "==============="} -- cgit v1.2.3 From 2147adc592cf62a9cc21b3ebf5adeec544b4cac2 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:17:06 +0700 Subject: Refactored `scheduler.py` to use the new `get_duration_from_expiry()` --- bot/cogs/moderation/scheduler.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 9e987d9ee..729763322 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -83,17 +83,11 @@ class InfractionScheduler(Scheduler): infr_type = infraction["type"] icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] - expiry = infraction["expires_at"] - expiry_at = expiry + expiry = time.get_duration_from_expiry(infraction["expires_at"]) id_ = infraction['id'] log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") - if expiry: - duration = time.get_duration_from_expiry(expiry) - expiry = time.format_infraction(expiry) - expiry_at = f"{expiry} ({duration})" - # Default values for the confirmation message and mod log. confirm_msg = f":ok_hand: applied" @@ -101,11 +95,11 @@ class InfractionScheduler(Scheduler): if infr_type in ("note", "warning"): expiry_msg = "" else: - expiry_msg = f" until {expiry_at}" if expiry else " permanently" + expiry_msg = f" until {expiry}" if expiry else " permanently" dm_result = "" dm_log_text = "" - expiry_log_text = f"Expires: {expiry_at}" if expiry else "" + expiry_log_text = f"Expires: {expiry}" if expiry else "" log_title = "applied" log_content = None @@ -115,7 +109,7 @@ class InfractionScheduler(Scheduler): user = await self.bot.fetch_user(user.id) # Accordingly display whether the user was successfully notified via DM. - if await utils.notify_infraction(user, infr_type, expiry_at, reason, icon): + if await utils.notify_infraction(user, infr_type, expiry, reason, icon): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" else: -- cgit v1.2.3 From 493cd411ce4d7f5dbddfe40003af0049015d0ebb Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:21:23 +0700 Subject: Updated test cases for `get_duration_from_expiry()` --- tests/utils/test_time.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 0afffd9b1..1df96beb8 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -82,14 +82,15 @@ def test_get_duration(date_from: datetime, date_to: datetime, expected: str): @pytest.mark.parametrize( ('expiry', 'date_from', 'expected'), ( - ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), '11 hours, 59 minutes'), - ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), '1 minute'), - ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), '1 week, 6 minutes'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), '7 months, 2 weeks'), - ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), '5 minutes'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), '1 minute'), - ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), '3 years, 3 months'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), '9 minutes, 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), '2019-12-12 00:01 (11 hours, 59 minutes)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), '2019-11-23 20:09 (1 week, 6 minutes)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), '2019-11-23 20:09 (7 months, 2 weeks)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), '2019-11-23 20:58 (5 minutes)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), '2019-11-23 23:59 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), '2019-11-23 23:59 (3 years, 3 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), '2019-11-23 23:59 (9 minutes, 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), None), ) ) def test_get_duration_from_expiry(expiry: str, date_from: datetime, expected: str): -- cgit v1.2.3 From f47ec6f65abe571110885e11cfc68d84e7f7b45e Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:52:08 +0700 Subject: Updated docstrings, allow passing `parts: Optional[int] = 2` to helper functions to return more than just 2 parts of the duration. --- bot/utils/time.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index 311a0a576..d3000a7c2 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -122,11 +122,11 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> str: +def get_duration(date_from: datetime.datetime, date_to: datetime.datetime, parts: Optional[int] = 2) -> str: """ Get the duration between two datetime, in human readable format. - Will return the two biggest units avaiable, for example: + Will return number of units if avaiable, for example: - 11 hours, 59 minutes - 1 week, 6 minutes - 7 months, 2 weeks @@ -135,6 +135,7 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st :param date_from: A datetime.datetime object. :param date_to: A datetime.datetime object. + :param parts: An int, defauted to two - the amount of units to return. """ div = abs(date_from - date_to).total_seconds() div = round(div, 0) # to avoid (14 minutes, 60 seconds) @@ -144,11 +145,16 @@ def get_duration(date_from: datetime.datetime, date_to: datetime.datetime) -> st if amount > 0: plural = 's' if amount > 1 else '' results.append(f"{amount:.0f} {name}{plural}") + parts = parts if parts is not None else len(results) # allow passing None directly to return all parts # We have to reverse the order of units because currently it's smallest -> largest - return ', '.join(results[::-1][:2]) + return ', '.join(results[::-1][:parts]) -def get_duration_from_expiry(expiry: str = None, date_from: datetime = None) -> Optional[str]: +def get_duration_from_expiry( + expiry: str = None, + date_from: datetime.datetime = None, + parts: Optional[int] = 2 +) -> Optional[str]: """ Get the duration between datetime.utcnow() and an expiry, in human readable format. @@ -159,7 +165,9 @@ def get_duration_from_expiry(expiry: str = None, date_from: datetime = None) -> - 3 years, 3 months - 5 minutes - :param expiry: A string. + :param expiry: A string. If not passed in, will early return a None ( Permanent infraction ). + :param date_from: A datetime.datetime object. If not passed in, will use datetime.utcnow(). + :param parts: An int, to show how many parts will be returned ( year - month or year - month - week - day ...). """ if not expiry: return None @@ -169,7 +177,7 @@ def get_duration_from_expiry(expiry: str = None, date_from: datetime = None) -> expiry_formatted = format_infraction(expiry) - duration = get_duration(date_from, date_to) + duration = get_duration(date_from, date_to, parts) duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" -- cgit v1.2.3 From b12fe618f73a0dfc31cd5ba4a9572ac0401d65ea Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 27 Nov 2019 17:52:38 +0700 Subject: Updated test cases for `parts: Optional[int]` --- tests/utils/test_time.py | 55 ++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 1df96beb8..7bde92506 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -63,35 +63,44 @@ def test_wait_until(sleep_patch): @pytest.mark.parametrize( - ('date_from', 'date_to', 'expected'), + ('date_from', 'date_to', 'parts', 'expected'), ( - (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), '11 hours, 59 minutes'), - (datetime(2019, 12, 12), datetime(2019, 12, 11, 23, 59), '1 minute'), - (datetime(2019, 11, 23, 20, 9), datetime(2019, 11, 30, 20, 15), '1 week, 6 minutes'), - (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), '7 months, 2 weeks'), - (datetime(2019, 11, 23, 20, 58), datetime(2019, 11, 23, 21, 3), '5 minutes'), - (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 24, 0, 0), '1 minute'), - (datetime(2019, 11, 23, 23, 59), datetime(2022, 11, 23, 23, 0), '3 years, 3 months'), - (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 23, 23, 49, 5), '9 minutes, 55 seconds'), + (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), 2, '11 hours, 59 minutes'), + (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), 1, '11 hours'), + (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), None, '11 hours, 59 minutes, 5 seconds'), + (datetime(2019, 12, 12, 0, 0), datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + (datetime(2019, 11, 23, 20, 9), datetime(2019, 11, 30, 20, 15), 2, '1 week, 6 minutes'), + (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), 2, '7 months, 2 weeks'), + (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), + None, '7 months, 2 weeks, 1 day, 23 hours, 54 minutes'), + (datetime(2019, 11, 23, 20, 58), datetime(2019, 11, 23, 21, 3), 2, '5 minutes'), + (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 24, 0, 0), 2, '1 minute'), + (datetime(2019, 11, 23, 23, 59), datetime(2022, 11, 23, 23, 0), 2, '3 years, 3 months'), + (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes, 55 seconds'), ) ) -def test_get_duration(date_from: datetime, date_to: datetime, expected: str): - assert time.get_duration(date_from, date_to) == expected +def test_get_duration(date_from: datetime, date_to: datetime, parts: int, expected: str): + assert time.get_duration(date_from, date_to, parts) == expected @pytest.mark.parametrize( - ('expiry', 'date_from', 'expected'), + ('expiry', 'date_from', 'parts', 'expected'), ( - ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), '2019-12-12 00:01 (11 hours, 59 minutes)'), - ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), '2019-12-12 00:00 (1 minute)'), - ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), '2019-11-23 20:09 (1 week, 6 minutes)'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), '2019-11-23 20:09 (7 months, 2 weeks)'), - ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), '2019-11-23 20:58 (5 minutes)'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), '2019-11-23 23:59 (1 minute)'), - ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), '2019-11-23 23:59 (3 years, 3 months)'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), '2019-11-23 23:59 (9 minutes, 55 seconds)'), - (None, datetime(2019, 11, 23, 23, 49, 5), None), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), 2, '2019-12-12 00:01 (11 hours, 59 minutes)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), 1, '2019-12-12 00:01 (11 hours)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), + None, '2019-12-12 00:01 (11 hours, 59 minutes, 5 seconds)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), 2, '2019-11-23 20:09 (1 week, 6 minutes)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (7 months, 2 weeks)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), None, + '2019-11-23 20:09 (7 months, 2 weeks, 1 day, 23 hours, 54 minutes)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), 2, '2019-11-23 23:59 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), 2, '2019-11-23 23:59 (3 years, 3 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '2019-11-23 23:59 (9 minutes, 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), ) ) -def test_get_duration_from_expiry(expiry: str, date_from: datetime, expected: str): - assert time.get_duration_from_expiry(expiry, date_from) == expected +def test_get_duration_from_expiry(expiry: str, date_from: datetime, parts: int, expected: str): + assert time.get_duration_from_expiry(expiry, date_from, parts) == expected -- cgit v1.2.3 From 4be37c1486ccb0a8fb680cb6dce51f8ad8028569 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Wed, 27 Nov 2019 17:42:48 +0100 Subject: Move duckpond payload emoji check to method to create testable unit I moved the check that tests if a payload contains a duck emoji to a separate method. This makes it easier to test this part of the code as a separate unit than when it's contained in the larger event listener. In addition, I kaizened the name `relay_message_to_duckpond` to the less verbose `relay_message`; that's already clear enough. --- bot/cogs/duck_pond.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 68fb09408..2d25cd17e 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -91,7 +91,7 @@ class DuckPond(Cog): duck_reactors.append(user.id) return duck_count - async def relay_message_to_duck_pond(self, message: Message) -> None: + async def relay_message(self, message: Message) -> None: """Relays the message's content and attachments to the duck pond channel.""" clean_content = message.clean_content @@ -120,6 +120,17 @@ class DuckPond(Cog): await message.add_reaction("✅") + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + @Cog.listener() async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: """ @@ -130,10 +141,7 @@ class DuckPond(Cog): send the message off to the duck pond. """ # Is the emoji in the reaction a duck? - if payload.emoji.is_custom_emoji(): - if payload.emoji.id not in constants.DuckPond.custom_emojis: - return - elif payload.emoji.name != "🦆": + if not self._payload_has_duckpond_emoji(payload): return channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) @@ -153,7 +161,7 @@ class DuckPond(Cog): # If we've got more than the required amount of ducks, send the message to the duck_pond. if duck_count >= constants.DuckPond.threshold: - await self.relay_message_to_duck_pond(message) + await self.relay_message(message) @Cog.listener() async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: -- cgit v1.2.3 From 80d84e19d6877d2cbf2a6ce5029c1fd96b286b1e Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff <33516116+SebastiaanZ@users.noreply.github.com> Date: Wed, 27 Nov 2019 17:46:56 +0100 Subject: Apply review comments to duckpond's unit tests https://github.com/python-discord/bot/pull/621 I've changed to unit tests according to the comments made on the issue. Most changes are straightforward enough, but, for context, see the PR linked above. --- tests/bot/cogs/test_duck_pond.py | 200 +++++++++++++++++++++++++-------------- 1 file changed, 128 insertions(+), 72 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 8f0c4f068..b801e86f1 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -72,8 +72,7 @@ class DuckPondTests(base.LoggingTestCase): self.assertEqual(len(log_watcher.records), 1) - [record] = log_watcher.records - self.assertEqual(record.message, f"Failed to fetch webhook with id `{self.cog.webhook_id}`") + record = log_watcher.records[0] self.assertEqual(record.levelno, logging.ERROR) def test_is_staff_returns_correct_values_based_on_instance_passed(self): @@ -99,15 +98,15 @@ class DuckPondTests(base.LoggingTestCase): ( "No green check mark reactions", helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji), - helpers.MockReaction(emoji=self.thumbs_up_emoji) + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) ]), False ), ( "Green check mark reaction, but not from our bot", helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) ]), False @@ -115,7 +114,7 @@ class DuckPondTests(base.LoggingTestCase): ( "Green check mark reaction, with one from the bot", helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) ]), True @@ -160,8 +159,7 @@ class DuckPondTests(base.LoggingTestCase): self.assertEqual(len(log_watcher.records), 1) - [record] = log_watcher.records - self.assertEqual(record.message, "Failed to send a message to the Duck Pool webhook") + record = log_watcher.records[0] self.assertEqual(record.levelno, logging.ERROR) def _get_reaction( @@ -250,10 +248,12 @@ class DuckPondTests(base.LoggingTestCase): # A staffer with multiple duck reactions only counts once ( "Two different duck reactions from the same staffer", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ]), + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), 1 ), # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) @@ -265,10 +265,12 @@ class DuckPondTests(base.LoggingTestCase): # We correctly sum when multiple reactions are provided. ( "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ]), + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), 3 + 4 ), ) @@ -279,8 +281,8 @@ class DuckPondTests(base.LoggingTestCase): self.assertEqual(expected_count, actual_count) @helpers.async_test - async def test_relay_message_to_duck_pond_correctly_relays_content_and_attachments(self): - """The `relay_message_to_duck_pond` method should correctly relay message content and attachments.""" + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" send_attachments_path = f"{MODULE_PATH}.send_attachments" @@ -297,41 +299,47 @@ class DuckPondTests(base.LoggingTestCase): with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message_to_duck_pond(message) + await self.cog.relay_message(message) self.assertEqual(expect_webhook_call, send_webhook.called) self.assertEqual(expect_attachment_call, send_attachments.called) message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.reset_mock() - @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) @helpers.async_test - async def test_relay_message_to_duck_pond_handles_send_attachments_exceptions(self, send_attachments, send_webhook): - """The `relay_message_to_duck_pond` method should handle exceptions when calling `send_attachment`.""" + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) self.cog.webhook = helpers.MockAsyncWebhook() log = logging.getLogger("bot.cogs.duck_pond") - # Subtests for the first `except` block for side_effect in side_effects: send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message_to_duck_pond(message) + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - self.assertEqual(send_webhook.call_count, 2) - send_webhook.reset_mock() + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") - # Subtests for the second `except` block side_effect = discord.HTTPException(MagicMock(), "") send_attachments.side_effect = side_effect with self.subTest(side_effect=type(side_effect).__name__): with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message_to_duck_pond(message) + await self.cog.relay_message(message) send_webhook.assert_called_once_with( content=message.clean_content, @@ -341,10 +349,75 @@ class DuckPondTests(base.LoggingTestCase): self.assertEqual(len(log_watcher.records), 1) - [record] = log_watcher.records - self.assertEqual(record.message, "Failed to send an attachment to the webhook") + record = log_watcher.records[0] self.assertEqual(record.levelno, logging.ERROR) + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + @helpers.async_test + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + def _raw_reaction_mocks(self, channel_id, message_id, user_id): """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" channel = helpers.MockTextChannel(id=channel_id) @@ -361,22 +434,6 @@ class DuckPondTests(base.LoggingTestCase): return channel, message, member, payload - @helpers.async_test - async def test_on_raw_reaction_add_returns_for_non_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - payload_custom_emoji = MagicMock(label="Non-Duck Custom Emoji") - payload_custom_emoji.emoji.is_custom_emoji.return_value = True - payload_custom_emoji.emoji.id = 12345 - - payload_unicode_emoji = MagicMock(label="Non-Duck Unicode Emoji") - payload_unicode_emoji.emoji.is_custom_emoji.return_value = False - payload_unicode_emoji.emoji.name = self.thumbs_up_emoji - - for payload in (payload_custom_emoji, payload_unicode_emoji): - with self.subTest(case=payload.label), patch(f"{MODULE_PATH}.discord.utils.get") as discord_utils_get: - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - discord_utils_get.assert_not_called() - @helpers.async_test async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" @@ -428,10 +485,8 @@ class DuckPondTests(base.LoggingTestCase): # Assert that we've made it past `self.is_staff` is_staff.assert_called_once() - @patch(f"{MODULE_PATH}.DuckPond.relay_message_to_duck_pond", new_callable=helpers.AsyncMock) - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) @helpers.async_test - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self, count_ducks, message_relay): + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" test_cases = ( (constants.DuckPond.threshold - 1, False), @@ -444,21 +499,21 @@ class DuckPondTests(base.LoggingTestCase): payload.emoji = self.duck_pond_emoji for duck_count, should_relay in test_cases: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) - # Confirm that we've made it past counting - count_ducks.assert_called_once() - count_ducks.reset_mock() + # Confirm that we've made it past counting + count_ducks.assert_called_once() - # Did we relay a message? - has_relayed = message_relay.called - self.assertEqual(has_relayed, should_relay) + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) - if should_relay: - message_relay.assert_called_once_with(message) - message_relay.reset_mock() + if should_relay: + relay_message.assert_called_once_with(message) @helpers.async_test async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): @@ -479,10 +534,10 @@ class DuckPondTests(base.LoggingTestCase): (constants.DuckPond.threshold, True), (constants.DuckPond.threshold + 1, True), ) - for duck_count, should_readd_checkmark in test_cases: + for duck_count, should_re_add_checkmark in test_cases: with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_readd_checkmark=should_readd_checkmark): + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): await self.cog.on_raw_reaction_remove(payload) # Check if we fetched the message @@ -491,16 +546,15 @@ class DuckPondTests(base.LoggingTestCase): # Check if we actually counted the number of ducks count_ducks.assert_called_once_with(message) - has_readded_checkmark = message.add_reaction.called - self.assertEqual(should_readd_checkmark, has_readded_checkmark) + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - if should_readd_checkmark: + if should_re_add_checkmark: message.add_reaction.assert_called_once_with(self.checkmark_emoji) message.add_reaction.reset_mock() # reset mocks channel.fetch_message.reset_mock() - count_ducks.reset_mock() message.reset_mock() def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): @@ -530,7 +584,9 @@ class DuckPondSetupTests(unittest.TestCase): with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: duck_pond.setup(bot) - line = log_watcher.output[0] + + self.assertEqual(len(log_watcher.records), 1) + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.INFO) bot.add_cog.assert_called_once() - self.assertIn("Cog loaded: DuckPond", line) -- cgit v1.2.3 From d4269b36bbe1a57c4e1b61671c28647267c608bc Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 30 Nov 2019 13:35:41 +0100 Subject: Update bot/cogs/moderation/modlog.py --- bot/cogs/moderation/modlog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 41d7709e4..0df752a97 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -674,7 +674,7 @@ class ModLog(Cog, name="ModLog"): f"**Before**:\n{' '.join(content_before)}\n" f"**After**:\n{' '.join(content_after)}\n" "\n" - f"[jump to message]({msg_after.jump_url})" + f"[Jump to message]({msg_after.jump_url})" ) if msg_before.edited_at: -- cgit v1.2.3 From ac09fa35c03f76f50d6f7310e3ff4959270aad2b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 1 Dec 2019 13:26:27 -0800 Subject: Allow snekbox in esoteric-python channel * Add a hidden_channels parameter to in_channel decorator to hide channels from the InChannelCheckFailure error message. --- bot/cogs/snekbox.py | 2 +- bot/constants.py | 1 + bot/decorators.py | 15 ++++++++++++--- config-default.yml | 1 + 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 362968bd0..55a187ac1 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -176,7 +176,7 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot, bypass_roles=EVAL_ROLES) + @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. diff --git a/bot/constants.py b/bot/constants.py index a65c9ffa4..89504a2e0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -353,6 +353,7 @@ class Channels(metaclass=YAMLGetter): defcon: int devlog: int devtest: int + esoteric: int help_0: int help_1: int help_2: int diff --git a/bot/decorators.py b/bot/decorators.py index 935df4af0..61587f406 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -27,11 +27,20 @@ class InChannelCheckFailure(CheckFailure): super().__init__(f"Sorry, but you may only use this command within {channels_str}.") -def in_channel(*channels: int, bypass_roles: Container[int] = None) -> Callable: - """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" +def in_channel( + *channels: int, + hidden_channels: Container[int] = None, + bypass_roles: Container[int] = None +) -> Callable: + """ + Checks that the message is in a whitelisted channel or optionally has a bypass role. + + Hidden channels are channels which will not be displayed in the InChannelCheckFailure error + message. + """ def predicate(ctx: Context) -> bool: """In-channel checker predicate.""" - if ctx.channel.id in channels: + if ctx.channel.id in channels or ctx.channel.id in hidden_channels: log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " f"The command was used in a whitelisted channel.") return True diff --git a/config-default.yml b/config-default.yml index b2ee1361f..930a1a0e6 100644 --- a/config-default.yml +++ b/config-default.yml @@ -108,6 +108,7 @@ guild: defcon: &DEFCON 464469101889454091 devlog: &DEVLOG 622895325144940554 devtest: &DEVTEST 414574275865870337 + esoteric: 470884583684964352 help_0: 303906576991780866 help_1: 303906556754395136 help_2: 303906514266226689 -- cgit v1.2.3 From 4d702cb7783639e1e442409eed7306b4ddedbd81 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Tue, 3 Dec 2019 01:15:04 +0700 Subject: Removed pytest, getting ready to migrate to unittest in another PR --- tests/utils/test_time.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 7bde92506..4baa6395c 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -60,47 +60,3 @@ def test_wait_until(sleep_patch): assert asyncio.run(time.wait_until(then, start)) is None sleep_patch.assert_called_once_with(10 * 60) - - -@pytest.mark.parametrize( - ('date_from', 'date_to', 'parts', 'expected'), - ( - (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), 2, '11 hours, 59 minutes'), - (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), 1, '11 hours'), - (datetime(2019, 12, 12, 0, 1), datetime(2019, 12, 12, 12, 0, 5), None, '11 hours, 59 minutes, 5 seconds'), - (datetime(2019, 12, 12, 0, 0), datetime(2019, 12, 11, 23, 59), 2, '1 minute'), - (datetime(2019, 11, 23, 20, 9), datetime(2019, 11, 30, 20, 15), 2, '1 week, 6 minutes'), - (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), 2, '7 months, 2 weeks'), - (datetime(2019, 11, 23, 20, 9), datetime(2019, 4, 25, 20, 15), - None, '7 months, 2 weeks, 1 day, 23 hours, 54 minutes'), - (datetime(2019, 11, 23, 20, 58), datetime(2019, 11, 23, 21, 3), 2, '5 minutes'), - (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 24, 0, 0), 2, '1 minute'), - (datetime(2019, 11, 23, 23, 59), datetime(2022, 11, 23, 23, 0), 2, '3 years, 3 months'), - (datetime(2019, 11, 23, 23, 59), datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes, 55 seconds'), - ) -) -def test_get_duration(date_from: datetime, date_to: datetime, parts: int, expected: str): - assert time.get_duration(date_from, date_to, parts) == expected - - -@pytest.mark.parametrize( - ('expiry', 'date_from', 'parts', 'expected'), - ( - ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), 2, '2019-12-12 00:01 (11 hours, 59 minutes)'), - ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), 1, '2019-12-12 00:01 (11 hours)'), - ('2019-12-12T00:01:00Z', datetime(2019, 12, 12, 12, 0, 5), - None, '2019-12-12 00:01 (11 hours, 59 minutes, 5 seconds)'), - ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), - ('2019-11-23T20:09:00Z', datetime(2019, 11, 30, 20, 15), 2, '2019-11-23 20:09 (1 week, 6 minutes)'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (7 months, 2 weeks)'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), None, - '2019-11-23 20:09 (7 months, 2 weeks, 1 day, 23 hours, 54 minutes)'), - ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 21, 3), 2, '2019-11-23 20:58 (5 minutes)'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 24, 0, 0), 2, '2019-11-23 23:59 (1 minute)'), - ('2019-11-23T23:59:00Z', datetime(2022, 11, 23, 23, 0), 2, '2019-11-23 23:59 (3 years, 3 months)'), - ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '2019-11-23 23:59 (9 minutes, 55 seconds)'), - (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), - ) -) -def test_get_duration_from_expiry(expiry: str, date_from: datetime, parts: int, expected: str): - assert time.get_duration_from_expiry(expiry, date_from, parts) == expected -- cgit v1.2.3 From 8fee0ca7fce8919ebf853c5572d988f047043fee Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Tue, 3 Dec 2019 01:15:36 +0700 Subject: Deleted `get_duration` and switched to using the already, nicely made `humanize_delta` --- bot/utils/time.py | 56 ++++++++----------------------------------------------- 1 file changed, 8 insertions(+), 48 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index d3000a7c2..ec47fce2e 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,21 +1,12 @@ import asyncio import datetime -from typing import List, Optional +from typing import Optional import dateutil.parser from dateutil.relativedelta import relativedelta RFC1123_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" INFRACTION_FORMAT = "%Y-%m-%d %H:%M" -TIME_MARKS = ( - (60, 'second'), # 1 minute - (60, 'minute'), # 1 hour - (24, 'hour'), # 1 day - (7, 'day'), # 1 week - (4, 'week'), # 1 month - (12, 'month'), # 1 year - (999, 'year') # dumb the rest as year, max 999 -) def _stringify_time_unit(value: int, unit: str) -> str: @@ -122,48 +113,17 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def get_duration(date_from: datetime.datetime, date_to: datetime.datetime, parts: Optional[int] = 2) -> str: - """ - Get the duration between two datetime, in human readable format. - - Will return number of units if avaiable, for example: - - 11 hours, 59 minutes - - 1 week, 6 minutes - - 7 months, 2 weeks - - 3 years, 3 months - - 5 minutes - - :param date_from: A datetime.datetime object. - :param date_to: A datetime.datetime object. - :param parts: An int, defauted to two - the amount of units to return. - """ - div = abs(date_from - date_to).total_seconds() - div = round(div, 0) # to avoid (14 minutes, 60 seconds) - results: List[str] = [] - for unit, name in TIME_MARKS: - div, amount = divmod(div, unit) - if amount > 0: - plural = 's' if amount > 1 else '' - results.append(f"{amount:.0f} {name}{plural}") - parts = parts if parts is not None else len(results) # allow passing None directly to return all parts - # We have to reverse the order of units because currently it's smallest -> largest - return ', '.join(results[::-1][:parts]) - - def get_duration_from_expiry( expiry: str = None, date_from: datetime.datetime = None, - parts: Optional[int] = 2 + max_units: int = 2 ) -> Optional[str]: """ - Get the duration between datetime.utcnow() and an expiry, in human readable format. + Returns a human-readable version of the the duration between datetime.utcnow() and an expiry. - Will return the two biggest units avaiable, for example: - - 11 hours, 59 minutes - - 1 week, 6 minutes - - 7 months, 2 weeks - - 3 years, 3 months - - 5 minutes + Unlike the original function, this function will force the precision to be 'seconds' by not passing it. + max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2 :param expiry: A string. If not passed in, will early return a None ( Permanent infraction ). :param date_from: A datetime.datetime object. If not passed in, will use datetime.utcnow(). @@ -173,11 +133,11 @@ def get_duration_from_expiry( return None date_from = date_from or datetime.datetime.utcnow() - date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None) + date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) expiry_formatted = format_infraction(expiry) - duration = get_duration(date_from, date_to, parts) + duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units) duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" -- cgit v1.2.3 From 6cf907a4ab1f632dbe0fb2445703a84b965d7bfa Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 09:03:26 +0700 Subject: Renamed function and improved its docstring to better reflect its purposes. Changed from `get_duration_from_expiry` -> `format_infraction_with_duration` --- bot/cogs/moderation/management.py | 4 ++-- bot/cogs/moderation/scheduler.py | 2 +- bot/utils/time.py | 19 ++++++------------- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 5221baa81..abfe5c2b3 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -97,7 +97,7 @@ class ModManagement(commands.Cog): confirm_messages.append("marked as permanent") elif duration is not None: request_data['expires_at'] = duration.isoformat() - expiry = time.get_duration_from_expiry(request_data['expires_at']) + expiry = time.format_infraction_with_duration(request_data['expires_at']) confirm_messages.append(f"set to expire on {expiry}") else: confirm_messages.append("expiry unchanged") @@ -236,7 +236,7 @@ class ModManagement(commands.Cog): expires = "*Permanent*" else: date_from = datetime.strptime(created, time.INFRACTION_FORMAT) - expires = time.get_duration_from_expiry(infraction["expires_at"], date_from) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) lines = textwrap.dedent(f""" {"**===============**" if active else "==============="} diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 729763322..3e0968121 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -83,7 +83,7 @@ class InfractionScheduler(Scheduler): infr_type = infraction["type"] icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] - expiry = time.get_duration_from_expiry(infraction["expires_at"]) + expiry = time.format_infraction_with_duration(infraction["expires_at"]) id_ = infraction['id'] log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") diff --git a/bot/utils/time.py b/bot/utils/time.py index ec47fce2e..a024674ac 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,21 +113,14 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def get_duration_from_expiry( - expiry: str = None, - date_from: datetime.datetime = None, - max_units: int = 2 -) -> Optional[str]: +def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: """ - Returns a human-readable version of the the duration between datetime.utcnow() and an expiry. + Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. - Unlike the original function, this function will force the precision to be 'seconds' by not passing it. - max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). - By default, max_units is 2 - - :param expiry: A string. If not passed in, will early return a None ( Permanent infraction ). - :param date_from: A datetime.datetime object. If not passed in, will use datetime.utcnow(). - :param parts: An int, to show how many parts will be returned ( year - month or year - month - week - day ...). + Returns a human-readable version of the duration between datetime.utcnow() and an expiry. + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. """ if not expiry: return None -- cgit v1.2.3 From a92186f7218faf48b1ceb3b9f516b29d40e6efaf Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Tue, 3 Dec 2019 18:33:36 -0800 Subject: Antimalware: fix paste service URL showing replacement field --- bot/cogs/antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 745dd8082..602819191 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -26,7 +26,7 @@ class AntiMalware(Cog): if filename.endswith('.py'): embed.description = ( f"It looks like you tried to attach a Python file - please " - f"use a code-pasting service such as {URLs.paste_service}" + f"use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" ) break # Other detections irrelevant because we prioritize the .py message. if not filename.endswith(tuple(AntiMalwareConfig.whitelist)): -- cgit v1.2.3 From 7e25475b78df01646cbc82176443f955bb6d1964 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:01:40 +0700 Subject: Improved type hinting for `format_infraction_with_duration` --- bot/utils/time.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..9520b32f8 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( + expiry: Optional[str], + date_from: datetime.datetime = None, + max_units: int = 2 +) -> Optional[str]: """ Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. -- cgit v1.2.3 From 51f80015c5db9ab8e85ea2304789491d4c72c053 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:03:16 +0700 Subject: Created `until_expiration` to get the remaining time until the infraction expires. --- bot/utils/time.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/bot/utils/time.py b/bot/utils/time.py index 9520b32f8..ac64865d6 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -138,3 +138,23 @@ def format_infraction_with_duration( duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str]: + """ + Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + now = datetime.datetime.utcnow() + since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + if since < now: + return None + + return humanize_delta(relativedelta(since, now), max_units=max_units) -- cgit v1.2.3 From 82eb5e1c46e378a6f3778e17cc342193b910ded5 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:04:20 +0700 Subject: Implemented remaining time until expiration for infraction searching. Will show the remaining time, `Expired.` or `Inactive.` based on the status of the infraction ( It can be inactive but not expired, like an early unmute ) --- bot/cogs/moderation/management.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..2f5e09f1b 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -232,6 +232,12 @@ class ModManagement(commands.Cog): user_id = infraction["user"] hidden = infraction["hidden"] created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or 'Expired.' + else: + remaining = 'Inactive.' + if infraction["expires_at"] is None: expires = "*Permanent*" else: @@ -247,6 +253,7 @@ class ModManagement(commands.Cog): Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} + Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` {"**===============**" if active else "==============="} -- cgit v1.2.3 From c1aeb6d263172168f77845408e8d2756f6cb2813 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 17:12:25 +0700 Subject: Apply suggestions from Mark - removing `.` at the end and use double quote instead of single. Co-Authored-By: Mark --- bot/cogs/moderation/management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 2f5e09f1b..74f75781d 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -234,9 +234,9 @@ class ModManagement(commands.Cog): created = time.format_infraction(infraction["inserted_at"]) if active: - remaining = time.until_expiration(infraction["expires_at"]) or 'Expired.' + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" else: - remaining = 'Inactive.' + remaining = "Inactive" if infraction["expires_at"] is None: expires = "*Permanent*" -- cgit v1.2.3 From e07cf7342184b769d8c0655bc9b84be02809319a Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 23:38:46 +0700 Subject: Added `unittest` for `bot.utils.time` --- tests/bot/utils/test_time.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/bot/utils/test_time.py diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..0ef59292e --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,87 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def setUp(self): + pass + + def test_humanize_delta(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), + + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), + ) + + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + test_cases = (-1, 0) + + for max_units in test_cases: + with self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + test_cases = ( + ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), + ) + + for stamp, expected in test_cases: + self.assertEqual(time.parse_rfc1123(stamp), expected) + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + assert asyncio.run(time.wait_until(then, start)) is None + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration(self): + """Testing format_infraction_with_duration.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 6, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) -- cgit v1.2.3 From b17dbe5e3e0dfa6ae44d660924455f709abefd0d Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:34:58 +0700 Subject: Splitting test cases for `humanize_delta` into proper, independent tests. --- tests/bot/utils/test_time.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 0ef59292e..5e5f2bf2f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -15,18 +15,20 @@ class TimeTests(unittest.TestCase): def setUp(self): pass - def test_humanize_delta(self): - """Testing humanize delta.""" + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" test_cases = ( - (relativedelta(days=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), - (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'days', 2, '2 days'), - # Does not abort for unknown units, as the unit name is checked # against the attribute of the relativedelta instance. (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), + ) + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + test_cases = ( # Very high maximum units, but it only ever iterates over # each value the relativedelta might have. (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), @@ -35,6 +37,18 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + def test_humanize_delta_should_work_normally(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + def test_humanize_delta_raises_for_invalid_max_units(self): test_cases = (-1, 0) -- cgit v1.2.3 From 0aee728d6d23ef24f51834f39016f938f3f1b8a9 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:36:29 +0700 Subject: Added missing docstring for `test_humanize_delta_raises_for_invalid_max_units` --- tests/bot/utils/test_time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 5e5f2bf2f..a929bee89 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -50,6 +50,7 @@ class TimeTests(unittest.TestCase): self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max units.""" test_cases = (-1, 0) for max_units in test_cases: -- cgit v1.2.3 From beed21355e7f0e25b69637768843c53d510b8969 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:40:38 +0700 Subject: Changed `assert` to `self.assertIs` for `test_wait_until` --- tests/bot/utils/test_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index a929bee89..0afabe400 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -74,7 +74,7 @@ class TimeTests(unittest.TestCase): then = datetime(2019, 1, 1, 0, 10) # No return value - assert asyncio.run(time.wait_until(then, start)) is None + self.assertIs(asyncio.run(time.wait_until(then, start)), None) mock.assert_called_once_with(10 * 60) -- cgit v1.2.3 From ccdd8363d75846f0841791ba54763dae28243c62 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:45:36 +0700 Subject: Splitting test cases for `format_infraction_with_duration` into proper, independent tests. --- tests/bot/utils/test_time.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 0afabe400..2a2a707d8 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -37,7 +37,7 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) - def test_humanize_delta_should_work_normally(self): + def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" test_cases = ( (relativedelta(days=2), 'seconds', 1, '2 days'), @@ -78,18 +78,38 @@ class TimeTests(unittest.TestCase): mock.assert_called_once_with(10 * 60) - def test_format_infraction_with_duration(self): - """Testing format_infraction_with_duration.""" + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + self.assertEqual(time.format_infraction_with_duration(None), None) + + # To make sure that date_from and max_units are not touched + self.assertEqual(time.format_infraction_with_duration(None, date_from='Why hello there!'), None) + self.assertEqual(time.format_infraction_with_duration(None, max_units=float('inf')), None) + self.assertEqual( + time.format_infraction_with_duration(None, date_from='Why hello there!', max_units=float('inf')), + None + ) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + self.assertEqual( + time.format_infraction_with_duration('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6), + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)' + ) + + self.assertEqual( + time.format_infraction_with_duration('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20), + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)' + ) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" test_cases = ( ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), - ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, - '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 6, - '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)'), ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), -- cgit v1.2.3 From fa66195dbb6f79bb7174084835499a61e8cb03a3 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:52:15 +0700 Subject: Introduced test for `test_format_infraction`, refactored `test_parse_rfc1123`, fixed typo. --- tests/bot/utils/test_time.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 2a2a707d8..09fb824e4 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -50,7 +50,7 @@ class TimeTests(unittest.TestCase): self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): - """humanize_delta should raises ValueError('max_units must be positive') for invalid max units.""" + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" test_cases = (-1, 0) for max_units in test_cases: @@ -60,12 +60,14 @@ class TimeTests(unittest.TestCase): def test_parse_rfc1123(self): """Testing parse_rfc1123.""" - test_cases = ( - ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) ) - for stamp, expected in test_cases: - self.assertEqual(time.parse_rfc1123(stamp), expected) + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') @patch('asyncio.sleep', new_callable=AsyncMock) def test_wait_until(self, mock): -- cgit v1.2.3 From 5e0b19ae841f3f355931ad331f7aa861fbafc4d9 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:05:41 +0700 Subject: Added `self.subTest` for tests with multiple test cases & simplified single test case tests. --- tests/bot/utils/test_time.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 09fb824e4..c47a306f0 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -17,25 +17,15 @@ class TimeTests(unittest.TestCase): def test_humanize_delta_handle_unknown_units(self): """humanize_delta should be able to handle unknown units, and will not abort.""" - test_cases = ( - # Does not abort for unknown units, as the unit name is checked - # against the attribute of the relativedelta instance. - (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), - ) - - for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') def test_humanize_delta_handle_high_units(self): """humanize_delta should be able to handle very high units.""" - test_cases = ( - # Very high maximum units, but it only ever iterates over - # each value the relativedelta might have. - (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), - ) - - for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" @@ -47,14 +37,15 @@ class TimeTests(unittest.TestCase): ) for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" test_cases = (-1, 0) for max_units in test_cases: - with self.assertRaises(ValueError) as error: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) self.assertEqual(str(error), 'max_units must be positive') @@ -121,4 +112,5 @@ class TimeTests(unittest.TestCase): ) for expiry, date_from, max_units, expected in test_cases: - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) -- cgit v1.2.3 From db341d927aab42c2e874cb499ab1c2e6c0e7647b Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:17:56 +0700 Subject: Moved all individual test cases into iterables and test with `self.subTest` context manager. --- tests/bot/utils/test_time.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index c47a306f0..25cd3f69f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -73,27 +73,31 @@ class TimeTests(unittest.TestCase): def test_format_infraction_with_duration_none_expiry(self): """format_infraction_with_duration should work for None expiry.""" - self.assertEqual(time.format_infraction_with_duration(None), None) + test_cases = ( + (None, None, None, None), - # To make sure that date_from and max_units are not touched - self.assertEqual(time.format_infraction_with_duration(None, date_from='Why hello there!'), None) - self.assertEqual(time.format_infraction_with_duration(None, max_units=float('inf')), None) - self.assertEqual( - time.format_infraction_with_duration(None, date_from='Why hello there!', max_units=float('inf')), - None + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), ) + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + def test_format_infraction_with_duration_custom_units(self): """format_infraction_with_duration should work for custom max_units.""" - self.assertEqual( - time.format_infraction_with_duration('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6), - '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)' + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') ) - self.assertEqual( - time.format_infraction_with_duration('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20), - '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)' - ) + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) def test_format_infraction_with_duration_normal_usage(self): """format_infraction_with_duration should work for normal usage, across various durations.""" -- cgit v1.2.3 From 323306776b0312e2a32ada213a35159311a93a7f Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:42:23 +0700 Subject: Removed `setUp()` from `TimeTests` since it is not being used for anything. --- tests/bot/utils/test_time.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 25cd3f69f..7f55dc3ec 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -12,9 +12,6 @@ from tests.helpers import AsyncMock class TimeTests(unittest.TestCase): """Test helper functions in bot.utils.time.""" - def setUp(self): - pass - def test_humanize_delta_handle_unknown_units(self): """humanize_delta should be able to handle unknown units, and will not abort.""" # Does not abort for unknown units, as the unit name is checked -- cgit v1.2.3 From ad1a33e80152343a81eeeabf0117ced76b83e273 Mon Sep 17 00:00:00 2001 From: Daniel Brown Date: Thu, 5 Dec 2019 10:35:55 -0600 Subject: Added optional channel parameter to !echo: - Added the option to specify a channel to have Python repeat what you said to it, as well as keeping the old functionality of having it repeat what you said in the current channel if no channel argument is given. Signed-off-by: Daniel Brown --- bot/cogs/bot.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 7583b2f2d..ee0a463de 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -4,7 +4,7 @@ import re import time from typing import Optional, Tuple -from discord import Embed, Message, RawMessageUpdateEvent +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel from discord.ext.commands import Bot, Cog, Context, command, group from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs @@ -71,9 +71,12 @@ class Bot(Cog): @command(name='echo', aliases=('print',)) @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, *, text: str) -> None: - """Send the input verbatim to the current channel.""" - await ctx.send(text) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) @command(name='embed') @with_role(*MODERATION_ROLES) -- cgit v1.2.3 From 52163d775a6a0737f32a0c291e9275a910656fab Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Thu, 5 Dec 2019 17:49:28 +0000 Subject: Requested change Include the check about whether or not there is a token in the posted message in `parse_codeblock` boolean. --- bot/cogs/bot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 53221cd8b..f79e00454 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -236,9 +236,10 @@ class Bot(Cog): ) and not msg.author.bot and len(msg.content.splitlines()) > 3 + and not TokenRemover.is_token_in_message(msg) ) - if parse_codeblock and not TokenRemover.is_token_in_message(msg): # no token in the msg + if parse_codeblock: # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From a9dc1000872f507a850798b204befed299b6f703 Mon Sep 17 00:00:00 2001 From: Jens Date: Thu, 5 Dec 2019 22:07:59 +0100 Subject: Keeps access token alive, only revokes it on extension unload. Hard-coded version number to 1.0.0. --- bot/cogs/reddit.py | 52 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 64a940af1..0ebf2e1a7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,7 +2,8 @@ import asyncio import logging import random import textwrap -from datetime import datetime +from collections import namedtuple +from datetime import datetime, timedelta from typing import List from aiohttp import BasicAuth @@ -21,11 +22,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - # Change your client's User-Agent string to something unique and descriptive, - # including the target platform, a unique application identifier, a version string, - # and your username as contact information, in the following format: - # :: (by /u/) - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -33,7 +30,8 @@ class Reddit(Cog): def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready + self.webhook = None + self.access_token = None bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() @@ -41,6 +39,8 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -53,7 +53,7 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) - async def get_access_tokens(self) -> None: + async def get_access_token(self) -> None: """Get Reddit access tokens.""" headers = {"User-Agent": self.USER_AGENT} data = { @@ -61,6 +61,7 @@ class Reddit(Cog): "duration": "temporary" } + log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) for _ in range(self.MAX_RETRIES): @@ -72,9 +73,13 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() - self.access_token = content["access_token"] + AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(hours=1) + ) self.headers = { - "Authorization": "bearer " + self.access_token, + "Authorization": "bearer " + self.access_token.token, "User-Agent": self.USER_AGENT } return @@ -91,7 +96,7 @@ class Reddit(Cog): # The token should be revoked, since the API is called only once a day. headers = {"User-Agent": self.USER_AGENT} data = { - "token": self.access_token, + "token": self.access_token.token, "token_type_hint": "access_token" } @@ -200,7 +205,10 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - await self.get_access_tokens() + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() @@ -210,8 +218,6 @@ class Reddit(Cog): top_posts = await self.get_top_posts(subreddit=subreddit, time="day") await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) - await self.revoke_access_token() - async def top_weekly_posts(self) -> None: """Post a summary of the top posts.""" for subreddit in RedditConfig.subreddits: @@ -242,32 +248,38 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="all") await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="day") await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="week") await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @with_role(*STAFF_ROLES) @reddit_group.command(name="subreddits", aliases=("subs",)) -- cgit v1.2.3 From d913a91531ba6414741d745303f89cb687cf345b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Dec 2019 19:29:33 -0800 Subject: Subclass Bot --- bot/__main__.py | 26 ++------------------------ bot/bot.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 24 deletions(-) create mode 100644 bot/bot.py diff --git a/bot/__main__.py b/bot/__main__.py index ea7c43a12..84bc7094b 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,18 +1,11 @@ -import asyncio -import logging -import socket - import discord -from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord.ext.commands import Bot, when_mentioned_or +from discord.ext.commands import when_mentioned_or from bot import patches -from bot.api import APIClient, APILoggingHandler +from bot.bot import Bot from bot.constants import Bot as BotConfig, DEBUG_MODE -log = logging.getLogger('bot') - bot = Bot( command_prefix=when_mentioned_or(BotConfig.prefix), activity=discord.Game(name="Commands: !help"), @@ -20,18 +13,6 @@ bot = Bot( max_messages=10_000, ) -# Global aiohttp session for all cogs -# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads -# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. -bot.http_session = ClientSession( - connector=TCPConnector( - resolver=AsyncResolver(), - family=socket.AF_INET, - ) -) -bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.addHandler(APILoggingHandler(bot.api_client)) - # Internal/debug bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") @@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'): patches.message_edited_at.apply_patch() bot.run(BotConfig.token) - -# This calls a coroutine, so it doesn't do anything at the moment. -# bot.http_session.close() # Close the aiohttp session when the bot finishes running diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 000000000..05734ac1d --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,30 @@ +import asyncio +import logging +import socket + +import aiohttp +from discord.ext import commands + +from bot import api + +log = logging.getLogger('bot') + + +class Bot(commands.Bot): + """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Global aiohttp session for all cogs + # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads + # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. + self.http_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector( + resolver=aiohttp.AsyncResolver(), + family=socket.AF_INET, + ) + ) + + self.api_client = api.APIClient(loop=asyncio.get_event_loop()) + log.addHandler(api.APILoggingHandler(self.api_client)) -- cgit v1.2.3 From 6fe61e5919cb541a1651312a01ddf7e7f10d0f86 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Dec 2019 20:11:50 -0800 Subject: Change all Bot imports to use the subclass --- bot/cogs/alias.py | 3 ++- bot/cogs/antimalware.py | 3 ++- bot/cogs/antispam.py | 3 ++- bot/cogs/bot.py | 3 ++- bot/cogs/clean.py | 3 ++- bot/cogs/defcon.py | 3 ++- bot/cogs/doc.py | 5 +++-- bot/cogs/duck_pond.py | 3 ++- bot/cogs/error_handler.py | 3 ++- bot/cogs/eval.py | 3 ++- bot/cogs/extensions.py | 3 ++- bot/cogs/filtering.py | 3 ++- bot/cogs/free.py | 3 ++- bot/cogs/help.py | 3 ++- bot/cogs/information.py | 3 ++- bot/cogs/jams.py | 5 +++-- bot/cogs/logging.py | 3 ++- bot/cogs/moderation/__init__.py | 3 +-- bot/cogs/moderation/infractions.py | 3 ++- bot/cogs/moderation/management.py | 3 ++- bot/cogs/moderation/modlog.py | 3 ++- bot/cogs/moderation/scheduler.py | 3 ++- bot/cogs/moderation/superstarify.py | 3 ++- bot/cogs/off_topic_names.py | 3 ++- bot/cogs/reddit.py | 3 ++- bot/cogs/reminders.py | 3 ++- bot/cogs/security.py | 4 +++- bot/cogs/site.py | 3 ++- bot/cogs/snekbox.py | 3 ++- bot/cogs/sync/__init__.py | 3 +-- bot/cogs/sync/cog.py | 3 ++- bot/cogs/sync/syncers.py | 7 ++++--- bot/cogs/tags.py | 3 ++- bot/cogs/token_remover.py | 3 ++- bot/cogs/utils.py | 3 ++- bot/cogs/verification.py | 3 ++- bot/cogs/watchchannels/__init__.py | 3 +-- bot/cogs/watchchannels/bigbrother.py | 3 ++- bot/cogs/watchchannels/talentpool.py | 3 ++- bot/cogs/watchchannels/watchchannel.py | 3 ++- bot/cogs/wolfram.py | 7 ++++--- bot/interpreter.py | 4 +++- tests/helpers.py | 4 +++- 43 files changed, 92 insertions(+), 52 deletions(-) diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 5190c559b..4ee5a2aed 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -3,8 +3,9 @@ import logging from typing import Union from discord import Colour, Embed, Member, User -from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from discord.ext.commands import Cog, Command, Context, clean_content, command, group +from bot.bot import Bot from bot.cogs.extensions import Extension from bot.cogs.watchchannels.watchchannel import proxy_user from bot.converters import TagNameConverter diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 602819191..03c1e28a1 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,8 +1,9 @@ import logging from discord import Embed, Message, NotFound -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs log = logging.getLogger(__name__) diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 1340eb608..88912038a 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -7,9 +7,10 @@ from operator import itemgetter from typing import Dict, Iterable, List, Set from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from bot import rules +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( AntiSpam as AntiSpamConfig, Channels, diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index ee0a463de..a2edb7576 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -5,8 +5,9 @@ import time from typing import Optional, Tuple from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Bot, Cog, Context, command, group +from discord.ext.commands import Cog, Context, command, group +from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index dca411d01..3365d0934 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -4,8 +4,9 @@ import re from typing import Optional from discord import Colour, Embed, Message, User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, CleanMessages, Colours, Event, diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index bedd70c86..f062a7546 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -6,8 +6,9 @@ from datetime import datetime, timedelta from enum import Enum from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles from bot.decorators import with_role diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index e5b3a4062..7df159fd9 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -17,6 +17,7 @@ from requests import ConnectTimeout, ConnectionError, HTTPError from sphinx.ext import intersphinx from urllib3.exceptions import ProtocolError +from bot.bot import Bot from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import ValidPythonIdentifier, ValidURL from bot.decorators import with_role @@ -147,7 +148,7 @@ class InventoryURL(commands.Converter): class Doc(commands.Cog): """A set of commands for querying & displaying documentation.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.base_urls = {} self.bot = bot self.inventories = {} @@ -506,7 +507,7 @@ class Doc(commands.Cog): return tag.name == "table" -def setup(bot: commands.Bot) -> None: +def setup(bot: Bot) -> None: """Doc cog load.""" bot.add_cog(Doc(bot)) log.info("Cog loaded: Doc") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 2d25cd17e..879071d1b 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -3,9 +3,10 @@ from typing import Optional, Union import discord from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from bot import constants +from bot.bot import Bot from bot.utils.messages import send_attachments log = logging.getLogger(__name__) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 49411814c..cf90e9f48 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -14,9 +14,10 @@ from discord.ext.commands import ( NoPrivateMessage, UserInputError, ) -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels from bot.decorators import InChannelCheckFailure diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 00b988dde..5daec3e39 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -9,8 +9,9 @@ from io import StringIO from typing import Any, Optional, Tuple import discord -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role from bot.interpreter import Interpreter diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index bb66e0b8e..4d77d8205 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -6,8 +6,9 @@ from pkgutil import iter_modules from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import Bot, Context, group +from discord.ext.commands import Context, group +from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs from bot.pagination import LinePaginator from bot.utils.checks import with_role_check diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 1e7521054..2e54ccecb 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -5,8 +5,9 @@ from typing import Optional, Union import discord.errors from dateutil.relativedelta import relativedelta from discord import Colour, DMChannel, Member, Message, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, Colours, diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 82285656b..bbc9f063b 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -3,8 +3,9 @@ from datetime import datetime from operator import itemgetter from discord import Colour, Embed, Member, utils -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Categories, Channels, Free, STAFF_ROLES from bot.decorators import redirect_output diff --git a/bot/cogs/help.py b/bot/cogs/help.py index 9607dbd8d..6385fa467 100644 --- a/bot/cogs/help.py +++ b/bot/cogs/help.py @@ -6,10 +6,11 @@ from typing import Union from discord import Colour, Embed, HTTPException, Message, Reaction, User from discord.ext import commands -from discord.ext.commands import Bot, CheckFailure, Cog as DiscordCog, Command, Context +from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context from fuzzywuzzy import fuzz, process from bot import constants +from bot.bot import Bot from bot.constants import Channels, STAFF_ROLES from bot.decorators import redirect_output from bot.pagination import ( diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 530453600..56bd37bec 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -9,10 +9,11 @@ from typing import Any, Mapping, Optional import discord from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, command, group +from discord.ext.commands import BucketType, Cog, Context, command, group from discord.utils import escape_markdown from bot import constants +from bot.bot import Bot from bot.decorators import InChannelCheckFailure, in_channel, with_role from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index be9d33e3e..0c82e7962 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -4,6 +4,7 @@ from discord import Member, PermissionOverwrite, utils from discord.ext import commands from more_itertools import unique_everseen +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role @@ -13,7 +14,7 @@ log = logging.getLogger(__name__) class CodeJams(commands.Cog): """Manages the code-jam related parts of our server.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @commands.command() @@ -108,7 +109,7 @@ class CodeJams(commands.Cog): ) -def setup(bot: commands.Bot) -> None: +def setup(bot: Bot) -> None: """Code Jams cog load.""" bot.add_cog(CodeJams(bot)) log.info("Cog loaded: CodeJams") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index c92b619ff..44c771b42 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -1,8 +1,9 @@ import logging from discord import Embed -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 7383ed44e..0cbdb3aa6 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,7 +1,6 @@ import logging -from discord.ext.commands import Bot - +from bot.bot import Bot from .infractions import Infractions from .management import ModManagement from .modlog import ModLog diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2713a1b68..7478e19ef 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -7,6 +7,7 @@ from discord.ext import commands from discord.ext.commands import Context, command from bot import constants +from bot.bot import Bot from bot.constants import Event from bot.decorators import respect_role_hierarchy from bot.utils.checks import with_role_check @@ -25,7 +26,7 @@ class Infractions(InfractionScheduler, commands.Cog): category = "Moderation" category_description = "Server moderation tools." - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) self.category = "Moderation" diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..feae00b7c 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,6 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants +from bot.bot import Bot from bot.converters import InfractionSearchQuery from bot.pagination import LinePaginator from bot.utils import time @@ -36,7 +37,7 @@ class ModManagement(commands.Cog): category = "Moderation" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @property diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 0df752a97..35ef6cbcc 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -10,8 +10,9 @@ from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff from discord import Colour from discord.abc import GuildChannel -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context +from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs from bot.utils.time import humanize_delta from .utils import UserTypes diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 3e0968121..937113ef4 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -7,10 +7,11 @@ from gettext import ngettext import dateutil.parser import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Colours, STAFF_CHANNELS from bot.utils import time from bot.utils.scheduling import Scheduler diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 9b3c62403..7631d9bbe 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -6,9 +6,10 @@ import typing as t from pathlib import Path from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command from bot import constants +from bot.bot import Bot from bot.utils.checks import with_role_check from bot.utils.time import format_infraction from . import utils diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 78792240f..18d9cfb01 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -4,9 +4,10 @@ import logging from datetime import datetime, timedelta from discord import Colour, Embed -from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group +from discord.ext.commands import BadArgument, Cog, Context, Converter, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES from bot.decorators import with_role from bot.pagination import LinePaginator diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0d06e9c26..c76fcd937 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -6,9 +6,10 @@ from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, TextChannel -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop +from bot.bot import Bot from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks from bot.converters import Subreddit from bot.decorators import with_role diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 81990704b..b805b24c5 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -8,8 +8,9 @@ from typing import Optional from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Message -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration from bot.pagination import LinePaginator diff --git a/bot/cogs/security.py b/bot/cogs/security.py index 316b33d6b..45d0eb2f5 100644 --- a/bot/cogs/security.py +++ b/bot/cogs/security.py @@ -1,6 +1,8 @@ import logging -from discord.ext.commands import Bot, Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot log = logging.getLogger(__name__) diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 683613788..1d7bd03e4 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -1,8 +1,9 @@ import logging from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import URLs from bot.pagination import LinePaginator diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 55a187ac1..1ea61a8da 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -5,8 +5,9 @@ import textwrap from signal import Signals from typing import Optional, Tuple -from discord.ext.commands import Bot, Cog, Context, command, guild_only +from discord.ext.commands import Cog, Context, command, guild_only +from bot.bot import Bot from bot.constants import Channels, Roles, URLs from bot.decorators import in_channel from bot.utils.messages import wait_for_deletion diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py index d4565f848..0da81c60e 100644 --- a/bot/cogs/sync/__init__.py +++ b/bot/cogs/sync/__init__.py @@ -1,7 +1,6 @@ import logging -from discord.ext.commands import Bot - +from bot.bot import Bot from .cog import Sync log = logging.getLogger(__name__) diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index aaa581f96..90d4c40fe 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -3,10 +3,11 @@ from typing import Callable, Iterable from discord import Guild, Member, Role from discord.ext import commands -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.sync import syncers log = logging.getLogger(__name__) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 2cc5a66e1..14cf51383 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -2,7 +2,8 @@ from collections import namedtuple from typing import Dict, Set, Tuple from discord import Guild -from discord.ext.commands import Bot + +from bot.bot import Bot # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. @@ -52,7 +53,7 @@ async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: Synchronize roles found on the given `guild` with the ones on the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): @@ -169,7 +170,7 @@ async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: Synchronize users found in the given `guild` with the ones in the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index cd70e783a..2ece0095d 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -2,8 +2,9 @@ import logging import time from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Cooldowns, MODERATION_ROLES, Roles from bot.converters import TagContentConverter, TagNameConverter from bot.decorators import with_role diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5a0d20e57..7af7ed63a 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -6,9 +6,10 @@ import struct from datetime import datetime from discord import Colour, Message -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from discord.utils import snowflake_time +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Event, Icons diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 793fe4c1a..0ed996430 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -8,8 +8,9 @@ from typing import Tuple from dateutil import relativedelta from discord import Colour, Embed, Message, Role -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES from bot.decorators import in_channel, with_role from bot.utils.time import humanize_delta diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b5e8d4357..74eb0dbf8 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -3,8 +3,9 @@ from datetime import datetime from discord import Colour, Message, NotFound, Object from discord.ext import tasks -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Bot as BotConfig, diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py index 86e1050fa..e18aea88a 100644 --- a/bot/cogs/watchchannels/__init__.py +++ b/bot/cogs/watchchannels/__init__.py @@ -1,7 +1,6 @@ import logging -from discord.ext.commands import Bot - +from bot.bot import Bot from .bigbrother import BigBrother from .talentpool import TalentPool diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 49783bb09..306ed4c64 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -3,8 +3,9 @@ from collections import ChainMap from typing import Union from discord import User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation.utils import post_infraction from bot.constants import Channels, MODERATION_ROLES, Webhooks from bot.decorators import with_role diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index 4ec42dcc1..cc8feeeee 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -4,9 +4,10 @@ from collections import ChainMap from typing import Union from discord import Color, Embed, Member, User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks from bot.decorators import with_role from bot.pagination import LinePaginator diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 0bf75a924..bd0622554 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -10,9 +10,10 @@ from typing import Optional import dateutil.parser import discord from discord import Color, Embed, HTTPException, Message, Object, errors -from discord.ext.commands import BadArgument, Bot, Cog, Context +from discord.ext.commands import BadArgument, Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py index ab0ed2472..c3c193cb9 100644 --- a/bot/cogs/wolfram.py +++ b/bot/cogs/wolfram.py @@ -7,8 +7,9 @@ import discord from dateutil.relativedelta import relativedelta from discord import Embed from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, check, group +from discord.ext.commands import BucketType, Cog, Context, check, group +from bot.bot import Bot from bot.constants import Colours, STAFF_ROLES, Wolfram from bot.pagination import ImagePaginator from bot.utils.time import humanize_delta @@ -151,7 +152,7 @@ async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tup class Wolfram(Cog): """Commands for interacting with the Wolfram|Alpha API.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) @@ -266,7 +267,7 @@ class Wolfram(Cog): await send_embed(ctx, message, color) -def setup(bot: commands.Bot) -> None: +def setup(bot: Bot) -> None: """Wolfram cog load.""" bot.add_cog(Wolfram(bot)) log.info("Cog loaded: Wolfram") diff --git a/bot/interpreter.py b/bot/interpreter.py index 76a3fc293..8b7268746 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -2,7 +2,9 @@ from code import InteractiveInterpreter from io import StringIO from typing import Any -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot CODE_TEMPLATE = """ async def _func(): diff --git a/tests/helpers.py b/tests/helpers.py index b2daae92d..5df796c23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,7 +10,9 @@ import unittest.mock from typing import Any, Iterable, Optional import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot for logger in logging.Logger.manager.loggerDict.values(): -- cgit v1.2.3 From 52924051e27d34c3f7e32c281fbe8ae0587a9766 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Dec 2019 20:31:39 -0800 Subject: Override add_cog to log loading of cogs --- bot/bot.py | 5 +++++ bot/cogs/alias.py | 3 +-- bot/cogs/antimalware.py | 3 +-- bot/cogs/antispam.py | 3 +-- bot/cogs/bot.py | 3 +-- bot/cogs/clean.py | 3 +-- bot/cogs/defcon.py | 3 +-- bot/cogs/doc.py | 3 +-- bot/cogs/duck_pond.py | 3 +-- bot/cogs/error_handler.py | 3 +-- bot/cogs/eval.py | 3 +-- bot/cogs/extensions.py | 1 - bot/cogs/filtering.py | 3 +-- bot/cogs/free.py | 3 +-- bot/cogs/information.py | 3 +-- bot/cogs/jams.py | 3 +-- bot/cogs/logging.py | 3 +-- bot/cogs/moderation/__init__.py | 13 +------------ bot/cogs/off_topic_names.py | 3 +-- bot/cogs/reddit.py | 3 +-- bot/cogs/reminders.py | 3 +-- bot/cogs/security.py | 3 +-- bot/cogs/site.py | 3 +-- bot/cogs/snekbox.py | 3 +-- bot/cogs/sync/__init__.py | 7 +------ bot/cogs/tags.py | 3 +-- bot/cogs/token_remover.py | 3 +-- bot/cogs/utils.py | 3 +-- bot/cogs/verification.py | 3 +-- bot/cogs/watchchannels/__init__.py | 10 +--------- bot/cogs/wolfram.py | 3 +-- 31 files changed, 34 insertions(+), 80 deletions(-) diff --git a/bot/bot.py b/bot/bot.py index 05734ac1d..f39bfb50a 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -28,3 +28,8 @@ class Bot(commands.Bot): self.api_client = api.APIClient(loop=asyncio.get_event_loop()) log.addHandler(api.APILoggingHandler(self.api_client)) + + def add_cog(self, cog: commands.Cog) -> None: + """Adds a "cog" to the bot and logs the operation.""" + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 4ee5a2aed..c1db38462 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -148,6 +148,5 @@ class Alias (Cog): def setup(bot: Bot) -> None: - """Alias cog load.""" + """Load the Alias cog.""" bot.add_cog(Alias(bot)) - log.info("Cog loaded: Alias") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 03c1e28a1..28e3e5d96 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -50,6 +50,5 @@ class AntiMalware(Cog): def setup(bot: Bot) -> None: - """Antimalware cog load.""" + """Load the AntiMalware cog.""" bot.add_cog(AntiMalware(bot)) - log.info("Cog loaded: AntiMalware") diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 88912038a..f454061a6 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -277,7 +277,6 @@ def validate_config(rules: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: def setup(bot: Bot) -> None: - """Antispam cog load.""" + """Validate the AntiSpam configs and load the AntiSpam cog.""" validation_errors = validate_config() bot.add_cog(AntiSpam(bot, validation_errors)) - log.info("Cog loaded: AntiSpam") diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index a2edb7576..b5642da82 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -378,6 +378,5 @@ class Bot(Cog): def setup(bot: Bot) -> None: - """Bot cog load.""" + """Load the Bot cog.""" bot.add_cog(Bot(bot)) - log.info("Cog loaded: Bot") diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index 3365d0934..c7168122d 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -212,6 +212,5 @@ class Clean(Cog): def setup(bot: Bot) -> None: - """Clean cog load.""" + """Load the Clean cog.""" bot.add_cog(Clean(bot)) - log.info("Cog loaded: Clean") diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index f062a7546..3e7350fcc 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -237,6 +237,5 @@ class Defcon(Cog): def setup(bot: Bot) -> None: - """DEFCON cog load.""" + """Load the Defcon cog.""" bot.add_cog(Defcon(bot)) - log.info("Cog loaded: Defcon") diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 7df159fd9..9506b195a 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -508,6 +508,5 @@ class Doc(commands.Cog): def setup(bot: Bot) -> None: - """Doc cog load.""" + """Load the Doc cog.""" bot.add_cog(Doc(bot)) - log.info("Cog loaded: Doc") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 879071d1b..345d2856c 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -178,6 +178,5 @@ class DuckPond(Cog): def setup(bot: Bot) -> None: - """Load the duck pond cog.""" + """Load the DuckPond cog.""" bot.add_cog(DuckPond(bot)) - log.info("Cog loaded: DuckPond") diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index cf90e9f48..700f903a6 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -144,6 +144,5 @@ class ErrorHandler(Cog): def setup(bot: Bot) -> None: - """Error handler cog load.""" + """Load the ErrorHandler cog.""" bot.add_cog(ErrorHandler(bot)) - log.info("Cog loaded: Events") diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 5daec3e39..9c729f28a 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -198,6 +198,5 @@ async def func(): # (None,) -> Any def setup(bot: Bot) -> None: - """Code eval cog load.""" + """Load the CodeEval cog.""" bot.add_cog(CodeEval(bot)) - log.info("Cog loaded: Eval") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index 4d77d8205..f16e79fb7 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -234,4 +234,3 @@ class Extensions(commands.Cog): def setup(bot: Bot) -> None: """Load the Extensions cog.""" bot.add_cog(Extensions(bot)) - log.info("Cog loaded: Extensions") diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 2e54ccecb..74538542a 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -371,6 +371,5 @@ class Filtering(Cog): def setup(bot: Bot) -> None: - """Filtering cog load.""" + """Load the Filtering cog.""" bot.add_cog(Filtering(bot)) - log.info("Cog loaded: Filtering") diff --git a/bot/cogs/free.py b/bot/cogs/free.py index bbc9f063b..49cab6172 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -99,6 +99,5 @@ class Free(Cog): def setup(bot: Bot) -> None: - """Free cog load.""" + """Load the Free cog.""" bot.add_cog(Free()) - log.info("Cog loaded: Free") diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 56bd37bec..1ede95ff4 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -392,6 +392,5 @@ class Information(Cog): def setup(bot: Bot) -> None: - """Information cog load.""" + """Load the Information cog.""" bot.add_cog(Information(bot)) - log.info("Cog loaded: Information") diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index 0c82e7962..985f28ce5 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -110,6 +110,5 @@ class CodeJams(commands.Cog): def setup(bot: Bot) -> None: - """Code Jams cog load.""" + """Load the CodeJams cog.""" bot.add_cog(CodeJams(bot)) - log.info("Cog loaded: CodeJams") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index 44c771b42..d1b7dcab3 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -38,6 +38,5 @@ class Logging(Cog): def setup(bot: Bot) -> None: - """Logging cog load.""" + """Load the Logging cog.""" bot.add_cog(Logging(bot)) - log.info("Cog loaded: Logging") diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 0cbdb3aa6..5243cb92d 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,24 +1,13 @@ -import logging - from bot.bot import Bot from .infractions import Infractions from .management import ModManagement from .modlog import ModLog from .superstarify import Superstarify -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs).""" + """Load the Infractions, ModManagement, ModLog, and Superstarify cogs.""" bot.add_cog(Infractions(bot)) - log.info("Cog loaded: Infractions") - bot.add_cog(ModLog(bot)) - log.info("Cog loaded: ModLog") - bot.add_cog(ModManagement(bot)) - log.info("Cog loaded: ModManagement") - bot.add_cog(Superstarify(bot)) - log.info("Cog loaded: Superstarify") diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 18d9cfb01..bf777ea5a 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -185,6 +185,5 @@ class OffTopicNames(Cog): def setup(bot: Bot) -> None: - """Off topic names cog load.""" + """Load the OffTopicNames cog.""" bot.add_cog(OffTopicNames(bot)) - log.info("Cog loaded: OffTopicNames") diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index c76fcd937..bec316ae7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -218,6 +218,5 @@ class Reddit(Cog): def setup(bot: Bot) -> None: - """Reddit cog load.""" + """Load the Reddit cog.""" bot.add_cog(Reddit(bot)) - log.info("Cog loaded: Reddit") diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index b805b24c5..45bf9a8f4 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -291,6 +291,5 @@ class Reminders(Scheduler, Cog): def setup(bot: Bot) -> None: - """Reminders cog load.""" + """Load the Reminders cog.""" bot.add_cog(Reminders(bot)) - log.info("Cog loaded: Reminders") diff --git a/bot/cogs/security.py b/bot/cogs/security.py index 45d0eb2f5..c680c5e27 100644 --- a/bot/cogs/security.py +++ b/bot/cogs/security.py @@ -27,6 +27,5 @@ class Security(Cog): def setup(bot: Bot) -> None: - """Security cog load.""" + """Load the Security cog.""" bot.add_cog(Security(bot)) - log.info("Cog loaded: Security") diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 1d7bd03e4..2ea8c7a2e 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -139,6 +139,5 @@ class Site(Cog): def setup(bot: Bot) -> None: - """Site cog load.""" + """Load the Site cog.""" bot.add_cog(Site(bot)) - log.info("Cog loaded: Site") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 1ea61a8da..da33e27b2 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -228,6 +228,5 @@ class Snekbox(Cog): def setup(bot: Bot) -> None: - """Snekbox cog load.""" + """Load the Snekbox cog.""" bot.add_cog(Snekbox(bot)) - log.info("Cog loaded: Snekbox") diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py index 0da81c60e..fe7df4e9b 100644 --- a/bot/cogs/sync/__init__.py +++ b/bot/cogs/sync/__init__.py @@ -1,12 +1,7 @@ -import logging - from bot.bot import Bot from .cog import Sync -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Sync cog load.""" + """Load the Sync cog.""" bot.add_cog(Sync(bot)) - log.info("Cog loaded: Sync") diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 2ece0095d..970301013 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -161,6 +161,5 @@ class Tags(Cog): def setup(bot: Bot) -> None: - """Tags cog load.""" + """Load the Tags cog.""" bot.add_cog(Tags(bot)) - log.info("Cog loaded: Tags") diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 7af7ed63a..5d6618338 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -120,6 +120,5 @@ class TokenRemover(Cog): def setup(bot: Bot) -> None: - """Token Remover cog load.""" + """Load the TokenRemover cog.""" bot.add_cog(TokenRemover(bot)) - log.info("Cog loaded: TokenRemover") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 0ed996430..47a59db66 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -177,6 +177,5 @@ class Utils(Cog): def setup(bot: Bot) -> None: - """Utils cog load.""" + """Load the Utils cog.""" bot.add_cog(Utils(bot)) - log.info("Cog loaded: Utils") diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 74eb0dbf8..b32b9a29e 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -225,6 +225,5 @@ class Verification(Cog): def setup(bot: Bot) -> None: - """Verification cog load.""" + """Load the Verification cog.""" bot.add_cog(Verification(bot)) - log.info("Cog loaded: Verification") diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py index e18aea88a..69d118df6 100644 --- a/bot/cogs/watchchannels/__init__.py +++ b/bot/cogs/watchchannels/__init__.py @@ -1,17 +1,9 @@ -import logging - from bot.bot import Bot from .bigbrother import BigBrother from .talentpool import TalentPool -log = logging.getLogger(__name__) - - def setup(bot: Bot) -> None: - """Monitoring cogs load.""" + """Load the BigBrother and TalentPool cogs.""" bot.add_cog(BigBrother(bot)) - log.info("Cog loaded: BigBrother") - bot.add_cog(TalentPool(bot)) - log.info("Cog loaded: TalentPool") diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py index c3c193cb9..5d6b4630b 100644 --- a/bot/cogs/wolfram.py +++ b/bot/cogs/wolfram.py @@ -268,6 +268,5 @@ class Wolfram(Cog): def setup(bot: Bot) -> None: - """Wolfram cog load.""" + """Load the Wolfram cog.""" bot.add_cog(Wolfram(bot)) - log.info("Cog loaded: Wolfram") -- cgit v1.2.3 From a4a53f3b9d1cc9928ee03d4f0ecb8087a527e8ca Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 7 Dec 2019 20:39:09 -0800 Subject: Fix name conflict with the Bot cog --- bot/cogs/bot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index b5642da82..e795e5960 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -17,7 +17,7 @@ log = logging.getLogger(__name__) RE_MARKDOWN = re.compile(r'([*_~`|>])') -class Bot(Cog): +class BotCog(Cog, name="Bot"): """Bot information commands.""" def __init__(self, bot: Bot): @@ -374,9 +374,9 @@ class Bot(Cog): bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) await bot_message.delete() del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") def setup(bot: Bot) -> None: """Load the Bot cog.""" - bot.add_cog(Bot(bot)) + bot.add_cog(BotCog(bot)) -- cgit v1.2.3 From 56578525ac4e5c6d20392d1208b74623c8524bcd Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 8 Dec 2019 01:40:19 -0800 Subject: Properly create and close aiohttp sessions aiohttp throws a warning when a session is created outside of a running async event loop. In aiohttp 4.0 this actually changes to an error instead of merely a warning. Since discord.py manages the event loop with client.run(), some of the "internal" coroutines of the client were overwritten in the bot subclass to be able to hook into when the bot starts and stops. Sessions of both the bot and the API client can now potentially be None if accessed before the sessions have been created. However, if called, the API client's methods will wait for a session to be ready. It will attempt to create a session as soon as the event loop starts (i.e. the bot is running). --- bot/api.py | 41 +++++++++++++++++++++++++++++++++++++++-- bot/bot.py | 34 ++++++++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" - def __init__(self, **kwargs): + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" } @@ -42,12 +42,39 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session = aiohttp.ClientSession(**kwargs) + self.session: Optional[aiohttp.ClientSession] = None + self.loop = loop + + self._ready = asyncio.Event(loop=loop) + self._creation_task = None + self._session_args = kwargs + + self.recreate() @staticmethod def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" + async def _create_session(self) -> None: + """Create the aiohttp session and set the ready event.""" + self.session = aiohttp.ClientSession(**self._session_args) + self._ready.set() + + async def close(self) -> None: + """Close the aiohttp session and unset the ready event.""" + if not self._ready.is_set(): + return + + await self.session.close() + self._ready.clear() + + def recreate(self) -> None: + """Schedule the aiohttp session to be created if it's been closed.""" + if self.session is None or self.session.closed: + # Don't schedule a task if one is already in progress. + if self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session()) + async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient: async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API GET.""" + await self._ready.wait() + async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PATCH.""" + await self._ready.wait() + async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" + await self._ready.wait() + async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" + await self._ready.wait() + async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" + await self._ready.wait() + async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: if resp.status == 204: return None diff --git a/bot/bot.py b/bot/bot.py index f39bfb50a..4b3b991a3 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,6 +1,6 @@ -import asyncio import logging import socket +from typing import Optional import aiohttp from discord.ext import commands @@ -16,6 +16,30 @@ class Bot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.http_session: Optional[aiohttp.ClientSession] = None + self.api_client = api.APIClient(loop=self.loop) + + log.addHandler(api.APILoggingHandler(self.api_client)) + + def add_cog(self, cog: commands.Cog) -> None: + """Adds a "cog" to the bot and logs the operation.""" + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def clear(self) -> None: + """Clears the internal state of the bot and resets the API client.""" + super().clear() + self.api_client.recreate() + + async def close(self) -> None: + """Close the aiohttp session after closing the Discord connection.""" + await super().close() + + await self.http_session.close() + await self.api_client.close() + + async def start(self, *args, **kwargs) -> None: + """Open an aiohttp session before logging in and connecting to Discord.""" # Global aiohttp session for all cogs # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. @@ -26,10 +50,4 @@ class Bot(commands.Bot): ) ) - self.api_client = api.APIClient(loop=asyncio.get_event_loop()) - log.addHandler(api.APILoggingHandler(self.api_client)) - - def add_cog(self, cog: commands.Cog) -> None: - """Adds a "cog" to the bot and logs the operation.""" - super().add_cog(cog) - log.info(f"Cog loaded: {cog.qualified_name}") + await super().start(*args, **kwargs) -- cgit v1.2.3 From 23acfce3521cd420a2df6eb51f036a2a54140ef6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 8 Dec 2019 02:01:01 -0800 Subject: Fix test failures for setup log messages --- tests/bot/cogs/test_duck_pond.py | 12 ++---------- tests/bot/cogs/test_security.py | 11 +++-------- tests/bot/cogs/test_token_remover.py | 8 ++------ 3 files changed, 7 insertions(+), 24 deletions(-) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index b801e86f1..d07b2bce1 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -578,15 +578,7 @@ class DuckPondSetupTests(unittest.TestCase): """Tests setup of the `DuckPond` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = helpers.MockBot() - log = logging.getLogger('bot.cogs.duck_pond') - - with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: - duck_pond.setup(bot) - - self.assertEqual(len(log_watcher.records), 1) - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.INFO) - + duck_pond.setup(bot) bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index efa7a50b1..9d1a62f7e 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -1,4 +1,3 @@ -import logging import unittest from unittest.mock import MagicMock @@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase): """Tests loading the `Security` cog.""" def test_security_cog_load(self): - """Cog loading logs a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MagicMock() - with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: - security.setup(bot) - bot.add_cog.assert_called_once() - - [line] = cm.output - self.assertIn("Cog loaded: Security", line) + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3276cf5a5..a54b839d7 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MockBot() - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: - setup_cog(bot) - - [line] = cm.output + setup_cog(bot) bot.add_cog.assert_called_once() - self.assertIn("Cog loaded: TokenRemover", line) -- cgit v1.2.3 From 34bac05ccc6c11ea370aa14431e4d6d6cd28f1d6 Mon Sep 17 00:00:00 2001 From: Manuel Ignacio Pérez Alcolea Date: Mon, 9 Dec 2019 02:37:30 -0300 Subject: Ensure hidden_channels and bypass_roles use a list when not passed. The in_channel decorator raised 'NoneType' is not iterable when it wasn't passed, due to the default value being None but not checked against before iterating over it. This edit ensures the arguments are set to an empty list in cases where they have a value of None instead. --- bot/decorators.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bot/decorators.py b/bot/decorators.py index 61587f406..2d18eaa6a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -38,6 +38,9 @@ def in_channel( Hidden channels are channels which will not be displayed in the InChannelCheckFailure error message. """ + hidden_channels = hidden_channels or [] + bypass_roles = bypass_roles or [] + def predicate(ctx: Context) -> bool: """In-channel checker predicate.""" if ctx.channel.id in channels or ctx.channel.id in hidden_channels: -- cgit v1.2.3 From dbd7220caed5a6ba759d0cf9efaa0c0c0e57f391 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sun, 8 Dec 2019 23:40:17 -0800 Subject: Use the AsyncResolver for APIClient and discord.py sessions too Active thread counts are observed to be lower with it in use. --- bot/bot.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bot/bot.py b/bot/bot.py index 4b3b991a3..8f808272f 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -14,10 +14,18 @@ class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self.connector = aiohttp.TCPConnector( + resolver=aiohttp.AsyncResolver(), + family=socket.AF_INET, + ) + + super().__init__(*args, connector=self.connector, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.api_client = api.APIClient(loop=self.loop) + self.api_client = api.APIClient(loop=self.loop, connector=self.connector) log.addHandler(api.APILoggingHandler(self.api_client)) @@ -40,14 +48,6 @@ class Bot(commands.Bot): async def start(self, *args, **kwargs) -> None: """Open an aiohttp session before logging in and connecting to Discord.""" - # Global aiohttp session for all cogs - # - Uses asyncio for DNS resolution instead of threads, so we don't spam threads - # - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. - self.http_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector( - resolver=aiohttp.AsyncResolver(), - family=socket.AF_INET, - ) - ) + self.http_session = aiohttp.ClientSession(connector=self.connector) await super().start(*args, **kwargs) -- cgit v1.2.3 From 1b938af27cb9901acdb86579029dc4a7cbae0b7d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Dec 2019 23:18:32 -0800 Subject: Moderation: show HTTP status code in the log for deactivation failures --- bot/cogs/moderation/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 3e0968121..703b09802 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -329,7 +329,7 @@ class InfractionScheduler(Scheduler): log_content = mod_role.mention except discord.HTTPException as e: log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with code {e.code}." + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." log_content = mod_role.mention # Check if the user is currently being watched by Big Brother. -- cgit v1.2.3 From d0e14dca855179bd71c46747ecf63d4038045881 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Dec 2019 23:34:30 -0800 Subject: Moderation: catch HTTPException when applying an infraction Only a warning is logged if it's a Forbidden error. Otherwise, the whole exception is logged. --- bot/cogs/moderation/scheduler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 703b09802..8e5b4691f 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -146,14 +146,18 @@ class InfractionScheduler(Scheduler): if expiry: # Schedule the expiration of the infraction. self.schedule_task(ctx.bot.loop, infraction["id"], infraction) - except discord.Forbidden: + except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = f":x: failed to apply" expiry_msg = "" log_content = ctx.author.mention log_title = "failed to apply" - log.warning(f"Failed to apply {infr_type} infraction #{id_} to {user}.") + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) # Send a confirmation message to the invoking context. log.trace(f"Sending infraction #{id_} confirmation message.") @@ -324,7 +328,7 @@ class InfractionScheduler(Scheduler): f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions") + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention except discord.HTTPException as e: -- cgit v1.2.3 From f0e993a3514c1ef7256c4b7593d4db94a6d34569 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Mon, 9 Dec 2019 23:41:27 -0800 Subject: Infractions: kick user from voice after muting (#644) --- bot/cogs/moderation/infractions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2713a1b68..fe5150652 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -208,8 +208,13 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_update, user.id) - action = user.add_roles(self._muted_role, reason=reason) - await self.apply_infraction(ctx, infraction, user, action) + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) @respect_role_hierarchy() async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None: -- cgit v1.2.3 From 9a3e83116e145b720fc47b0686b357fa6ae9e488 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 02:05:18 -0800 Subject: ErrorHandler: fix #650 tag fallback not respecting checks --- bot/cogs/error_handler.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 49411814c..5fba9633b 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -75,6 +75,16 @@ class ErrorHandler(Cog): tags_get_command = self.bot.get_command("tags get") ctx.invoked_from_error_handler = True + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + # Return to not raise the exception with contextlib.suppress(ResponseCodeError): await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) -- cgit v1.2.3 From d84fc6346197d8176a7989b9b74e94d837d26882 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:39:39 -0800 Subject: Reddit: move token renewal inside fetch_posts This removes the duplicate code for renewing the token. Since fetch_posts is the only place where the token gets used, it can just be refreshed there directly. --- bot/cogs/reddit.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 22bb66bf0..0802c6102 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -122,6 +122,10 @@ class Reddit(Cog): if params is None: params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( @@ -206,11 +210,6 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() # if it's a monday send the top weekly posts @@ -249,10 +248,6 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="all") @@ -261,10 +256,6 @@ class Reddit(Cog): @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="day") @@ -273,10 +264,6 @@ class Reddit(Cog): @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="week") -- cgit v1.2.3 From ddfbfe31b2c2d9e5bc5d46ab9ffffa5b35a63e5f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:12 -0800 Subject: Reddit: move BasicAuth instantiation to __init__ The object is basically just a namedtuple so there's no need to re-create it every time a token is obtained. * Remove log message which shows credentials. * Initialise headers attribute to None in __init__. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0802c6102..48f636159 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -32,8 +32,10 @@ class Reddit(Cog): self.webhook = None self.access_token = None - bot.loop.create_task(self.init_reddit_ready()) + self.headers = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: @@ -61,9 +63,6 @@ class Reddit(Cog): "duration": "temporary" } - log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", -- cgit v1.2.3 From d0f6f794d4fa3dd78ac8be2b95cae669b4587fb3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:35 -0800 Subject: Reddit: use qualified_name attribute when removing the cog --- bot/cogs/reddit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 48f636159..111f3b8ab 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -86,7 +86,7 @@ class Reddit(Cog): await asyncio.sleep(3) log.error("Authentication with Reddit API failed. Unloading extension.") - self.bot.remove_cog(self.__class__.__name__) + self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: -- cgit v1.2.3 From 249a4c185ca9680258eb7a753307fbe8d0089b4b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:49:21 -0800 Subject: Reddit: use expires_in from the response to calculate token expiration --- bot/cogs/reddit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 111f3b8ab..083f90573 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -56,7 +56,7 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get Reddit access tokens.""" + """Get a Reddit API OAuth2 access token.""" headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,10 +72,11 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(hours=1) + expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) self.headers = { "Authorization": "bearer " + self.access_token.token, -- cgit v1.2.3 From e49d9d5429c8e1aedfde5a5d38750890b3361496 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:50:03 -0800 Subject: Reddit: define AccessToken type at the module level --- bot/cogs/reddit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 083f90573..d9e1f0a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -18,6 +18,8 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" @@ -73,7 +75,6 @@ class Reddit(Cog): if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) -- cgit v1.2.3 From 4889326af4f9251218e332f873349e3f8c7bea7b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:18:46 -0800 Subject: Reddit: revise docstrings --- bot/cogs/reddit.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index d9e1f0a39..0a0279a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -41,7 +41,7 @@ class Reddit(Cog): self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() if self.access_token.expires_at < datetime.utcnow(): self.revoke_access_token() @@ -58,7 +58,12 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get a Reddit API OAuth2 access token.""" + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded if retrieval was still unsuccessful. + """ headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,6 +77,7 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. @@ -87,14 +93,16 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading extension.") + log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: - """Revoke the access token for Reddit API.""" - # Access tokens are valid for 1 hour. - # The token should be revoked, since the API is called only once a day. + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ headers = {"User-Agent": self.USER_AGENT} data = { "token": self.access_token.token, @@ -107,12 +115,12 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 204 and response.content_type == "application/json": self.access_token = None self.headers = None - return - - log.warning(f"Unable to revoke access token, status code {response.status}.") + else: + log.warning(f"Unable to revoke access token: status {response.status}.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" -- cgit v1.2.3 From 08881979e25749cb8da9efaccead64bb15b354ec Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:33:39 -0800 Subject: Reddit: create a dict constant for the User-Agent header --- bot/cogs/reddit.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0a0279a39..6af33d9db 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -24,7 +24,7 @@ AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -34,7 +34,6 @@ class Reddit(Cog): self.webhook = None self.access_token = None - self.headers = None self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) bot.loop.create_task(self.init_reddit_ready()) @@ -64,18 +63,15 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "grant_type": "client_credentials", - "duration": "temporary" - } - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } ) if response.status == 200 and response.content_type == "application/json": @@ -85,10 +81,6 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) - self.headers = { - "Authorization": "bearer " + self.access_token.token, - "User-Agent": self.USER_AGENT - } return await asyncio.sleep(3) @@ -103,22 +95,18 @@ class Reddit(Cog): For security reasons, it's good practice to revoke the token when it's no longer being used. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "token": self.access_token.token, - "token_type_hint": "access_token" - } - response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/revoke_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } ) if response.status == 204 and response.content_type == "application/json": self.access_token = None - self.headers = None else: log.warning(f"Unable to revoke access token: status {response.status}.") @@ -128,9 +116,6 @@ class Reddit(Cog): if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} - # Renew the token if necessary. if not self.access_token or self.access_token.expires_at < datetime.utcnow(): await self.get_access_token() @@ -139,7 +124,7 @@ class Reddit(Cog): for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.headers, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': -- cgit v1.2.3 From a4fb51bbeb9b15e1a3718038f280d9c633acfa66 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:45:56 -0800 Subject: Reddit: log retries when getting the access token --- bot/cogs/reddit.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6af33d9db..15b4a108c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -63,7 +63,7 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - for _ in range(self.MAX_RETRIES): + for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", headers=self.HEADERS, @@ -81,7 +81,15 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) await asyncio.sleep(3) -- cgit v1.2.3 From 806ccf73dd896b4272726ce32edea7c882ce9a81 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:47:13 -0800 Subject: Reddit: raise ClientError when the token can't be retrieved Raising an exception allows the error handler to display a message to the user if the failure happened from a command invocation. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 15b4a108c..96af90bc4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -6,7 +6,7 @@ from collections import namedtuple from datetime import datetime, timedelta from typing import List -from aiohttp import BasicAuth +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group from discord.ext.tasks import loop @@ -61,7 +61,7 @@ class Reddit(Cog): Get a Reddit API OAuth2 access token and assign it to self.access_token. A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded if retrieval was still unsuccessful. + will be unloaded and a ClientError raised if retrieval was still unsuccessful. """ for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( @@ -93,9 +93,8 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) - return + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") async def revoke_access_token(self) -> None: """ -- cgit v1.2.3 From 2d69e1293ad659b4f4fd7f5e5029b6591328ebc6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 18:07:06 -0800 Subject: Clean: un-hide from help and add purge alias --- bot/cogs/clean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index dca411d01..a45d30142 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -167,7 +167,7 @@ class Clean(Cog): channel_id=Channels.modlog, ) - @group(invoke_without_command=True, name="clean", hidden=True) + @group(invoke_without_command=True, name="clean", aliases=["purge"]) @with_role(*MODERATION_ROLES) async def clean_group(self, ctx: Context) -> None: """Commands for cleaning messages in channels.""" -- cgit v1.2.3 From 7130de271ddfde568d2d008823656e5371b4dc45 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 18:13:05 -0800 Subject: Clean: support specifying a channel different than the context's --- bot/cogs/clean.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index a45d30142..312c7926d 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -3,7 +3,7 @@ import random import re from typing import Optional -from discord import Colour, Embed, Message, User +from discord import Colour, Embed, Message, TextChannel, User from discord.ext.commands import Bot, Cog, Context, group from bot.cogs.moderation import ModLog @@ -39,7 +39,8 @@ class Clean(Cog): async def _clean_messages( self, amount: int, ctx: Context, bots_only: bool = False, user: User = None, - regex: Optional[str] = None + regex: Optional[str] = None, + channel: Optional[TextChannel] = None ) -> None: """A helper function that does the actual message cleaning.""" def predicate_bots_only(message: Message) -> bool: @@ -104,6 +105,10 @@ class Clean(Cog): else: predicate = None # Delete all messages + # Default to using the invoking context's channel + if not channel: + channel = ctx.channel + # Look through the history and retrieve message data messages = [] message_ids = [] @@ -111,7 +116,7 @@ class Clean(Cog): invocation_deleted = False # To account for the invocation message, we index `amount + 1` messages. - async for message in ctx.channel.history(limit=amount + 1): + async for message in channel.history(limit=amount + 1): # If at any point the cancel command is invoked, we should stop. if not self.cleaning: @@ -135,7 +140,7 @@ class Clean(Cog): self.mod_log.ignore(Event.message_delete, *message_ids) # Use bulk delete to actually do the cleaning. It's far faster. - await ctx.channel.purge( + await channel.purge( limit=amount, check=predicate ) @@ -155,7 +160,7 @@ class Clean(Cog): # Build the embed and send it message = ( - f"**{len(message_ids)}** messages deleted in <#{ctx.channel.id}> by **{ctx.author.name}**\n\n" + f"**{len(message_ids)}** messages deleted in <#{channel.id}> by **{ctx.author.name}**\n\n" f"A log of the deleted messages can be found [here]({log_url})." ) @@ -175,27 +180,27 @@ class Clean(Cog): @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) - async def clean_user(self, ctx: Context, user: User, amount: int = 10) -> None: + async def clean_user(self, ctx: Context, user: User, amount: int = 10, channel: TextChannel = None) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user) + await self._clean_messages(amount, ctx, user=user, channel=channel) @clean_group.command(name="all", aliases=["everything"]) @with_role(*MODERATION_ROLES) - async def clean_all(self, ctx: Context, amount: int = 10) -> None: + async def clean_all(self, ctx: Context, amount: int = 10, channel: TextChannel = None) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx) + await self._clean_messages(amount, ctx, channel=channel) @clean_group.command(name="bots", aliases=["bot"]) @with_role(*MODERATION_ROLES) - async def clean_bots(self, ctx: Context, amount: int = 10) -> None: + async def clean_bots(self, ctx: Context, amount: int = 10, channel: TextChannel = None) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True) + await self._clean_messages(amount, ctx, bots_only=True, channel=channel) @clean_group.command(name="regex", aliases=["word", "expression"]) @with_role(*MODERATION_ROLES) - async def clean_regex(self, ctx: Context, regex: str, amount: int = 10) -> None: + async def clean_regex(self, ctx: Context, regex: str, amount: int = 10, channel: TextChannel = None) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex) + await self._clean_messages(amount, ctx, regex=regex, channel=channel) @clean_group.command(name="stop", aliases=["cancel", "abort"]) @with_role(*MODERATION_ROLES) -- cgit v1.2.3 From 0ddf9e66ea4e7a9753cc7044da4bc06d1e9cfb7d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 19:21:31 -0800 Subject: Verification: allow mods+ to use commands in checkpoint (#688) --- bot/cogs/verification.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b5e8d4357..b62a08db6 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,9 +9,10 @@ from bot.cogs.moderation import ModLog from bot.constants import ( Bot as BotConfig, Channels, Colours, Event, - Filter, Icons, Roles + Filter, Icons, MODERATION_ROLES, Roles ) from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -189,7 +190,7 @@ class Verification(Cog): @staticmethod def bot_check(ctx: Context) -> bool: """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == Channels.verification: + if ctx.channel.id == Channels.verification and without_role_check(ctx, *MODERATION_ROLES): return ctx.command.name == "accept" else: return True -- cgit v1.2.3 From 65c7319c7bd83e6f8833b323d5dd408ca771cf9d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 19:28:41 -0800 Subject: Verification: delete bots' messages (#689) Messages are deleted after a delay of 10 seconds. This helps keep the channel clean. The periodic ping is an exception; it will remain. --- bot/cogs/verification.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b62a08db6..2d759f885 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -38,6 +38,7 @@ PERIODIC_PING = ( f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel." ) +BOT_MESSAGE_DELETE_DELAY = 10 class Verification(Cog): @@ -56,7 +57,11 @@ class Verification(Cog): async def on_message(self, message: Message) -> None: """Check new message event for messages to the checkpoint channel & process.""" if message.author.bot: - return # They're a bot, ignore + # They're a bot, delete their message after the delay. + # But not the periodic ping; we like that one. + if message.content != PERIODIC_PING: + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return if message.channel.id != Channels.verification: return # Only listen for #checkpoint messages -- cgit v1.2.3 From 9d551cc69c1935165389f26f52753895604dd3f5 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 20:26:26 -0800 Subject: Add a generic converter for only allowing certain string values --- bot/cogs/moderation/management.py | 13 ++----------- bot/converters.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..50bce3981 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, string from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -22,15 +22,6 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" @@ -61,7 +52,7 @@ class ModManagement(commands.Cog): self, ctx: Context, infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + duration: t.Union[utils.Expiry, string("permanent"), None], *, reason: str = None ) -> None: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..2cfc42903 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converter to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") -- cgit v1.2.3 From 729ac3d83a3bd4620d1e9b24769466e219d45de6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:00:47 -0800 Subject: ModManagement: allow "recent" as ID to edit infraction (#624) It will attempt to find the most recent infraction authored by the invoker of the edit command. --- bot/cogs/moderation/management.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 50bce3981..35832ded5 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -51,7 +51,7 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, + infraction_id: t.Union[int, string("recent")], duration: t.Union[utils.Expiry, string("permanent"), None], *, reason: str = None @@ -69,6 +69,9 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds + Use "recent" as the infraction ID to specify that the ost recent infraction authored by the + command invoker should be edited. + Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp can be provided for the duration. """ @@ -77,7 +80,23 @@ class ModManagement(commands.Cog): raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if infraction_id == "recent": + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] -- cgit v1.2.3 From c1bf0a48692d87c5cbe9ee310cd0120bda339a96 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:01:16 -0800 Subject: ModManagement: display ID of edited infraction in confirmation message --- bot/cogs/moderation/management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 35832ded5..904611e13 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -139,7 +139,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] -- cgit v1.2.3 From ce52c836ff2e14cd9cfade3a4fcfe8a0e5071e2e Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:21:10 -0800 Subject: Moderation: show emoji for DM failure instead of mentioning actor (#534) --- bot/cogs/moderation/scheduler.py | 5 ++--- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 3e0968121..0ab1fe997 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -113,8 +113,8 @@ class InfractionScheduler(Scheduler): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" else: + dm_result = f"{constants.Emojis.failmail} " dm_log_text = "\nDM: **Failed**" - log_content = ctx.author.mention if infraction["actor"] == self.bot.user.id: log.trace( @@ -250,8 +250,7 @@ class InfractionScheduler(Scheduler): if log_text.get("DM") == "Sent": dm_emoji = ":incoming_envelope: " elif "DM" in log_text: - # Mention the actor because the DM failed to send. - log_content = ctx.author.mention + dm_emoji = f"{constants.Emojis.failmail} " # Accordingly display whether the pardon failed. if "Failure" in log_text: diff --git a/bot/constants.py b/bot/constants.py index 89504a2e0..3b1ca2887 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -256,6 +256,8 @@ class Emojis(metaclass=YAMLGetter): status_idle: str status_dnd: str + failmail: str + bullet: str new: str pencil: str diff --git a/config-default.yml b/config-default.yml index 930a1a0e6..9e6ada3dd 100644 --- a/config-default.yml +++ b/config-default.yml @@ -27,6 +27,8 @@ style: status_dnd: "<:status_dnd:470326272082313216>" status_offline: "<:status_offline:470326266537705472>" + failmail: "<:failmail:633660039931887616>" + bullet: "\u2022" pencil: "\u270F" new: "\U0001F195" -- cgit v1.2.3 From 56833bbe99ce3cd93af87e9d33cec47a059b61f3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:52:57 -0800 Subject: ModManagement: add more aliases for "special" params of infraction edit --- bot/cogs/moderation/management.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 904611e13..37bdb1934 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -51,8 +51,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: t.Union[int, string("recent")], - duration: t.Union[utils.Expiry, string("permanent"), None], + infraction_id: t.Union[int, string("l", "last", "recent")], + duration: t.Union[utils.Expiry, string("p", "permanent"), None], *, reason: str = None ) -> None: @@ -69,18 +69,18 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds - Use "recent" as the infraction ID to specify that the ost recent infraction authored by the - command invoker should be edited. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - if infraction_id == "recent": + if isinstance(infraction_id, str): params = { "actor__id": ctx.author.id, "ordering": "-inserted_at" @@ -102,7 +102,7 @@ class ModManagement(commands.Cog): confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: -- cgit v1.2.3 From eb53a4594dff20372574058ec90062995362b098 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 22:03:20 -0800 Subject: Converters: rename string to allowed_strings --- bot/cogs/moderation/management.py | 6 +++--- bot/converters.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 37bdb1934..20ff25ba1 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery, string +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -51,8 +51,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: t.Union[int, string("l", "last", "recent")], - duration: t.Union[utils.Expiry, string("p", "permanent"), None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: diff --git a/bot/converters.py b/bot/converters.py index 2cfc42903..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -15,11 +15,11 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) -def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: """ Return a converter which only allows arguments equal to one of the given values. - Unless preserve_case is True, the argument is converter to lowercase. All values are then + Unless preserve_case is True, the argument is converted to lowercase. All values are then expected to have already been given in lowercase too. """ def converter(arg: str) -> str: -- cgit v1.2.3 From 38129d632648da21726a8158b76676695fd0b512 Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 11 Dec 2019 22:50:44 -0800 Subject: Clean: allow amount argument to be skipped This make the channel specifiable without the amount. Co-Authored-By: scragly <29337040+scragly@users.noreply.github.com> --- bot/cogs/clean.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index 312c7926d..ed1962565 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -180,25 +180,25 @@ class Clean(Cog): @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) - async def clean_user(self, ctx: Context, user: User, amount: int = 10, channel: TextChannel = None) -> None: + async def clean_user(self, ctx: Context, user: User, amount: Optional[int] = 10, channel: TextChannel = None) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, user=user, channel=channel) @clean_group.command(name="all", aliases=["everything"]) @with_role(*MODERATION_ROLES) - async def clean_all(self, ctx: Context, amount: int = 10, channel: TextChannel = None) -> None: + async def clean_all(self, ctx: Context, amount: Optional[int] = 10, channel: TextChannel = None) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, channel=channel) @clean_group.command(name="bots", aliases=["bot"]) @with_role(*MODERATION_ROLES) - async def clean_bots(self, ctx: Context, amount: int = 10, channel: TextChannel = None) -> None: + async def clean_bots(self, ctx: Context, amount: Optional[int] = 10, channel: TextChannel = None) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, bots_only=True, channel=channel) @clean_group.command(name="regex", aliases=["word", "expression"]) @with_role(*MODERATION_ROLES) - async def clean_regex(self, ctx: Context, regex: str, amount: int = 10, channel: TextChannel = None) -> None: + async def clean_regex(self, ctx: Context, regex: str, amount: Optional[int] = 10, channel: TextChannel = None) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, regex=regex, channel=channel) -- cgit v1.2.3 From a657fd4ebfaebd2a419f81dbda14f93b395380ff Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 22:54:40 -0800 Subject: Clean: reformat arguments --- bot/cogs/clean.py | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index ed1962565..432c9e998 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -37,10 +37,13 @@ class Clean(Cog): return self.bot.get_cog("ModLog") async def _clean_messages( - self, amount: int, ctx: Context, - bots_only: bool = False, user: User = None, - regex: Optional[str] = None, - channel: Optional[TextChannel] = None + self, + amount: int, + ctx: Context, + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + channel: Optional[TextChannel] = None ) -> None: """A helper function that does the actual message cleaning.""" def predicate_bots_only(message: Message) -> bool: @@ -180,25 +183,47 @@ class Clean(Cog): @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) - async def clean_user(self, ctx: Context, user: User, amount: Optional[int] = 10, channel: TextChannel = None) -> None: + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, user=user, channel=channel) @clean_group.command(name="all", aliases=["everything"]) @with_role(*MODERATION_ROLES) - async def clean_all(self, ctx: Context, amount: Optional[int] = 10, channel: TextChannel = None) -> None: + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, channel=channel) @clean_group.command(name="bots", aliases=["bot"]) @with_role(*MODERATION_ROLES) - async def clean_bots(self, ctx: Context, amount: Optional[int] = 10, channel: TextChannel = None) -> None: + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, bots_only=True, channel=channel) @clean_group.command(name="regex", aliases=["word", "expression"]) @with_role(*MODERATION_ROLES) - async def clean_regex(self, ctx: Context, regex: str, amount: Optional[int] = 10, channel: TextChannel = None) -> None: + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" await self._clean_messages(amount, ctx, regex=regex, channel=channel) -- cgit v1.2.3 From 97f0cb8efb82217d28123e834454d1316f04b031 Mon Sep 17 00:00:00 2001 From: Joseph Date: Fri, 13 Dec 2019 00:41:56 +0000 Subject: Revert "Use OAuth to be Reddit API compliant" --- azure-pipelines.yml | 2 +- bot/cogs/reddit.py | 91 ++++++----------------------------------------------- bot/constants.py | 2 -- config-default.yml | 2 -- docker-compose.yml | 2 -- 5 files changed, 11 insertions(+), 88 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0400ac4d2..da3b06201 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index aa487f18e..bec316ae7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,11 +2,9 @@ import asyncio import logging import random import textwrap -from collections import namedtuple from datetime import datetime, timedelta from typing import List -from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop @@ -19,32 +17,25 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 + MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - + self.webhook = None # set in on_ready bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" + """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() - if self.access_token.expires_at < datetime.utcnow(): - self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -57,82 +48,20 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() + if params is None: + params = {} - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): + url = f"{self.URL}/{route}.json" + for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + headers=self.HEADERS, params=params ) if response.status == 200 and response.content_type == 'application/json': diff --git a/bot/constants.py b/bot/constants.py index ed85adf6a..89504a2e0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -465,8 +465,6 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list - client_id: str - secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index e6f0fda21..930a1a0e6 100644 --- a/config-default.yml +++ b/config-default.yml @@ -365,8 +365,6 @@ anti_malware: reddit: subreddits: - 'r/Python' - client_id: !ENV "REDDIT_CLIENT_ID" - secret: !ENV "REDDIT_SECRET" wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index 7281c7953..f79fdba58 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,5 +42,3 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 - REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} - REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From 8a545b495a19715cc519afc9a867e16787cd5212 Mon Sep 17 00:00:00 2001 From: Joseph Date: Fri, 13 Dec 2019 00:48:11 +0000 Subject: Revert "Revert "Use OAuth to be Reddit API compliant"" --- azure-pipelines.yml | 2 +- bot/cogs/reddit.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++------ bot/constants.py | 2 ++ config-default.yml | 2 ++ docker-compose.yml | 2 ++ 5 files changed, 88 insertions(+), 11 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index bec316ae7..aa487f18e 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,9 +2,11 @@ import asyncio import logging import random import textwrap +from collections import namedtuple from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop @@ -17,25 +19,32 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" - MAX_FETCH_RETRIES = 3 + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready - bot.loop.create_task(self.init_reddit_ready()) + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -48,20 +57,82 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() - url = f"{self.URL}/{route}.json" - for _ in range(self.MAX_FETCH_RETRIES): + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': diff --git a/bot/constants.py b/bot/constants.py index 89504a2e0..ed85adf6a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -465,6 +465,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 930a1a0e6..e6f0fda21 100644 --- a/config-default.yml +++ b/config-default.yml @@ -365,6 +365,8 @@ anti_malware: reddit: subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From c5109844e45c37bc1cc38eb1c3da31d52ab2aa6d Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:17:21 +0700 Subject: Adding an optional argument for `until_expiration`, update typehints for `format_infraction_with_duration` - `until_expiration` was being a pain to unittests without a `now` ( default to `datetime.utcnow()` ). Adding an optional argument for this will not only make writing tests easier, but also allow more control over the helper function should we need to calculate the remaining time between two dates in the past. - Changed typehint for `date_from` in `format_infraction_with_duration` to `Optional[datetime.datetime]` to better reflect what it is. --- bot/utils/time.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index ac64865d6..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -115,7 +115,7 @@ def format_infraction(timestamp: str) -> str: def format_infraction_with_duration( expiry: Optional[str], - date_from: datetime.datetime = None, + date_from: Optional[datetime.datetime] = None, max_units: int = 2 ) -> Optional[str]: """ @@ -140,10 +140,15 @@ def format_infraction_with_duration( return f"{expiry_formatted}{duration_formatted}" -def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str]: +def until_expiration( + expiry: Optional[str], + now: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: """ Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). By default, max_units is 2. @@ -151,7 +156,7 @@ def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str] if not expiry: return None - now = datetime.datetime.utcnow() + now = now or datetime.datetime.utcnow() since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) if since < now: -- cgit v1.2.3 From 520346d0b472e5cb6c9091a8323b871d2e3821cc Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:18:49 +0700 Subject: Added tests for `until_expiration` Similar to `format_infraction_with_duration` ( if not outright copying it ), added 3 tests for `until_expiration`: - None `expiry`. - Custom `max_units`. - Normal use cases. --- tests/bot/utils/test_time.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 7f55dc3ec..bd04de28b 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -115,3 +115,48 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) -- cgit v1.2.3 From 66d4b93593b95bfa6999b70aca53328d83710c44 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:29:06 +0700 Subject: Fixed a typo ( due to poor copy pasta and eyeballing skills ) --- tests/bot/utils/test_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index bd04de28b..69f35f2f5 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -121,7 +121,7 @@ class TimeTests(unittest.TestCase): test_cases = ( (None, None, None, None), - # To make sure that date_from and max_units are not touched + # To make sure that now and max_units are not touched (None, 'Why hello there!', None, None), (None, None, float('inf'), None), (None, 'Why hello there!', float('inf'), None), -- cgit v1.2.3 From c9cc19c27f3a3458910012e4c15442b0974b9fb3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Thu, 12 Dec 2019 20:34:10 -0800 Subject: Verification: check channel before checking for bot messages --- bot/cogs/verification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 2d759f885..ec0f9627e 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -56,6 +56,9 @@ class Verification(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Check new message event for messages to the checkpoint channel & process.""" + if message.channel.id != Channels.verification: + return # Only listen for #checkpoint messages + if message.author.bot: # They're a bot, delete their message after the delay. # But not the periodic ping; we like that one. @@ -63,9 +66,6 @@ class Verification(Cog): await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) return - if message.channel.id != Channels.verification: - return # Only listen for #checkpoint messages - # if a user mentions a role or guild member # alert the mods in mod-alerts channel if message.mentions or message.role_mentions: -- cgit v1.2.3