From 2cb7ee12805957c7d655679ff54a14f16e059a80 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6880aab85..bf4403ce4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -34,6 +41,66 @@ class Reddit(Cog): self.new_posts_task = None self.top_weekly_posts_task = None + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -43,11 +110,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -55,7 +122,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -127,8 +194,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From 9b4b98eb6f37060771c61979458737629b3c5db7 Mon Sep 17 00:00:00 2001 From: Jens Date: Wed, 9 Oct 2019 23:43:16 +0200 Subject: Add Reddit OAuth tasks and refactor code --- bot/cogs/reddit.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++----- bot/constants.py | 2 ++ config-default.yml | 2 ++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..25df014f8 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # :: (by /u/) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -35,6 +42,65 @@ class Reddit(Cog): self.top_weekly_posts_task = None self.bot.loop.create_task(self.init_reddit_polling()) + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -45,11 +111,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -57,7 +123,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -129,8 +195,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: -- cgit v1.2.3 From 79ed098809ce3cfaa0fa75608f6f6a85af2a90dd Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 15 Oct 2019 23:36:36 +0200 Subject: Unload cog on auth error and fix linting warnings --- bot/cogs/reddit.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 25df014f8..451d2bf4c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,10 @@ import asyncio import logging import random import textwrap -from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, Message, TextChannel from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group @@ -46,7 +46,7 @@ class Reddit(Cog): @tasks.loop(hours=0.99) # access tokens are valid for one hour async def refresh_access_token(self) -> None: - """Refresh the access token""" + """Refresh Reddits access token.""" headers = {"Authorization": self.client_auth} data = { "grant_type": "refresh_token", @@ -54,7 +54,7 @@ class Reddit(Cog): } response = await self.bot.http_session.post( - url = f"{self.URL}/api/v1/access_token", + url=f"{self.URL}/api/v1/access_token", headers=headers, data=data, ) @@ -68,7 +68,7 @@ class Reddit(Cog): @refresh_access_token.before_loop async def get_tokens(self) -> None: - """Get Reddit access and refresh tokens""" + """Get Reddit access and refresh tokens.""" await self.bot.wait_until_ready() headers = {"User-Agent": self.USER_AGENT} @@ -77,16 +77,16 @@ class Reddit(Cog): "duration": "permanent" } - if RedditConfig.client_id and RedditConfig.secret: - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=headers, - auth=self.client_auth, - data=data - ) + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() self.access_token = content["access_token"] self.refresh_token = content["refresh_token"] @@ -95,12 +95,9 @@ class Reddit(Cog): "User-Agent": self.USER_AGENT } else: - self.client_auth = None - self.access_token = None - self.refresh_token = None - self.headers = None - - log.error("Unable to find client credentials.") + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -123,7 +120,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") -- cgit v1.2.3 From 5ec4db0cba484f8adfc25b642a4f24f362a5b53c Mon Sep 17 00:00:00 2001 From: Jens Date: Tue, 22 Oct 2019 22:05:50 +0200 Subject: Add reddit environment variable, change User-Agent and fix lint problem --- bot/cogs/reddit.py | 4 ++-- docker-compose.yml | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7b183221c..76da0f09f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -25,7 +25,7 @@ class Reddit(Cog): # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: # :: (by /u/) - USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 @@ -117,7 +117,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} -- cgit v1.2.3 From 2c0fffbc697db5f33e759f733c310f9f0b754d11 Mon Sep 17 00:00:00 2001 From: Jens Date: Sat, 26 Oct 2019 17:24:09 +0200 Subject: Fix linting error --- bot/cogs/reddit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7e2ba40d5..64a940af1 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -20,6 +20,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" + # Change your client's User-Agent string to something unique and descriptive, # including the target platform, a unique application identifier, a version string, # and your username as contact information, in the following format: -- cgit v1.2.3 From a9dc1000872f507a850798b204befed299b6f703 Mon Sep 17 00:00:00 2001 From: Jens Date: Thu, 5 Dec 2019 22:07:59 +0100 Subject: Keeps access token alive, only revokes it on extension unload. Hard-coded version number to 1.0.0. --- bot/cogs/reddit.py | 52 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 64a940af1..0ebf2e1a7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,7 +2,8 @@ import asyncio import logging import random import textwrap -from datetime import datetime +from collections import namedtuple +from datetime import datetime, timedelta from typing import List from aiohttp import BasicAuth @@ -21,11 +22,7 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - # Change your client's User-Agent string to something unique and descriptive, - # including the target platform, a unique application identifier, a version string, - # and your username as contact information, in the following format: - # :: (by /u/) - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -33,7 +30,8 @@ class Reddit(Cog): def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready + self.webhook = None + self.access_token = None bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() @@ -41,6 +39,8 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -53,7 +53,7 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) - async def get_access_tokens(self) -> None: + async def get_access_token(self) -> None: """Get Reddit access tokens.""" headers = {"User-Agent": self.USER_AGENT} data = { @@ -61,6 +61,7 @@ class Reddit(Cog): "duration": "temporary" } + log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) for _ in range(self.MAX_RETRIES): @@ -72,9 +73,13 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() - self.access_token = content["access_token"] + AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(hours=1) + ) self.headers = { - "Authorization": "bearer " + self.access_token, + "Authorization": "bearer " + self.access_token.token, "User-Agent": self.USER_AGENT } return @@ -91,7 +96,7 @@ class Reddit(Cog): # The token should be revoked, since the API is called only once a day. headers = {"User-Agent": self.USER_AGENT} data = { - "token": self.access_token, + "token": self.access_token.token, "token_type_hint": "access_token" } @@ -200,7 +205,10 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - await self.get_access_tokens() + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() @@ -210,8 +218,6 @@ class Reddit(Cog): top_posts = await self.get_top_posts(subreddit=subreddit, time="day") await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) - await self.revoke_access_token() - async def top_weekly_posts(self) -> None: """Post a summary of the top posts.""" for subreddit in RedditConfig.subreddits: @@ -242,32 +248,38 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="all") await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="day") await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): - await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="week") await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - await self.revoke_access_token() @with_role(*STAFF_ROLES) @reddit_group.command(name="subreddits", aliases=("subs",)) -- cgit v1.2.3 From d84fc6346197d8176a7989b9b74e94d837d26882 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:39:39 -0800 Subject: Reddit: move token renewal inside fetch_posts This removes the duplicate code for renewing the token. Since fetch_posts is the only place where the token gets used, it can just be refreshed there directly. --- bot/cogs/reddit.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 22bb66bf0..0802c6102 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -122,6 +122,10 @@ class Reddit(Cog): if params is None: params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( @@ -206,11 +210,6 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() # if it's a monday send the top weekly posts @@ -249,10 +248,6 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="all") @@ -261,10 +256,6 @@ class Reddit(Cog): @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="day") @@ -273,10 +264,6 @@ class Reddit(Cog): @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" - if not self.access_token: - await self.get_access_token() - elif self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="week") -- cgit v1.2.3 From ddfbfe31b2c2d9e5bc5d46ab9ffffa5b35a63e5f Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:12 -0800 Subject: Reddit: move BasicAuth instantiation to __init__ The object is basically just a namedtuple so there's no need to re-create it every time a token is obtained. * Remove log message which shows credentials. * Initialise headers attribute to None in __init__. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0802c6102..48f636159 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -32,8 +32,10 @@ class Reddit(Cog): self.webhook = None self.access_token = None - bot.loop.create_task(self.init_reddit_ready()) + self.headers = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: @@ -61,9 +63,6 @@ class Reddit(Cog): "duration": "temporary" } - log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", -- cgit v1.2.3 From d0f6f794d4fa3dd78ac8be2b95cae669b4587fb3 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 12:44:35 -0800 Subject: Reddit: use qualified_name attribute when removing the cog --- bot/cogs/reddit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 48f636159..111f3b8ab 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -86,7 +86,7 @@ class Reddit(Cog): await asyncio.sleep(3) log.error("Authentication with Reddit API failed. Unloading extension.") - self.bot.remove_cog(self.__class__.__name__) + self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: -- cgit v1.2.3 From 249a4c185ca9680258eb7a753307fbe8d0089b4b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:49:21 -0800 Subject: Reddit: use expires_in from the response to calculate token expiration --- bot/cogs/reddit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 111f3b8ab..083f90573 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -56,7 +56,7 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get Reddit access tokens.""" + """Get a Reddit API OAuth2 access token.""" headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,10 +72,11 @@ class Reddit(Cog): ) if response.status == 200 and response.content_type == "application/json": content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(hours=1) + expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) self.headers = { "Authorization": "bearer " + self.access_token.token, -- cgit v1.2.3 From e49d9d5429c8e1aedfde5a5d38750890b3361496 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 13:50:03 -0800 Subject: Reddit: define AccessToken type at the module level --- bot/cogs/reddit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 083f90573..d9e1f0a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -18,6 +18,8 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" @@ -73,7 +75,6 @@ class Reddit(Cog): if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) self.access_token = AccessToken( token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) -- cgit v1.2.3 From 4889326af4f9251218e332f873349e3f8c7bea7b Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:18:46 -0800 Subject: Reddit: revise docstrings --- bot/cogs/reddit.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index d9e1f0a39..0a0279a39 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -41,7 +41,7 @@ class Reddit(Cog): self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() if self.access_token.expires_at < datetime.utcnow(): self.revoke_access_token() @@ -58,7 +58,12 @@ class Reddit(Cog): return self.bot.get_channel(Channels.reddit) async def get_access_token(self) -> None: - """Get a Reddit API OAuth2 access token.""" + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded if retrieval was still unsuccessful. + """ headers = {"User-Agent": self.USER_AGENT} data = { "grant_type": "client_credentials", @@ -72,6 +77,7 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 200 and response.content_type == "application/json": content = await response.json() expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. @@ -87,14 +93,16 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading extension.") + log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) return async def revoke_access_token(self) -> None: - """Revoke the access token for Reddit API.""" - # Access tokens are valid for 1 hour. - # The token should be revoked, since the API is called only once a day. + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ headers = {"User-Agent": self.USER_AGENT} data = { "token": self.access_token.token, @@ -107,12 +115,12 @@ class Reddit(Cog): auth=self.client_auth, data=data ) + if response.status == 204 and response.content_type == "application/json": self.access_token = None self.headers = None - return - - log.warning(f"Unable to revoke access token, status code {response.status}.") + else: + log.warning(f"Unable to revoke access token: status {response.status}.") async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" -- cgit v1.2.3 From 08881979e25749cb8da9efaccead64bb15b354ec Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:33:39 -0800 Subject: Reddit: create a dict constant for the User-Agent header --- bot/cogs/reddit.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0a0279a39..6af33d9db 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -24,7 +24,7 @@ AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" OAUTH_URL = "https://oauth.reddit.com" MAX_RETRIES = 3 @@ -34,7 +34,6 @@ class Reddit(Cog): self.webhook = None self.access_token = None - self.headers = None self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) bot.loop.create_task(self.init_reddit_ready()) @@ -64,18 +63,15 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "grant_type": "client_credentials", - "duration": "temporary" - } - for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } ) if response.status == 200 and response.content_type == "application/json": @@ -85,10 +81,6 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) - self.headers = { - "Authorization": "bearer " + self.access_token.token, - "User-Agent": self.USER_AGENT - } return await asyncio.sleep(3) @@ -103,22 +95,18 @@ class Reddit(Cog): For security reasons, it's good practice to revoke the token when it's no longer being used. """ - headers = {"User-Agent": self.USER_AGENT} - data = { - "token": self.access_token.token, - "token_type_hint": "access_token" - } - response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/revoke_token", - headers=headers, + headers=self.HEADERS, auth=self.client_auth, - data=data + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } ) if response.status == 204 and response.content_type == "application/json": self.access_token = None - self.headers = None else: log.warning(f"Unable to revoke access token: status {response.status}.") @@ -128,9 +116,6 @@ class Reddit(Cog): if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} - # Renew the token if necessary. if not self.access_token or self.access_token.expires_at < datetime.utcnow(): await self.get_access_token() @@ -139,7 +124,7 @@ class Reddit(Cog): for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.headers, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': -- cgit v1.2.3 From a4fb51bbeb9b15e1a3718038f280d9c633acfa66 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:45:56 -0800 Subject: Reddit: log retries when getting the access token --- bot/cogs/reddit.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6af33d9db..15b4a108c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -63,7 +63,7 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded if retrieval was still unsuccessful. """ - for _ in range(self.MAX_RETRIES): + for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( url=f"{self.URL}/api/v1/access_token", headers=self.HEADERS, @@ -81,7 +81,15 @@ class Reddit(Cog): token=content["access_token"], expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) await asyncio.sleep(3) -- cgit v1.2.3 From 806ccf73dd896b4272726ce32edea7c882ce9a81 Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Wed, 11 Dec 2019 14:47:13 -0800 Subject: Reddit: raise ClientError when the token can't be retrieved Raising an exception allows the error handler to display a message to the user if the failure happened from a command invocation. --- bot/cogs/reddit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 15b4a108c..96af90bc4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -6,7 +6,7 @@ from collections import namedtuple from datetime import datetime, timedelta from typing import List -from aiohttp import BasicAuth +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group from discord.ext.tasks import loop @@ -61,7 +61,7 @@ class Reddit(Cog): Get a Reddit API OAuth2 access token and assign it to self.access_token. A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded if retrieval was still unsuccessful. + will be unloaded and a ClientError raised if retrieval was still unsuccessful. """ for i in range(1, self.MAX_RETRIES + 1): response = await self.bot.http_session.post( @@ -93,9 +93,8 @@ class Reddit(Cog): await asyncio.sleep(3) - log.error("Authentication with Reddit API failed. Unloading the cog.") self.bot.remove_cog(self.qualified_name) - return + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") async def revoke_access_token(self) -> None: """ -- cgit v1.2.3