diff options
-rw-r--r-- | bot/cogs/reddit.py | 73 | ||||
-rw-r--r-- | bot/constants.py | 2 | ||||
-rw-r--r-- | config-default.yml | 2 |
3 files changed, 71 insertions, 6 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..7b183221c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -5,7 +5,9 @@ import textwrap from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # <platform>:<app ID>:<version string> (by /u/<reddit username>) + USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" MAX_FETCH_RETRIES = 3 def __init__(self, bot: Bot): @@ -36,6 +43,59 @@ class Reddit(Cog): self.bot.loop.create_task(self.init_reddit_polling()) + @tasks.loop(hours=0.99) # access tokens are valid for one hour + async def refresh_access_token(self) -> None: + """Refresh Reddits access token.""" + headers = {"Authorization": self.client_auth} + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token + } + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + data=data, + ) + + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + + @refresh_access_token.before_loop + async def get_tokens(self) -> None: + """Get Reddit access and refresh tokens.""" + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "permanent" + } + + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + self.access_token = content["access_token"] + self.refresh_token = content["refresh_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + else: + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -45,11 +105,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" + url = f"{self.OAUTH_URL}/{route}" for _ in range(self.MAX_FETCH_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -57,7 +117,7 @@ class Reddit(Cog): content = await response.json() posts = content["data"]["children"] return posts[:amount] - + await asyncio.sleep(3) log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -129,8 +189,8 @@ class Reddit(Cog): for subreddit in RedditConfig.subreddits: # Make a HEAD request to the subreddit head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS + url=f"{self.OAUTH_URL}/{subreddit}/new.rss", + headers=self.headers ) content_length = head_response.headers["content-length"] @@ -268,6 +328,7 @@ class Reddit(Cog): """Initiate reddit post event loop.""" await self.bot.wait_until_ready() self.reddit_channel = await self.bot.fetch_channel(Channels.reddit) + self.refresh_access_token.start() if self.reddit_channel is not None: if self.new_posts_task is None: diff --git a/bot/constants.py b/bot/constants.py index f4f45eb2c..c49242d5e 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter): request_delay: int subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index ca405337e..3487dff27 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit: request_delay: 60 subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: |