diff options
| -rw-r--r-- | bot/cogs/reddit.py | 52 | 
1 files changed, 32 insertions, 20 deletions
| diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 64a940af1..0ebf2e1a7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,7 +2,8 @@ import asyncio  import logging  import random  import textwrap -from datetime import datetime +from collections import namedtuple +from datetime import datetime, timedelta  from typing import List  from aiohttp import BasicAuth @@ -21,11 +22,7 @@ log = logging.getLogger(__name__)  class Reddit(Cog):      """Track subreddit posts and show detailed statistics about them.""" -    # Change your client's User-Agent string to something unique and descriptive, -    # including the target platform, a unique application identifier, a version string, -    # and your username as contact information, in the following format: -    # <platform>:<app ID>:<version string> (by /u/<reddit username>) -    USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" +    USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)"      URL = "https://www.reddit.com"      OAUTH_URL = "https://oauth.reddit.com"      MAX_RETRIES = 3 @@ -33,7 +30,8 @@ class Reddit(Cog):      def __init__(self, bot: Bot):          self.bot = bot -        self.webhook = None  # set in on_ready +        self.webhook = None +        self.access_token = None          bot.loop.create_task(self.init_reddit_ready())          self.auto_poster_loop.start() @@ -41,6 +39,8 @@ class Reddit(Cog):      def cog_unload(self) -> None:          """Stops the loops when the cog is unloaded."""          self.auto_poster_loop.cancel() +        if self.access_token.expires_at < datetime.utcnow(): +            self.revoke_access_token()      async def init_reddit_ready(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" @@ -53,7 +53,7 @@ class Reddit(Cog):          """Get the #reddit channel object from the bot's cache."""          return self.bot.get_channel(Channels.reddit) -    async def get_access_tokens(self) -> None: +    async def get_access_token(self) -> None:          """Get Reddit access tokens."""          headers = {"User-Agent": self.USER_AGENT}          data = { @@ -61,6 +61,7 @@ class Reddit(Cog):              "duration": "temporary"          } +        log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}")          self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)          for _ in range(self.MAX_RETRIES): @@ -72,9 +73,13 @@ class Reddit(Cog):              )              if response.status == 200 and response.content_type == "application/json":                  content = await response.json() -                self.access_token = content["access_token"] +                AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) +                self.access_token = AccessToken( +                    token=content["access_token"], +                    expires_at=datetime.utcnow() + timedelta(hours=1) +                )                  self.headers = { -                    "Authorization": "bearer " + self.access_token, +                    "Authorization": "bearer " + self.access_token.token,                      "User-Agent": self.USER_AGENT                  }                  return @@ -91,7 +96,7 @@ class Reddit(Cog):          # The token should be revoked, since the API is called only once a day.          headers = {"User-Agent": self.USER_AGENT}          data = { -            "token": self.access_token, +            "token": self.access_token.token,              "token_type_hint": "access_token"          } @@ -200,7 +205,10 @@ class Reddit(Cog):          if not self.webhook:              await self.bot.fetch_webhook(Webhooks.reddit) -        await self.get_access_tokens() +        if not self.access_token: +            await self.get_access_token() +        elif self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token()          if datetime.utcnow().weekday() == 0:              await self.top_weekly_posts() @@ -210,8 +218,6 @@ class Reddit(Cog):              top_posts = await self.get_top_posts(subreddit=subreddit, time="day")              await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) -        await self.revoke_access_token() -      async def top_weekly_posts(self) -> None:          """Post a summary of the top posts."""          for subreddit in RedditConfig.subreddits: @@ -242,32 +248,38 @@ class Reddit(Cog):      @reddit_group.command(name="top")      async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of all time from a given subreddit.""" +        if not self.access_token: +            await self.get_access_token() +        elif self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token()          async with ctx.typing(): -            await self.get_access_tokens()              embed = await self.get_top_posts(subreddit=subreddit, time="all")          await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) -        await self.revoke_access_token()      @reddit_group.command(name="daily")      async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of today from a given subreddit.""" +        if not self.access_token: +            await self.get_access_token() +        elif self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token()          async with ctx.typing(): -            await self.get_access_tokens()              embed = await self.get_top_posts(subreddit=subreddit, time="day")          await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) -        await self.revoke_access_token()      @reddit_group.command(name="weekly")      async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:          """Send the top posts of this week from a given subreddit.""" +        if not self.access_token: +            await self.get_access_token() +        elif self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token()          async with ctx.typing(): -            await self.get_access_tokens()              embed = await self.get_top_posts(subreddit=subreddit, time="week")          await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) -        await self.revoke_access_token()      @with_role(*STAFF_ROLES)      @reddit_group.command(name="subreddits", aliases=("subs",)) | 
