diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/reddit.py | 73 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 2 | 
3 files changed, 71 insertions, 6 deletions
| diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..7b183221c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -5,7 +5,9 @@ import textwrap  from datetime import datetime, timedelta  from typing import List +from aiohttp import BasicAuth  from discord import Colour, Embed, Message, TextChannel +from discord.ext import tasks  from discord.ext.commands import Bot, Cog, Context, group  from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES @@ -19,8 +21,13 @@ log = logging.getLogger(__name__)  class Reddit(Cog):      """Track subreddit posts and show detailed statistics about them.""" -    HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} +    # Change your client's User-Agent string to something unique and descriptive, +    # including the target platform, a unique application identifier, a version string, +    # and your username as contact information, in the following format: +    # <platform>:<app ID>:<version string> (by /u/<reddit username>) +    USER_AGENT = "docker:Discord Bot of https://pythondiscord.com/:v?.?.? (by /u/PythonDiscord)"      URL = "https://www.reddit.com" +    OAUTH_URL = "https://oauth.reddit.com"      MAX_FETCH_RETRIES = 3      def __init__(self, bot: Bot): @@ -36,6 +43,59 @@ class Reddit(Cog):          self.bot.loop.create_task(self.init_reddit_polling()) +    @tasks.loop(hours=0.99)  # access tokens are valid for one hour +    async def refresh_access_token(self) -> None: +        """Refresh Reddits access token.""" +        headers = {"Authorization": self.client_auth} +        data = { +            "grant_type": "refresh_token", +            "refresh_token": self.refresh_token +        } + +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/access_token", +            headers=headers, +            data=data, +        ) + +        content = await response.json() +        self.access_token = content["access_token"] +        self.headers = { +            "Authorization": "bearer " + self.access_token, +            "User-Agent": self.USER_AGENT +        } + +    @refresh_access_token.before_loop +    async def get_tokens(self) -> None: +        """Get Reddit access and refresh tokens.""" +        headers = {"User-Agent": self.USER_AGENT} +        data = { +            "grant_type": "client_credentials", +            "duration": "permanent" +        } + +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/access_token", +            headers=headers, +            auth=self.client_auth, +            data=data +        ) + +        if response.status == 200 and response.content_type == "application/json": +            content = await response.json() +            self.access_token = content["access_token"] +            self.refresh_token = content["refresh_token"] +            self.headers = { +                "Authorization": "bearer " + self.access_token, +                "User-Agent": self.USER_AGENT +            } +        else: +            log.error("Authentication with Reddit API failed. Unloading extension.") +            self.bot.remove_cog(self.__class__.__name__) +            return +      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route."""          # Reddit's JSON responses only provide 25 posts at most. @@ -45,11 +105,11 @@ class Reddit(Cog):          if params is None:              params = {} -        url = f"{self.URL}/{route}.json" +        url = f"{self.OAUTH_URL}/{route}"          for _ in range(self.MAX_FETCH_RETRIES):              response = await self.bot.http_session.get(                  url=url, -                headers=self.HEADERS, +                headers=self.headers,                  params=params              )              if response.status == 200 and response.content_type == 'application/json': @@ -57,7 +117,7 @@ class Reddit(Cog):                  content = await response.json()                  posts = content["data"]["children"]                  return posts[:amount] - +                          await asyncio.sleep(3)          log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") @@ -129,8 +189,8 @@ class Reddit(Cog):              for subreddit in RedditConfig.subreddits:                  # Make a HEAD request to the subreddit                  head_response = await self.bot.http_session.head( -                    url=f"{self.URL}/{subreddit}/new.rss", -                    headers=self.HEADERS +                    url=f"{self.OAUTH_URL}/{subreddit}/new.rss", +                    headers=self.headers                  )                  content_length = head_response.headers["content-length"] @@ -268,6 +328,7 @@ class Reddit(Cog):          """Initiate reddit post event loop."""          await self.bot.wait_until_ready()          self.reddit_channel = await self.bot.fetch_channel(Channels.reddit) +        self.refresh_access_token.start()          if self.reddit_channel is not None:              if self.new_posts_task is None: diff --git a/bot/constants.py b/bot/constants.py index f4f45eb2c..c49242d5e 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -440,6 +440,8 @@ class Reddit(metaclass=YAMLGetter):      request_delay: int      subreddits: list +    client_id: str +    secret: str  class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index ca405337e..3487dff27 100644 --- a/config-default.yml +++ b/config-default.yml @@ -326,6 +326,8 @@ reddit:      request_delay: 60      subreddits:          - 'r/Python' +    client_id: !ENV "REDDIT_CLIENT_ID" +    secret:    !ENV "REDDIT_SECRET"  wolfram: | 
