From 1b6f3d23d4c0b1f6dfe1354a3a210e589f7b4956 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 18:40:02 +0200 Subject: Make sure that poor code does not contains token Added a new function `is_token_in_message` in `token_remover`. This function returns a `bool` and if the code contains a token then the embed message about the poorly formatted code is not displayed. --- bot/cogs/bot.py | 3 ++- bot/cogs/token_remover.py | 32 +++++++++++++++++++------------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 7583b2f2d..e8ac0a234 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple from discord import Embed, Message, RawMessageUpdateEvent from discord.ext.commands import Bot, Cog, Context, command, group +from bot.cogs.token_remover import TokenRemover from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion @@ -237,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock: + if parse_codeblock and not TokenRemover.is_token_in_message: # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 7dd0afbbd..8f356cf19 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -52,19 +52,8 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ - if msg.author.bot: - return - - maybe_match = TOKEN_RE.search(msg.content) - if maybe_match is None: - return - - try: - user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') - except ValueError: - return - - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + if self.is_token_in_message(msg): + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') self.mod_log.ignore(Event.message_delete, msg.id) await msg.delete() await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) @@ -86,6 +75,23 @@ class TokenRemover(Cog): channel_id=Channels.mod_alerts, ) + def is_token_in_message(self, msg: Message) -> bool: + """Check if `msg` contains a seemly valid token.""" + if msg.author.bot: + return False + + maybe_match = TOKEN_RE.search(msg.content) + if maybe_match is None: + return False + + try: + user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') + except ValueError: + return False + + if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + return True + @staticmethod def is_valid_user_id(b64_content: str) -> bool: """ -- cgit v1.2.3 From 2899bac85c3c0529b354a762ba27a587a520d7cd Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 18:51:18 +0200 Subject: minor fix --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index e8ac0a234..729550c1a 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message: # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(msg): # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From b94e8487c22d7c25ab09bb3d44c44d62e5a2b613 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 21:33:43 +0200 Subject: Another fix After a new bunch of test I found bugs, and this fix resolves them --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 729550c1a..b8de29f2a 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(msg): # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # if there is no token in the code on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From 25d4c05b2656ce8d9454269c77d42e18fb1ba785 Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Mon, 7 Oct 2019 21:42:49 +0200 Subject: fix linting error fix linting error --- bot/cogs/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index b8de29f2a..eab253681 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # if there is no token in the code + if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From 2cb7ee12805957c7d655679ff54a14f16e059a80 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6880aab85..bf4403ce4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -34,6 +41,66 @@ class Reddit(Cog): self.new_posts_task = None self.top_weekly_posts_task = None + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -43,11 +110,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -55,7 +122,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -127,8 +194,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From aa0096469e12546b0eadbb4b214cd3cae3a3a80d Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Sat, 12 Oct 2019 14:54:58 +0200 Subject: Use a `classmethod` --- bot/cogs/bot.py | 2 +- bot/cogs/token_remover.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index eab253681..53221cd8b 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -238,7 +238,7 @@ class Bot(Cog): and len(msg.content.splitlines()) > 3 ) - if parse_codeblock and not TokenRemover.is_token_in_message(TokenRemover, msg): # no token in the msg + if parse_codeblock and not TokenRemover.is_token_in_message(msg): # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 8f356cf19..5e83a777e 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -75,6 +75,7 @@ class TokenRemover(Cog): channel_id=Channels.mod_alerts, ) + @classmethod def is_token_in_message(self, msg: Message) -> bool: """Check if `msg` contains a seemly valid token.""" if msg.author.bot: -- cgit v1.2.3 From 9b4b98eb6f37060771c61979458737629b3c5db7 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..25df014f8 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -35,6 +42,65 @@ class Reddit(Cog): self.top_weekly_posts_task = None self.bot.loop.create_task(self.init_reddit_polling()) + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -45,11 +111,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -57,7 +123,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -129,8 +195,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From 79ed098809ce3cfaa0fa75608f6f6a85af2a90dd Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 15 Oct 2019 23:36:36 +0200 Subject: Unload cog on auth error and fix linting warnings --- bot/cogs/reddit.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 25df014f8..451d2bf4c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,10 @@ import asyncio import logging import random import textwrap -from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, Message, TextChannel from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group @@ -46,7 +46,7 @@ class Reddit(Cog): @tasks.loop(hours=0.99) # access tokens are valid for one hour async def refresh_access_token(self) -> None: - """Refresh the access token""" + """Refresh Reddits access token.""" headers = {"Authorization": self.client_auth} data = { "grant_type": "refresh_token", @@ -54,7 +54,7 @@ class Reddit(Cog): } response = await self.bot.http_session.post( - url = f"{self.URL}/api/v1/access_token", + url=f"{self.URL}/api/v1/access_token", headers=headers, data=data, ) @@ -68,7 +68,7 @@ class Reddit(Cog): @refresh_access_token.before_loop async def get_tokens(self) -> None: - """Get Reddit access and refresh tokens""" + """Get Reddit access and refresh tokens.""" await self.bot.wait_until_ready() headers = {"User-Agent": self.USER_AGENT} @@ -77,16 +77,16 @@ class Reddit(Cog): "duration": "permanent" } - if RedditConfig.client_id and RedditConfig.secret: - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=headers, - auth=self.client_auth, - data=data - ) + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() self.access_token = content["access_token"] self.refresh_token = content["refresh_token"] @@ -95,12 +95,9 @@ class Reddit(Cog): "User-Agent": self.USER_AGENT } else: - self.client_auth = None - self.access_token = None - self.refresh_token = None - self.headers = None - - log.error("Unable to find client credentials.") + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -123,7 +120,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") -- cgit v1.2.3 From bc907daa428d755d7f2cb0a6b945a179d523b31d Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Sun, 20 Oct 2019 21:35:38 +0200 Subject: Add check when a message is edited --- bot/cogs/token_remover.py | 60 +++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5e83a777e..e5b0e5b45 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -53,30 +53,44 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ if self.is_token_in_message(msg): - user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') - self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - message = ( - "Censored a seemingly valid token sent by " - f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " - f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" - ) - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) + await self.take_action(msg) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + if self.is_token_in_message(after): + await self.take_action(after) + + async def take_action(self, msg: Message) -> None: + """Remove the `msg` containing a token an send a mod_log message.""" + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') + self.mod_log.ignore(Event.message_delete, msg.id) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + message = ( + "Censored a seemingly valid token sent by " + f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " + f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" + ) + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) @classmethod - def is_token_in_message(self, msg: Message) -> bool: + def is_token_in_message(cls, msg: Message) -> bool: """Check if `msg` contains a seemly valid token.""" if msg.author.bot: return False @@ -90,7 +104,7 @@ class TokenRemover(Cog): except ValueError: return False - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): + if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): return True @staticmethod -- cgit v1.2.3 From 5ec4db0cba484f8adfc25b642a4f24f362a5b53c Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 22 Oct 2019 22:05:50 +0200 Subject: Add reddit environment variable, change User-Agent and fix lint problem --- bot/cogs/reddit.py | 4 ++-- docker-compose.yml | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7b183221c..76da0f09f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -25,7 +25,7 @@ class Reddit(Cog): # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: # :: (by /u/) - USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 @@ -117,7 +117,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From 2c0fffbc697db5f33e759f733c310f9f0b754d11 Mon Sep 17 00:00:00 2001 From: Jens Date: Sat, 26 Oct 2019 17:24:09 +0200 Subject: Fix linting error --- bot/cogs/reddit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7e2ba40d5..64a940af1 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -20,6 +20,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" + # Change your client's User-Agent string to something unique and descriptive, # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: -- cgit v1.2.3 From 7e25475b78df01646cbc82176443f955bb6d1964 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:01:40 +0700 Subject: Improved type hinting for `format_infraction_with_duration` --- bot/utils/time.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..9520b32f8 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( + expiry: Optional[str], + date_from: datetime.datetime = None, + max_units: int = 2 +) -> Optional[str]: """ Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. -- cgit v1.2.3 From 51f80015c5db9ab8e85ea2304789491d4c72c053 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:03:16 +0700 Subject: Created `until_expiration` to get the remaining time until the infraction expires. --- bot/utils/time.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/bot/utils/time.py b/bot/utils/time.py index 9520b32f8..ac64865d6 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -138,3 +138,23 @@ def format_infraction_with_duration( duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str]: + """ + Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + now = datetime.datetime.utcnow() + since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + if since < now: + return None + + return humanize_delta(relativedelta(since, now), max_units=max_units) -- cgit v1.2.3 From 82eb5e1c46e378a6f3778e17cc342193b910ded5 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 15:04:20 +0700 Subject: Implemented remaining time until expiration for infraction searching. Will show the remaining time, `Expired.` or `Inactive.` based on the status of the infraction ( It can be inactive but not expired, like an early unmute ) --- bot/cogs/moderation/management.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..2f5e09f1b 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -232,6 +232,12 @@ class ModManagement(commands.Cog): user_id = infraction["user"] hidden = infraction["hidden"] created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or 'Expired.' + else: + remaining = 'Inactive.' + if infraction["expires_at"] is None: expires = "*Permanent*" else: @@ -247,6 +253,7 @@ class ModManagement(commands.Cog): Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} + Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` {"**===============**" if active else "==============="} -- cgit v1.2.3 From c1aeb6d263172168f77845408e8d2756f6cb2813 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 17:12:25 +0700 Subject: Apply suggestions from Mark - removing `.` at the end and use double quote instead of single. Co-Authored-By: Mark --- bot/cogs/moderation/management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 2f5e09f1b..74f75781d 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -234,9 +234,9 @@ class ModManagement(commands.Cog): created = time.format_infraction(infraction["inserted_at"]) if active: - remaining = time.until_expiration(infraction["expires_at"]) or 'Expired.' + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" else: - remaining = 'Inactive.' + remaining = "Inactive" if infraction["expires_at"] is None: expires = "*Permanent*" -- cgit v1.2.3 From e07cf7342184b769d8c0655bc9b84be02809319a Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Wed, 4 Dec 2019 23:38:46 +0700 Subject: Added `unittest` for `bot.utils.time` --- tests/bot/utils/test_time.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/bot/utils/test_time.py diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..0ef59292e --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,87 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def setUp(self): + pass + + def test_humanize_delta(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), + + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), + ) + + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + test_cases = (-1, 0) + + for max_units in test_cases: + with self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + test_cases = ( + ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), + ) + + for stamp, expected in test_cases: + self.assertEqual(time.parse_rfc1123(stamp), expected) + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + assert asyncio.run(time.wait_until(then, start)) is None + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration(self): + """Testing format_infraction_with_duration.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 6, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) -- cgit v1.2.3 From b17dbe5e3e0dfa6ae44d660924455f709abefd0d Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:34:58 +0700 Subject: Splitting test cases for `humanize_delta` into proper, independent tests. --- tests/bot/utils/test_time.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 0ef59292e..5e5f2bf2f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -15,18 +15,20 @@ class TimeTests(unittest.TestCase): def setUp(self): pass - def test_humanize_delta(self): - """Testing humanize delta.""" + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" test_cases = ( - (relativedelta(days=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), - (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'days', 2, '2 days'), - # Does not abort for unknown units, as the unit name is checked # against the attribute of the relativedelta instance. (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), + ) + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + test_cases = ( # Very high maximum units, but it only ever iterates over # each value the relativedelta might have. (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), @@ -35,6 +37,18 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + def test_humanize_delta_should_work_normally(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + def test_humanize_delta_raises_for_invalid_max_units(self): test_cases = (-1, 0) -- cgit v1.2.3 From 0aee728d6d23ef24f51834f39016f938f3f1b8a9 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:36:29 +0700 Subject: Added missing docstring for `test_humanize_delta_raises_for_invalid_max_units` --- tests/bot/utils/test_time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 5e5f2bf2f..a929bee89 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -50,6 +50,7 @@ class TimeTests(unittest.TestCase): self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max units.""" test_cases = (-1, 0) for max_units in test_cases: -- cgit v1.2.3 From beed21355e7f0e25b69637768843c53d510b8969 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:40:38 +0700 Subject: Changed `assert` to `self.assertIs` for `test_wait_until` --- tests/bot/utils/test_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index a929bee89..0afabe400 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -74,7 +74,7 @@ class TimeTests(unittest.TestCase): then = datetime(2019, 1, 1, 0, 10) # No return value - assert asyncio.run(time.wait_until(then, start)) is None + self.assertIs(asyncio.run(time.wait_until(then, start)), None) mock.assert_called_once_with(10 * 60) -- cgit v1.2.3 From ccdd8363d75846f0841791ba54763dae28243c62 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:45:36 +0700 Subject: Splitting test cases for `format_infraction_with_duration` into proper, independent tests. --- tests/bot/utils/test_time.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 0afabe400..2a2a707d8 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -37,7 +37,7 @@ class TimeTests(unittest.TestCase): for delta, precision, max_units, expected in test_cases: self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) - def test_humanize_delta_should_work_normally(self): + def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" test_cases = ( (relativedelta(days=2), 'seconds', 1, '2 days'), @@ -78,18 +78,38 @@ class TimeTests(unittest.TestCase): mock.assert_called_once_with(10 * 60) - def test_format_infraction_with_duration(self): - """Testing format_infraction_with_duration.""" + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + self.assertEqual(time.format_infraction_with_duration(None), None) + + # To make sure that date_from and max_units are not touched + self.assertEqual(time.format_infraction_with_duration(None, date_from='Why hello there!'), None) + self.assertEqual(time.format_infraction_with_duration(None, max_units=float('inf')), None) + self.assertEqual( + time.format_infraction_with_duration(None, date_from='Why hello there!', max_units=float('inf')), + None + ) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + self.assertEqual( + time.format_infraction_with_duration('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6), + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)' + ) + + self.assertEqual( + time.format_infraction_with_duration('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20), + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)' + ) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" test_cases = ( ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), - ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, - '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), - ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 6, - '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)'), ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), -- cgit v1.2.3 From fa66195dbb6f79bb7174084835499a61e8cb03a3 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 00:52:15 +0700 Subject: Introduced test for `test_format_infraction`, refactored `test_parse_rfc1123`, fixed typo. --- tests/bot/utils/test_time.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 2a2a707d8..09fb824e4 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -50,7 +50,7 @@ class TimeTests(unittest.TestCase): self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): - """humanize_delta should raises ValueError('max_units must be positive') for invalid max units.""" + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" test_cases = (-1, 0) for max_units in test_cases: @@ -60,12 +60,14 @@ class TimeTests(unittest.TestCase): def test_parse_rfc1123(self): """Testing parse_rfc1123.""" - test_cases = ( - ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) ) - for stamp, expected in test_cases: - self.assertEqual(time.parse_rfc1123(stamp), expected) + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') @patch('asyncio.sleep', new_callable=AsyncMock) def test_wait_until(self, mock): -- cgit v1.2.3 From 5e0b19ae841f3f355931ad331f7aa861fbafc4d9 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:05:41 +0700 Subject: Added `self.subTest` for tests with multiple test cases & simplified single test case tests. --- tests/bot/utils/test_time.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 09fb824e4..c47a306f0 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -17,25 +17,15 @@ class TimeTests(unittest.TestCase): def test_humanize_delta_handle_unknown_units(self): """humanize_delta should be able to handle unknown units, and will not abort.""" - test_cases = ( - # Does not abort for unknown units, as the unit name is checked - # against the attribute of the relativedelta instance. - (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), - ) - - for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') def test_humanize_delta_handle_high_units(self): """humanize_delta should be able to handle very high units.""" - test_cases = ( - # Very high maximum units, but it only ever iterates over - # each value the relativedelta might have. - (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), - ) - - for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') def test_humanize_delta_should_normal_usage(self): """Testing humanize delta.""" @@ -47,14 +37,15 @@ class TimeTests(unittest.TestCase): ) for delta, precision, max_units, expected in test_cases: - self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) def test_humanize_delta_raises_for_invalid_max_units(self): """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" test_cases = (-1, 0) for max_units in test_cases: - with self.assertRaises(ValueError) as error: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) self.assertEqual(str(error), 'max_units must be positive') @@ -121,4 +112,5 @@ class TimeTests(unittest.TestCase): ) for expiry, date_from, max_units, expected in test_cases: - self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) -- cgit v1.2.3 From db341d927aab42c2e874cb499ab1c2e6c0e7647b Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:17:56 +0700 Subject: Moved all individual test cases into iterables and test with `self.subTest` context manager. --- tests/bot/utils/test_time.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index c47a306f0..25cd3f69f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -73,27 +73,31 @@ class TimeTests(unittest.TestCase): def test_format_infraction_with_duration_none_expiry(self): """format_infraction_with_duration should work for None expiry.""" - self.assertEqual(time.format_infraction_with_duration(None), None) + test_cases = ( + (None, None, None, None), - # To make sure that date_from and max_units are not touched - self.assertEqual(time.format_infraction_with_duration(None, date_from='Why hello there!'), None) - self.assertEqual(time.format_infraction_with_duration(None, max_units=float('inf')), None) - self.assertEqual( - time.format_infraction_with_duration(None, date_from='Why hello there!', max_units=float('inf')), - None + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), ) + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + def test_format_infraction_with_duration_custom_units(self): """format_infraction_with_duration should work for custom max_units.""" - self.assertEqual( - time.format_infraction_with_duration('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6), - '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)' + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') ) - self.assertEqual( - time.format_infraction_with_duration('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20), - '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)' - ) + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) def test_format_infraction_with_duration_normal_usage(self): """format_infraction_with_duration should work for normal usage, across various durations.""" -- cgit v1.2.3 From 323306776b0312e2a32ada213a35159311a93a7f Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Thu, 5 Dec 2019 01:42:23 +0700 Subject: Removed `setUp()` from `TimeTests` since it is not being used for anything. --- tests/bot/utils/test_time.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 25cd3f69f..7f55dc3ec 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -12,9 +12,6 @@ from tests.helpers import AsyncMock class TimeTests(unittest.TestCase): """Test helper functions in bot.utils.time.""" - def setUp(self): - pass - def test_humanize_delta_handle_unknown_units(self): """humanize_delta should be able to handle unknown units, and will not abort.""" # Does not abort for unknown units, as the unit name is checked -- cgit v1.2.3 From 52163d775a6a0737f32a0c291e9275a910656fab Mon Sep 17 00:00:00 2001 From: kraktus <56031107+kraktus@users.noreply.github.com> Date: Thu, 5 Dec 2019 17:49:28 +0000 Subject: Requested change Include the check about whether or not there is a token in the posted message in `parse_codeblock` boolean. --- bot/cogs/bot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 53221cd8b..f79e00454 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -236,9 +236,10 @@ class Bot(Cog): ) and not msg.author.bot and len(msg.content.splitlines()) > 3 + and not TokenRemover.is_token_in_message(msg) ) - if parse_codeblock and not TokenRemover.is_token_in_message(msg): # no token in the msg + if parse_codeblock: # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: -- cgit v1.2.3 From a9dc1000872f507a850798b204befed299b6f703 Mon Sep 17 00:00:00 2001 From: Jens Date: Thu, 5 Dec 2019 22:07:59 +0100 Subject: Keeps access token alive, only revokes it on extension unload. Hard-coded version number to 1.0.0. --- bot/cogs/reddit.py | 52 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 64a940af1..0ebf2e1a7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,7 +2,8 @@ import asyncio import logging import random import textwrap -from datetime import datetime +from collections import namedtuple +from datetime import datetime, timedelta from typing import List from aiohttp import BasicAuth @@ -21,11 +22,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - # Change your client's User-Agent string to something unique and descriptive, - # including the target platform, a unique application identifier, a version string, - # and your username as contact information, in the following format: - # :: (by /u/) - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -33,7 +30,8 @@ class Reddit(Cog): def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready + self.webhook = None + self.access_token = None bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() @@ -41,6 +39,8 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -53,7 +53,7 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) - async def get_access_tokens(self) -> None: + async def get_access_token(self) -> None: """Get Reddit access tokens.""" headers = {"User-Agent": self.USER_AGENT} data = { @@ -61,6 +61,7 @@ class Reddit(Cog): "duration": "temporary" } + log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) for _ in range(self.MAX_RETRIES): @@ -72,9 +73,13 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() - self.access_token = content["access_token"] + AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(hours=1) + ) self.headers = { - "Authorization": "bearer " + self.access_token, + "Authorization": "bearer " + self.access_token.token, "User-Agent": self.USER_AGENT } return @@ -91,7 +96,7 @@ class Reddit(Cog): # The token should be revoked, since the API is called only once a day. headers = {"User-Agent": self.USER_AGENT} data = { - "token": self.access_token, + "token": self.access_token.token, "token_type_hint": "access_token" } @@ -200,7 +205,10 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - await self.get_access_tokens() + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() @@ -210,8 +218,6 @@ class Reddit(Cog): top_posts = await self.get_top_posts(subreddit=subreddit, time="day") await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) - await self.revoke_access_token() - async def top_weekly_posts(self) -> None: """Post a summary of the top posts.""" for subreddit in RedditConfig.subreddits: @@ -242,32 +248,38 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="all") await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="day") await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="week") await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @with_role(*STAFF_ROLES) @reddit_group.command(name="subreddits", aliases=("subs",)) -- cgit v1.2.3 From d84fc6346197d8176a7989b9b74e94d837d26882 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:39:39 -0800 Subject: Reddit: move token renewal inside fetch_posts This removes the duplicate code for renewing the token. Since fetch_posts is the only place where the token gets used, it can just be refreshed there directly. --- bot/cogs/reddit.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 22bb66bf0..0802c6102 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -122,6 +122,10 @@ class Reddit(Cog): if params is None: params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( @@ -206,11 +210,6 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() # if it's a monday send the top weekly posts @@ -249,10 +248,6 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="all") @@ -261,10 +256,6 @@ class Reddit(Cog): @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="day") @@ -273,10 +264,6 @@ class Reddit(Cog): @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="week") -- cgit v1.2.3 From ddfbfe31b2c2d9e5bc5d46ab9ffffa5b35a63e5f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:12 -0800 Subject: Reddit: move BasicAuth instantiation to __init__ The object is basically just a namedtuple so there's no need to re-create it every time a token is obtained. * Remove log message which shows credentials. * Initialise headers attribute to None in __init__. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0802c6102..48f636159 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -32,8 +32,10 @@ class Reddit(Cog): self.webhook = None self.access_token = None - bot.loop.create_task(self.init_reddit_ready()) + self.headers = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: @@ -61,9 +63,6 @@ class Reddit(Cog): "duration": "temporary" } - log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", -- cgit v1.2.3 From d0f6f794d4fa3dd78ac8be2b95cae669b4587fb3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:35 -0800 Subject: Reddit: use qualified_name attribute when removing the cog --- bot/cogs/reddit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 48f636159..111f3b8ab 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -86,7 +86,7 @@ class Reddit(Cog): await asyncio.sleep(3) log.error("Authentication with Reddit API failed. Unloading extension.") - self.bot.remove_cog(self.__class__.__name__) + self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: -- cgit v1.2.3 From 249a4c185ca9680258eb7a753307fbe8d0089b4b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:49:21 -0800 Subject: Reddit: use expires_in from the response to calculate token expiration --- bot/cogs/reddit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 111f3b8ab..083f90573 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -56,7 +56,7 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get Reddit access tokens.""" + """Get a Reddit API OAuth2 access token.""" headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,10 +72,11 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(hours=1) + expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) self.headers = { "Authorization": "bearer " + self.access_token.token, -- cgit v1.2.3 From e49d9d5429c8e1aedfde5a5d38750890b3361496 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:50:03 -0800 Subject: Reddit: define AccessToken type at the module level --- bot/cogs/reddit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 083f90573..d9e1f0a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -18,6 +18,8 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" @@ -73,7 +75,6 @@ class Reddit(Cog): if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) -- cgit v1.2.3 From 4889326af4f9251218e332f873349e3f8c7bea7b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:18:46 -0800 Subject: Reddit: revise docstrings --- bot/cogs/reddit.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index d9e1f0a39..0a0279a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -41,7 +41,7 @@ class Reddit(Cog): self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() if self.access_token.expires_at < datetime.utcnow(): self.revoke_access_token() @@ -58,7 +58,12 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get a Reddit API OAuth2 access token.""" + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded if retrieval was still unsuccessful. + """ headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,6 +77,7 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. @@ -87,14 +93,16 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading extension.") + log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: - """Revoke the access token for Reddit API.""" - # Access tokens are valid for 1 hour. - # The token should be revoked, since the API is called only once a day. + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ headers = {"User-Agent": self.USER_AGENT} data = { "token": self.access_token.token, @@ -107,12 +115,12 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 204 and response.content_type == "application/json": self.access_token = None self.headers = None - return - - log.warning(f"Unable to revoke access token, status code {response.status}.") + else: + log.warning(f"Unable to revoke access token: status {response.status}.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" -- cgit v1.2.3 From 08881979e25749cb8da9efaccead64bb15b354ec Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:33:39 -0800 Subject: Reddit: create a dict constant for the User-Agent header --- bot/cogs/reddit.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0a0279a39..6af33d9db 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -24,7 +24,7 @@ AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -34,7 +34,6 @@ class Reddit(Cog): self.webhook = None self.access_token = None - self.headers = None self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) bot.loop.create_task(self.init_reddit_ready()) @@ -64,18 +63,15 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "grant_type": "client_credentials", - "duration": "temporary" - } - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } ) if response.status == 200 and response.content_type == "application/json": @@ -85,10 +81,6 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) - self.headers = { - "Authorization": "bearer " + self.access_token.token, - "User-Agent": self.USER_AGENT - } return await asyncio.sleep(3) @@ -103,22 +95,18 @@ class Reddit(Cog): For security reasons, it's good practice to revoke the token when it's no longer being used. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "token": self.access_token.token, - "token_type_hint": "access_token" - } - response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/revoke_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } ) if response.status == 204 and response.content_type == "application/json": self.access_token = None - self.headers = None else: log.warning(f"Unable to revoke access token: status {response.status}.") @@ -128,9 +116,6 @@ class Reddit(Cog): if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} - # Renew the token if necessary. if not self.access_token or self.access_token.expires_at < datetime.utcnow(): await self.get_access_token() @@ -139,7 +124,7 @@ class Reddit(Cog): for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.headers, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': -- cgit v1.2.3 From a4fb51bbeb9b15e1a3718038f280d9c633acfa66 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:45:56 -0800 Subject: Reddit: log retries when getting the access token --- bot/cogs/reddit.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6af33d9db..15b4a108c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -63,7 +63,7 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - for _ in range(self.MAX_RETRIES): + for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", headers=self.HEADERS, @@ -81,7 +81,15 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) await asyncio.sleep(3) -- cgit v1.2.3 From 806ccf73dd896b4272726ce32edea7c882ce9a81 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:47:13 -0800 Subject: Reddit: raise ClientError when the token can't be retrieved Raising an exception allows the error handler to display a message to the user if the failure happened from a command invocation. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 15b4a108c..96af90bc4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -6,7 +6,7 @@ from collections import namedtuple from datetime import datetime, timedelta from typing import List -from aiohttp import BasicAuth +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group from discord.ext.tasks import loop @@ -61,7 +61,7 @@ class Reddit(Cog): Get a Reddit API OAuth2 access token and assign it to self.access_token. A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded if retrieval was still unsuccessful. + will be unloaded and a ClientError raised if retrieval was still unsuccessful. """ for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( @@ -93,9 +93,8 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) - return + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") async def revoke_access_token(self) -> None: """ -- cgit v1.2.3 From 9d551cc69c1935165389f26f52753895604dd3f5 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 20:26:26 -0800 Subject: Add a generic converter for only allowing certain string values --- bot/cogs/moderation/management.py | 13 ++----------- bot/converters.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..50bce3981 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, string from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -22,15 +22,6 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" @@ -61,7 +52,7 @@ class ModManagement(commands.Cog): self, ctx: Context, infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + duration: t.Union[utils.Expiry, string("permanent"), None], *, reason: str = None ) -> None: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..2cfc42903 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converter to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") -- cgit v1.2.3 From 729ac3d83a3bd4620d1e9b24769466e219d45de6 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:00:47 -0800 Subject: ModManagement: allow "recent" as ID to edit infraction (#624) It will attempt to find the most recent infraction authored by the invoker of the edit command. --- bot/cogs/moderation/management.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 50bce3981..35832ded5 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -51,7 +51,7 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, + infraction_id: t.Union[int, string("recent")], duration: t.Union[utils.Expiry, string("permanent"), None], *, reason: str = None @@ -69,6 +69,9 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds + Use "recent" as the infraction ID to specify that the ost recent infraction authored by the + command invoker should be edited. + Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp can be provided for the duration. """ @@ -77,7 +80,23 @@ class ModManagement(commands.Cog): raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if infraction_id == "recent": + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] -- cgit v1.2.3 From c1bf0a48692d87c5cbe9ee310cd0120bda339a96 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:01:16 -0800 Subject: ModManagement: display ID of edited infraction in confirmation message --- bot/cogs/moderation/management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 35832ded5..904611e13 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -139,7 +139,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] -- cgit v1.2.3 From 56833bbe99ce3cd93af87e9d33cec47a059b61f3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 21:52:57 -0800 Subject: ModManagement: add more aliases for "special" params of infraction edit --- bot/cogs/moderation/management.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 904611e13..37bdb1934 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -51,8 +51,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: t.Union[int, string("recent")], - duration: t.Union[utils.Expiry, string("permanent"), None], + infraction_id: t.Union[int, string("l", "last", "recent")], + duration: t.Union[utils.Expiry, string("p", "permanent"), None], *, reason: str = None ) -> None: @@ -69,18 +69,18 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds - Use "recent" as the infraction ID to specify that the ost recent infraction authored by the - command invoker should be edited. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - if infraction_id == "recent": + if isinstance(infraction_id, str): params = { "actor__id": ctx.author.id, "ordering": "-inserted_at" @@ -102,7 +102,7 @@ class ModManagement(commands.Cog): confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: -- cgit v1.2.3 From eb53a4594dff20372574058ec90062995362b098 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 22:03:20 -0800 Subject: Converters: rename string to allowed_strings --- bot/cogs/moderation/management.py | 6 +++--- bot/converters.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 37bdb1934..20ff25ba1 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,7 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery, string +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -51,8 +51,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: t.Union[int, string("l", "last", "recent")], - duration: t.Union[utils.Expiry, string("p", "permanent"), None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: diff --git a/bot/converters.py b/bot/converters.py index 2cfc42903..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -15,11 +15,11 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) -def string(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: """ Return a converter which only allows arguments equal to one of the given values. - Unless preserve_case is True, the argument is converter to lowercase. All values are then + Unless preserve_case is True, the argument is converted to lowercase. All values are then expected to have already been given in lowercase too. """ def converter(arg: str) -> str: -- cgit v1.2.3 From 97f0cb8efb82217d28123e834454d1316f04b031 Mon Sep 17 00:00:00 2001 From: Joseph Date: Fri, 13 Dec 2019 00:41:56 +0000 Subject: Revert "Use OAuth to be Reddit API compliant" --- azure-pipelines.yml | 2 +- bot/cogs/reddit.py | 91 ++++++----------------------------------------------- bot/constants.py | 2 -- config-default.yml | 2 -- docker-compose.yml | 2 -- 5 files changed, 11 insertions(+), 88 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0400ac4d2..da3b06201 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index aa487f18e..bec316ae7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,11 +2,9 @@ import asyncio import logging import random import textwrap -from collections import namedtuple from datetime import datetime, timedelta from typing import List -from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop @@ -19,32 +17,25 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 + MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - + self.webhook = None # set in on_ready bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" + """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() - if self.access_token.expires_at < datetime.utcnow(): - self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -57,82 +48,20 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() + if params is None: + params = {} - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): + url = f"{self.URL}/{route}.json" + for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + headers=self.HEADERS, params=params ) if response.status == 200 and response.content_type == 'application/json': diff --git a/bot/constants.py b/bot/constants.py index ed85adf6a..89504a2e0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -465,8 +465,6 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list - client_id: str - secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index e6f0fda21..930a1a0e6 100644 --- a/config-default.yml +++ b/config-default.yml @@ -365,8 +365,6 @@ anti_malware: reddit: subreddits: - 'r/Python' - client_id: !ENV "REDDIT_CLIENT_ID" - secret: !ENV "REDDIT_SECRET" wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index 7281c7953..f79fdba58 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,5 +42,3 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 - REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} - REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From c5109844e45c37bc1cc38eb1c3da31d52ab2aa6d Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:17:21 +0700 Subject: Adding an optional argument for `until_expiration`, update typehints for `format_infraction_with_duration` - `until_expiration` was being a pain to unittests without a `now` ( default to `datetime.utcnow()` ). Adding an optional argument for this will not only make writing tests easier, but also allow more control over the helper function should we need to calculate the remaining time between two dates in the past. - Changed typehint for `date_from` in `format_infraction_with_duration` to `Optional[datetime.datetime]` to better reflect what it is. --- bot/utils/time.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bot/utils/time.py b/bot/utils/time.py index ac64865d6..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -115,7 +115,7 @@ def format_infraction(timestamp: str) -> str: def format_infraction_with_duration( expiry: Optional[str], - date_from: datetime.datetime = None, + date_from: Optional[datetime.datetime] = None, max_units: int = 2 ) -> Optional[str]: """ @@ -140,10 +140,15 @@ def format_infraction_with_duration( return f"{expiry_formatted}{duration_formatted}" -def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str]: +def until_expiration( + expiry: Optional[str], + now: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: """ Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). By default, max_units is 2. @@ -151,7 +156,7 @@ def until_expiration(expiry: Optional[str], max_units: int = 2) -> Optional[str] if not expiry: return None - now = datetime.datetime.utcnow() + now = now or datetime.datetime.utcnow() since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) if since < now: -- cgit v1.2.3 From 520346d0b472e5cb6c9091a8323b871d2e3821cc Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:18:49 +0700 Subject: Added tests for `until_expiration` Similar to `format_infraction_with_duration` ( if not outright copying it ), added 3 tests for `until_expiration`: - None `expiry`. - Custom `max_units`. - Normal use cases. --- tests/bot/utils/test_time.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 7f55dc3ec..bd04de28b 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -115,3 +115,48 @@ class TimeTests(unittest.TestCase): for expiry, date_from, max_units, expected in test_cases: with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) -- cgit v1.2.3 From 66d4b93593b95bfa6999b70aca53328d83710c44 Mon Sep 17 00:00:00 2001 From: Shirayuki Nekomata Date: Fri, 13 Dec 2019 09:29:06 +0700 Subject: Fixed a typo ( due to poor copy pasta and eyeballing skills ) --- tests/bot/utils/test_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index bd04de28b..69f35f2f5 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -121,7 +121,7 @@ class TimeTests(unittest.TestCase): test_cases = ( (None, None, None, None), - # To make sure that date_from and max_units are not touched + # To make sure that now and max_units are not touched (None, 'Why hello there!', None, None), (None, None, float('inf'), None), (None, 'Why hello there!', float('inf'), None), -- cgit v1.2.3