aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--azure-pipelines.yml2
-rw-r--r--bot/cogs/reddit.py82
-rw-r--r--bot/constants.py2
-rw-r--r--config-default.yml2
-rw-r--r--docker-compose.yml2
5 files changed, 83 insertions, 7 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index da3b06201..0400ac4d2 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,7 +30,7 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner
+ - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
displayName: Run tests
- script: coverage report -m && coverage xml -o coverage.xml
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 7749d237f..7e2ba40d5 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -5,6 +5,7 @@ import textwrap
from datetime import datetime
from typing import List
+from aiohttp import BasicAuth
from discord import Colour, Embed, TextChannel
from discord.ext.commands import Bot, Cog, Context, group
from discord.ext.tasks import loop
@@ -19,10 +20,14 @@ log = logging.getLogger(__name__)
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
-
- HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"}
+ # Change your client's User-Agent string to something unique and descriptive,
+ # including the target platform, a unique application identifier, a version string,
+ # and your username as contact information, in the following format:
+ # <platform>:<app ID>:<version string> (by /u/<reddit username>)
+ USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)"
URL = "https://www.reddit.com"
- MAX_FETCH_RETRIES = 3
+ OAUTH_URL = "https://oauth.reddit.com"
+ MAX_RETRIES = 3
def __init__(self, bot: Bot):
self.bot = bot
@@ -47,6 +52,61 @@ class Reddit(Cog):
"""Get the #reddit channel object from the bot's cache."""
return self.bot.get_channel(Channels.reddit)
+ async def get_access_tokens(self) -> None:
+ """Get Reddit access tokens."""
+ headers = {"User-Agent": self.USER_AGENT}
+ data = {
+ "grant_type": "client_credentials",
+ "duration": "temporary"
+ }
+
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
+
+ for _ in range(self.MAX_RETRIES):
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ auth=self.client_auth,
+ data=data
+ )
+ if response.status == 200 and response.content_type == "application/json":
+ content = await response.json()
+ self.access_token = content["access_token"]
+ self.headers = {
+ "Authorization": "bearer " + self.access_token,
+ "User-Agent": self.USER_AGENT
+ }
+ return
+
+ await asyncio.sleep(3)
+
+ log.error("Authentication with Reddit API failed. Unloading extension.")
+ self.bot.remove_cog(self.__class__.__name__)
+ return
+
+ async def revoke_access_token(self) -> None:
+ """Revoke the access token for Reddit API."""
+ # Access tokens are valid for 1 hour.
+ # The token should be revoked, since the API is called only once a day.
+ headers = {"User-Agent": self.USER_AGENT}
+ data = {
+ "token": self.access_token,
+ "token_type_hint": "access_token"
+ }
+
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/revoke_token",
+ headers=headers,
+ auth=self.client_auth,
+ data=data
+ )
+ if response.status == 204 and response.content_type == "application/json":
+ self.access_token = None
+ self.headers = None
+ return
+
+ log.warning(f"Unable to revoke access token, status code {response.status}.")
+
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
@@ -56,11 +116,11 @@ class Reddit(Cog):
if params is None:
params = {}
- url = f"{self.URL}/{route}.json"
- for _ in range(self.MAX_FETCH_RETRIES):
+ url = f"{self.OAUTH_URL}/{route}"
+ for _ in range(self.MAX_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers=self.HEADERS,
+ headers=self.headers,
params=params
)
if response.status == 200 and response.content_type == 'application/json':
@@ -139,6 +199,8 @@ class Reddit(Cog):
if not self.webhook:
await self.bot.fetch_webhook(Webhooks.reddit)
+ await self.get_access_tokens()
+
if datetime.utcnow().weekday() == 0:
await self.top_weekly_posts()
# if it's a monday send the top weekly posts
@@ -147,6 +209,8 @@ class Reddit(Cog):
top_posts = await self.get_top_posts(subreddit=subreddit, time="day")
await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts)
+ await self.revoke_access_token()
+
async def top_weekly_posts(self) -> None:
"""Post a summary of the top posts."""
for subreddit in RedditConfig.subreddits:
@@ -178,25 +242,31 @@ class Reddit(Cog):
async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of all time from a given subreddit."""
async with ctx.typing():
+ await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="all")
await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed)
+ await self.revoke_access_token()
@reddit_group.command(name="daily")
async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of today from a given subreddit."""
async with ctx.typing():
+ await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="day")
await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed)
+ await self.revoke_access_token()
@reddit_group.command(name="weekly")
async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of this week from a given subreddit."""
async with ctx.typing():
+ await self.get_access_tokens()
embed = await self.get_top_posts(subreddit=subreddit, time="week")
await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)
+ await self.revoke_access_token()
@with_role(*STAFF_ROLES)
@reddit_group.command(name="subreddits", aliases=("subs",))
diff --git a/bot/constants.py b/bot/constants.py
index 838fe7a79..b11ab65e9 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -446,6 +446,8 @@ class Reddit(metaclass=YAMLGetter):
section = "reddit"
subreddits: list
+ client_id: str
+ secret: str
class Wolfram(metaclass=YAMLGetter):
diff --git a/config-default.yml b/config-default.yml
index 4638a89ee..bd85e1509 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -353,6 +353,8 @@ anti_malware:
reddit:
subreddits:
- 'r/Python'
+ client_id: !ENV "REDDIT_CLIENT_ID"
+ secret: !ENV "REDDIT_SECRET"
wolfram:
diff --git a/docker-compose.yml b/docker-compose.yml
index f79fdba58..7281c7953 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -42,3 +42,5 @@ services:
environment:
BOT_TOKEN: ${BOT_TOKEN}
BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531
+ REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID}
+ REDDIT_SECRET: ${REDDIT_SECRET}