aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/reddit.py73
-rw-r--r--bot/constants.py2
-rw-r--r--config-default.yml2
3 files changed, 71 insertions, 6 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 0f575cece..7b183221c 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -5,7 +5,9 @@ import textwrap
from datetime import datetime, timedelta
from typing import List
+from aiohttp import BasicAuth
from discord import Colour, Embed, Message, TextChannel
+from discord.ext import tasks
from discord.ext.commands import Bot, Cog, Context, group
from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES
@@ -19,8 +21,13 @@ log = logging.getLogger(__name__)
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"}
+ # Change your client's User-Agent string to something unique and descriptive,
+ # including the target platform, a unique application identifier, a version string,
+ # and your username as contact information, in the following format:
+ # <platform>:<app ID>:<version string> (by /u/<reddit username>)
+ USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)"
URL = "https://www.reddit.com"
+ OAUTH_URL = "https://oauth.reddit.com"
MAX_FETCH_RETRIES = 3
def __init__(self, bot: Bot):
@@ -36,6 +43,59 @@ class Reddit(Cog):
self.bot.loop.create_task(self.init_reddit_polling())
+ @tasks.loop(hours=0.99) # access tokens are valid for one hour
+ async def refresh_access_token(self) -> None:
+ """Refresh Reddits access token."""
+ headers = {"Authorization": self.client_auth}
+ data = {
+ "grant_type": "refresh_token",
+ "refresh_token": self.refresh_token
+ }
+
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ data=data,
+ )
+
+ content = await response.json()
+ self.access_token = content["access_token"]
+ self.headers = {
+ "Authorization": "bearer " + self.access_token,
+ "User-Agent": self.USER_AGENT
+ }
+
+ @refresh_access_token.before_loop
+ async def get_tokens(self) -> None:
+ """Get Reddit access and refresh tokens."""
+ headers = {"User-Agent": self.USER_AGENT}
+ data = {
+ "grant_type": "client_credentials",
+ "duration": "permanent"
+ }
+
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
+
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ auth=self.client_auth,
+ data=data
+ )
+
+ if response.status == 200 and response.content_type == "application/json":
+ content = await response.json()
+ self.access_token = content["access_token"]
+ self.refresh_token = content["refresh_token"]
+ self.headers = {
+ "Authorization": "bearer " + self.access_token,
+ "User-Agent": self.USER_AGENT
+ }
+ else:
+ log.error("Authentication with Reddit API failed. Unloading extension.")
+ self.bot.remove_cog(self.__class__.__name__)
+ return
+
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
@@ -45,11 +105,11 @@ class Reddit(Cog):
if params is None:
params = {}
- url = f"{self.URL}/{route}.json"
+ url = f"{self.OAUTH_URL}/{route}"
for _ in range(self.MAX_FETCH_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers=self.HEADERS,
+ headers=self.headers,
params=params
)
if response.status == 200 and response.content_type == 'application/json':
@@ -57,7 +117,7 @@ class Reddit(Cog):
content = await response.json()
posts = content["data"]["children"]
return posts[:amount]
-
+
await asyncio.sleep(3)
log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")
@@ -129,8 +189,8 @@ class Reddit(Cog):
for subreddit in RedditConfig.subreddits:
# Make a HEAD request to the subreddit
head_response = await self.bot.http_session.head(
- url=f"{self.URL}/{subreddit}/new.rss",
- headers=self.HEADERS
+ url=f"{self.OAUTH_URL}/{subreddit}/new.rss",
+ headers=self.headers
)
content_length = head_response.headers["content-length"]
@@ -268,6 +328,7 @@ class Reddit(Cog):
"""Initiate reddit post event loop."""
await self.bot.wait_until_ready()
self.reddit_channel = await self.bot.fetch_channel(Channels.reddit)
+ self.refresh_access_token.start()
if self.reddit_channel is not None:
if self.new_posts_task is None:
diff --git a/bot/constants.py b/bot/constants.py
index f4f45eb2c..c49242d5e 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter):
request_delay: int
subreddits: list
+ client_id: str
+ secret: str
class Wolfram(metaclass=YAMLGetter):
diff --git a/config-default.yml b/config-default.yml
index ca405337e..3487dff27 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -326,6 +326,8 @@ reddit:
request_delay: 60
subreddits:
- 'r/Python'
+ client_id: !ENV "REDDIT_CLIENT_ID"
+ secret: !ENV "REDDIT_SECRET"
wolfram: