diff options
| author | 2019-10-09 23:43:16 +0200 | |
|---|---|---|
| committer | 2019-10-09 23:43:16 +0200 | |
| commit | 2cb7ee12805957c7d655679ff54a14f16e059a80 (patch) | |
| tree | caddeeea2c706576257fc88c40e3427ff0fa83bf | |
| parent | Merge pull request #505 from python-discord/user-log-display-name-changes (diff) | |
Add Reddit OAuth tasks and refactor code
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/reddit.py | 79 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 2 | 
3 files changed, 77 insertions, 6 deletions
| diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 6880aab85..bf4403ce4 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,12 @@ import asyncio  import logging  import random  import textwrap +from aiohttp import BasicAuth  from datetime import datetime, timedelta  from typing import List  from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks  from discord.ext.commands import Bot, Cog, Context, group  from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__)  class Reddit(Cog):      """Track subreddit posts and show detailed statistics about them.""" -    HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} +    # Change your client's User-Agent string to something unique and descriptive, +    # including the target platform, a unique application identifier, a version string, +    # and your username as contact information, in the following format: +    # <platform>:<app ID>:<version string> (by /u/<reddit username>) +    USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)"      URL = "https://www.reddit.com" +    OAUTH_URL = "https://oauth.reddit.com"      MAX_FETCH_RETRIES = 3      def __init__(self, bot: Bot): @@ -34,6 +41,66 @@ class Reddit(Cog):          self.new_posts_task = None          self.top_weekly_posts_task = None +        self.refresh_access_token.start() + +    @tasks.loop(hours=0.99)  # access tokens are valid for one hour +    async def refresh_access_token(self) -> None: +        """Refresh the access token""" +        headers = {"Authorization": self.client_auth} +        data = { +            "grant_type": "refresh_token", +            "refresh_token": self.refresh_token +        } + +        response = await self.bot.http_session.post( +            url = f"{self.URL}/api/v1/access_token", +            headers=headers, +            data=data, +        ) + +        content = await response.json() +        self.access_token = content["access_token"] +        self.headers = { +            "Authorization": "bearer " + self.access_token, +            "User-Agent": self.USER_AGENT +        } + +    @refresh_access_token.before_loop +    async def get_tokens(self) -> None: +        """Get Reddit access and refresh tokens""" +        await self.bot.wait_until_ready() + +        headers = {"User-Agent": self.USER_AGENT} +        data = { +            "grant_type": "client_credentials", +            "duration": "permanent" +        } + +        if RedditConfig.client_id and RedditConfig.secret: +            self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + +            response = await self.bot.http_session.post( +                url=f"{self.URL}/api/v1/access_token", +                headers=headers, +                auth=self.client_auth, +                data=data +            ) + +            content = await response.json() +            self.access_token = content["access_token"] +            self.refresh_token = content["refresh_token"] +            self.headers = { +                "Authorization": "bearer " + self.access_token, +                "User-Agent": self.USER_AGENT +            } +        else: +            self.client_auth = None +            self.access_token = None +            self.refresh_token = None +            self.headers = None + +            log.error("Unable to find client credentials.") +      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route."""          # Reddit's JSON responses only provide 25 posts at most. @@ -43,11 +110,11 @@ class Reddit(Cog):          if params is None:              params = {} -        url = f"{self.URL}/{route}.json" +        url = f"{self.OAUTH_URL}/{route}"          for _ in range(self.MAX_FETCH_RETRIES):              response = await self.bot.http_session.get(                  url=url, -                headers=self.HEADERS, +                headers=self.headers,                  params=params              )              if response.status == 200 and response.content_type == 'application/json': @@ -55,7 +122,7 @@ class Reddit(Cog):                  content = await response.json()                  posts = content["data"]["children"]                  return posts[:amount] - +                          await asyncio.sleep(3)          log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -127,8 +194,8 @@ class Reddit(Cog):              for subreddit in RedditConfig.subreddits:                  # Make a HEAD request to the subreddit                  head_response = await self.bot.http_session.head( -                    url=f"{self.URL}/{subreddit}/new.rss", -                    headers=self.HEADERS +                    url=f"{self.OAUTH_URL}/{subreddit}/new.rss", +                    headers=self.headers                  )                  content_length = head_response.headers["content-length"] diff --git a/bot/constants.py b/bot/constants.py index 1deeaa3b8..f84889e10 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter):      request_delay: int      subreddits: list +    client_id: str +    secret: str  class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 0dac9bf9f..c43ea4f8f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit:      request_delay: 60      subreddits:          - 'r/Python' +    client_id: !ENV "REDDIT_CLIENT_ID" +    secret:    !ENV "REDDIT_SECRET"  wolfram: | 
