aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Jens <[email protected]>2019-10-09 23:43:16 +0200
committerGravatar Jens <[email protected]>2019-10-09 23:43:16 +0200
commit2cb7ee12805957c7d655679ff54a14f16e059a80 (patch)
treecaddeeea2c706576257fc88c40e3427ff0fa83bf
parentMerge pull request #505 from python-discord/user-log-display-name-changes (diff)
Add Reddit OAuth tasks and refactor code
-rw-r--r--bot/cogs/reddit.py79
-rw-r--r--bot/constants.py2
-rw-r--r--config-default.yml2
3 files changed, 77 insertions, 6 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 6880aab85..bf4403ce4 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,10 +2,12 @@ import asyncio
import logging
import random
import textwrap
+from aiohttp import BasicAuth
from datetime import datetime, timedelta
from typing import List
from discord import Colour, Embed, Message, TextChannel
+from discord.ext import tasks
from discord.ext.commands import Bot, Cog, Context, group
from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES
@@ -19,8 +21,13 @@ log = logging.getLogger(__name__)
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"}
+ # Change your client's User-Agent string to something unique and descriptive,
+ # including the target platform, a unique application identifier, a version string,
+ # and your username as contact information, in the following format:
+ # <platform>:<app ID>:<version string> (by /u/<reddit username>)
+ USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)"
URL = "https://www.reddit.com"
+ OAUTH_URL = "https://oauth.reddit.com"
MAX_FETCH_RETRIES = 3
def __init__(self, bot: Bot):
@@ -34,6 +41,66 @@ class Reddit(Cog):
self.new_posts_task = None
self.top_weekly_posts_task = None
+ self.refresh_access_token.start()
+
+ @tasks.loop(hours=0.99) # access tokens are valid for one hour
+ async def refresh_access_token(self) -> None:
+ """Refresh the access token"""
+ headers = {"Authorization": self.client_auth}
+ data = {
+ "grant_type": "refresh_token",
+ "refresh_token": self.refresh_token
+ }
+
+ response = await self.bot.http_session.post(
+ url = f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ data=data,
+ )
+
+ content = await response.json()
+ self.access_token = content["access_token"]
+ self.headers = {
+ "Authorization": "bearer " + self.access_token,
+ "User-Agent": self.USER_AGENT
+ }
+
+ @refresh_access_token.before_loop
+ async def get_tokens(self) -> None:
+ """Get Reddit access and refresh tokens"""
+ await self.bot.wait_until_ready()
+
+ headers = {"User-Agent": self.USER_AGENT}
+ data = {
+ "grant_type": "client_credentials",
+ "duration": "permanent"
+ }
+
+ if RedditConfig.client_id and RedditConfig.secret:
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
+
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ auth=self.client_auth,
+ data=data
+ )
+
+ content = await response.json()
+ self.access_token = content["access_token"]
+ self.refresh_token = content["refresh_token"]
+ self.headers = {
+ "Authorization": "bearer " + self.access_token,
+ "User-Agent": self.USER_AGENT
+ }
+ else:
+ self.client_auth = None
+ self.access_token = None
+ self.refresh_token = None
+ self.headers = None
+
+ log.error("Unable to find client credentials.")
+
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
@@ -43,11 +110,11 @@ class Reddit(Cog):
if params is None:
params = {}
- url = f"{self.URL}/{route}.json"
+ url = f"{self.OAUTH_URL}/{route}"
for _ in range(self.MAX_FETCH_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers=self.HEADERS,
+ headers=self.headers,
params=params
)
if response.status == 200 and response.content_type == 'application/json':
@@ -55,7 +122,7 @@ class Reddit(Cog):
content = await response.json()
posts = content["data"]["children"]
return posts[:amount]
-
+
await asyncio.sleep(3)
log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")
@@ -127,8 +194,8 @@ class Reddit(Cog):
for subreddit in RedditConfig.subreddits:
# Make a HEAD request to the subreddit
head_response = await self.bot.http_session.head(
- url=f"{self.URL}/{subreddit}/new.rss",
- headers=self.HEADERS
+ url=f"{self.OAUTH_URL}/{subreddit}/new.rss",
+ headers=self.headers
)
content_length = head_response.headers["content-length"]
diff --git a/bot/constants.py b/bot/constants.py
index 1deeaa3b8..f84889e10 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter):
request_delay: int
subreddits: list
+ client_id: str
+ secret: str
class Wolfram(metaclass=YAMLGetter):
diff --git a/config-default.yml b/config-default.yml
index 0dac9bf9f..c43ea4f8f 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -326,6 +326,8 @@ reddit:
request_delay: 60
subreddits:
- 'r/Python'
+ client_id: !ENV "REDDIT_CLIENT_ID"
+ secret: !ENV "REDDIT_SECRET"
wolfram: