aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Jens <[email protected]>2019-12-05 22:07:59 +0100
committerGravatar Jens <[email protected]>2019-12-05 22:07:59 +0100
commita9dc1000872f507a850798b204befed299b6f703 (patch)
tree56c92c8c17ed1034603b69088dd72975a0d25c95
parentFix linting error (diff)
Keeps access token alive, only revokes it on extension unload.
Hard-coded version number to 1.0.0.
-rw-r--r--bot/cogs/reddit.py52
1 files changed, 32 insertions, 20 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 64a940af1..0ebf2e1a7 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,7 +2,8 @@ import asyncio
import logging
import random
import textwrap
-from datetime import datetime
+from collections import namedtuple
+from datetime import datetime, timedelta
from typing import List
from aiohttp import BasicAuth
@@ -21,11 +22,7 @@ log = logging.getLogger(__name__)
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- # Change your client's User-Agent string to something unique and descriptive,
- # including the target platform, a unique application identifier, a version string,
- # and your username as contact information, in the following format:
- # <platform>:<app ID>:<version string> (by /u/<reddit username>)
- USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)"
+ USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)"
URL = "https://www.reddit.com"
OAUTH_URL = "https://oauth.reddit.com"
MAX_RETRIES = 3
@@ -33,7 +30,8 @@ class Reddit(Cog):
def __init__(self, bot: Bot):
self.bot = bot
- self.webhook = None # set in on_ready
+ self.webhook = None
+ self.access_token = None
bot.loop.create_task(self.init_reddit_ready())
self.auto_poster_loop.start()
@@ -41,6 +39,8 @@ class Reddit(Cog):
def cog_unload(self) -> None:
"""Stops the loops when the cog is unloaded."""
self.auto_poster_loop.cancel()
+ if self.access_token.expires_at < datetime.utcnow():
+ self.revoke_access_token()
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
@@ -53,7 +53,7 @@ class Reddit(Cog):
"""Get the #reddit channel object from the bot's cache."""
return self.bot.get_channel(Channels.reddit)
- async def get_access_tokens(self) -> None:
+ async def get_access_token(self) -> None:
"""Get Reddit access tokens."""
headers = {"User-Agent": self.USER_AGENT}
data = {
@@ -61,6 +61,7 @@ class Reddit(Cog):
"duration": "temporary"
}
+ log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}")
self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
for _ in range(self.MAX_RETRIES):
@@ -72,9 +73,13 @@ class Reddit(Cog):
)
if response.status == 200 and response.content_type == "application/json":
content = await response.json()
- self.access_token = content["access_token"]
+ AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
+ self.access_token = AccessToken(
+ token=content["access_token"],
+ expires_at=datetime.utcnow() + timedelta(hours=1)
+ )
self.headers = {
- "Authorization": "bearer " + self.access_token,
+ "Authorization": "bearer " + self.access_token.token,
"User-Agent": self.USER_AGENT
}
return
@@ -91,7 +96,7 @@ class Reddit(Cog):
# The token should be revoked, since the API is called only once a day.
headers = {"User-Agent": self.USER_AGENT}
data = {
- "token": self.access_token,
+ "token": self.access_token.token,
"token_type_hint": "access_token"
}
@@ -200,7 +205,10 @@ class Reddit(Cog):
if not self.webhook:
await self.bot.fetch_webhook(Webhooks.reddit)
- await self.get_access_tokens()
+ if not self.access_token:
+ await self.get_access_token()
+ elif self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
if datetime.utcnow().weekday() == 0:
await self.top_weekly_posts()
@@ -210,8 +218,6 @@ class Reddit(Cog):
top_posts = await self.get_top_posts(subreddit=subreddit, time="day")
await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts)
- await self.revoke_access_token()
-
async def top_weekly_posts(self) -> None:
"""Post a summary of the top posts."""
for subreddit in RedditConfig.subreddits:
@@ -242,32 +248,38 @@ class Reddit(Cog):
@reddit_group.command(name="top")
async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of all time from a given subreddit."""
+ if not self.access_token:
+ await self.get_access_token()
+ elif self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
async with ctx.typing():
- await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="all")
await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed)
- await self.revoke_access_token()
@reddit_group.command(name="daily")
async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of today from a given subreddit."""
+ if not self.access_token:
+ await self.get_access_token()
+ elif self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
async with ctx.typing():
- await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="day")
await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed)
- await self.revoke_access_token()
@reddit_group.command(name="weekly")
async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of this week from a given subreddit."""
+ if not self.access_token:
+ await self.get_access_token()
+ elif self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
async with ctx.typing():
- await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="week")
await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)
- await self.revoke_access_token()
@with_role(*STAFF_ROLES)
@reddit_group.command(name="subreddits", aliases=("subs",))