diff options
author | 2019-10-09 23:43:16 +0200 | |
---|---|---|
committer | 2019-10-09 23:43:16 +0200 | |
commit | 2cb7ee12805957c7d655679ff54a14f16e059a80 (patch) | |
tree | caddeeea2c706576257fc88c40e3427ff0fa83bf | |
parent | Merge pull request #505 from python-discord/user-log-display-name-changes (diff) |
Add Reddit OAuth tasks and refactor code
-rw-r--r-- | bot/cogs/reddit.py | 79 | ||||
-rw-r--r-- | bot/constants.py | 2 | ||||
-rw-r--r-- | config-default.yml | 2 |
3 files changed, 77 insertions, 6 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6880aab85..bf4403ce4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio import logging import random import textwrap +from aiohttp import BasicAuth from datetime import datetime, timedelta from typing import List from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # <platform>:<app ID>:<version string> (by /u/<reddit username>) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -34,6 +41,66 @@ class Reddit(Cog): self.new_posts_task = None self.top_weekly_posts_task = None + self.refresh_access_token.start() + + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh the access token""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url = f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens""" + await self.bot.wait_until_ready() + + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + if RedditConfig.client_id and RedditConfig.secret: + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + self.client_auth = None + self.access_token = None + self.refresh_token = None + self.headers = None + + log.error("Unable to find client credentials.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -43,11 +110,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -55,7 +122,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -127,8 +194,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: |