aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joseph <[email protected]>2019-12-13 00:42:19 +0000
committerGravatar GitHub <[email protected]>2019-12-13 00:42:19 +0000
commitba21a419cfb0879efecb42f58d80f605f9226473 (patch)
treec352b58b29a7c0e710fc023239e859c0b17f2724
parentUse OAuth to be Reddit API compliant (#510) (diff)
parentRevert "Use OAuth to be Reddit API compliant" (diff)
Revert "Use OAuth to be Reddit API compliant" (#695)
Revert "Use OAuth to be Reddit API compliant"
-rw-r--r--azure-pipelines.yml2
-rw-r--r--bot/cogs/reddit.py91
-rw-r--r--bot/constants.py2
-rw-r--r--config-default.yml2
-rw-r--r--docker-compose.yml2
5 files changed, 11 insertions, 88 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 0400ac4d2..da3b06201 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,7 +30,7 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
+ - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner
displayName: Run tests
- script: coverage report -m && coverage xml -o coverage.xml
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index aa487f18e..bec316ae7 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,11 +2,9 @@ import asyncio
import logging
import random
import textwrap
-from collections import namedtuple
from datetime import datetime, timedelta
from typing import List
-from aiohttp import BasicAuth, ClientError
from discord import Colour, Embed, TextChannel
from discord.ext.commands import Cog, Context, group
from discord.ext.tasks import loop
@@ -19,32 +17,25 @@ from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
-AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
-
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
+ HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"}
URL = "https://www.reddit.com"
- OAUTH_URL = "https://oauth.reddit.com"
- MAX_RETRIES = 3
+ MAX_FETCH_RETRIES = 3
def __init__(self, bot: Bot):
self.bot = bot
- self.webhook = None
- self.access_token = None
- self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
-
+ self.webhook = None # set in on_ready
bot.loop.create_task(self.init_reddit_ready())
+
self.auto_poster_loop.start()
def cog_unload(self) -> None:
- """Stop the loop task and revoke the access token when the cog is unloaded."""
+ """Stops the loops when the cog is unloaded."""
self.auto_poster_loop.cancel()
- if self.access_token.expires_at < datetime.utcnow():
- self.revoke_access_token()
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
@@ -57,82 +48,20 @@ class Reddit(Cog):
"""Get the #reddit channel object from the bot's cache."""
return self.bot.get_channel(Channels.reddit)
- async def get_access_token(self) -> None:
- """
- Get a Reddit API OAuth2 access token and assign it to self.access_token.
-
- A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
- will be unloaded and a ClientError raised if retrieval was still unsuccessful.
- """
- for i in range(1, self.MAX_RETRIES + 1):
- response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/access_token",
- headers=self.HEADERS,
- auth=self.client_auth,
- data={
- "grant_type": "client_credentials",
- "duration": "temporary"
- }
- )
-
- if response.status == 200 and response.content_type == "application/json":
- content = await response.json()
- expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway.
- self.access_token = AccessToken(
- token=content["access_token"],
- expires_at=datetime.utcnow() + timedelta(seconds=expiration)
- )
-
- log.debug(f"New token acquired; expires on {self.access_token.expires_at}")
- return
- else:
- log.debug(
- f"Failed to get an access token: "
- f"status {response.status} & content type {response.content_type}; "
- f"retrying ({i}/{self.MAX_RETRIES})"
- )
-
- await asyncio.sleep(3)
-
- self.bot.remove_cog(self.qualified_name)
- raise ClientError("Authentication with the Reddit API failed. Unloading the cog.")
-
- async def revoke_access_token(self) -> None:
- """
- Revoke the OAuth2 access token for the Reddit API.
-
- For security reasons, it's good practice to revoke the token when it's no longer being used.
- """
- response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/revoke_token",
- headers=self.HEADERS,
- auth=self.client_auth,
- data={
- "token": self.access_token.token,
- "token_type_hint": "access_token"
- }
- )
-
- if response.status == 204 and response.content_type == "application/json":
- self.access_token = None
- else:
- log.warning(f"Unable to revoke access token: status {response.status}.")
-
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
if not 25 >= amount > 0:
raise ValueError("Invalid amount of subreddit posts requested.")
- # Renew the token if necessary.
- if not self.access_token or self.access_token.expires_at < datetime.utcnow():
- await self.get_access_token()
+ if params is None:
+ params = {}
- url = f"{self.OAUTH_URL}/{route}"
- for _ in range(self.MAX_RETRIES):
+ url = f"{self.URL}/{route}.json"
+ for _ in range(self.MAX_FETCH_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},
+ headers=self.HEADERS,
params=params
)
if response.status == 200 and response.content_type == 'application/json':
diff --git a/bot/constants.py b/bot/constants.py
index ed85adf6a..89504a2e0 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -465,8 +465,6 @@ class Reddit(metaclass=YAMLGetter):
section = "reddit"
subreddits: list
- client_id: str
- secret: str
class Wolfram(metaclass=YAMLGetter):
diff --git a/config-default.yml b/config-default.yml
index e6f0fda21..930a1a0e6 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -365,8 +365,6 @@ anti_malware:
reddit:
subreddits:
- 'r/Python'
- client_id: !ENV "REDDIT_CLIENT_ID"
- secret: !ENV "REDDIT_SECRET"
wolfram:
diff --git a/docker-compose.yml b/docker-compose.yml
index 7281c7953..f79fdba58 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -42,5 +42,3 @@ services:
environment:
BOT_TOKEN: ${BOT_TOKEN}
BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531
- REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID}
- REDDIT_SECRET: ${REDDIT_SECRET}