aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Jens <[email protected]>2019-10-15 23:36:36 +0200
committerGravatar Jens <[email protected]>2019-10-15 23:36:36 +0200
commit79ed098809ce3cfaa0fa75608f6f6a85af2a90dd (patch)
tree5882b2e330b271f67b31511b53d391468562ba9f
parentAdd Reddit OAuth tasks and refactor code (diff)
Unload cog on auth error and fix linting warnings
-rw-r--r--bot/cogs/reddit.py35
1 files changed, 16 insertions, 19 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 25df014f8..451d2bf4c 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,10 +2,10 @@ import asyncio
import logging
import random
import textwrap
-from aiohttp import BasicAuth
from datetime import datetime, timedelta
from typing import List
+from aiohttp import BasicAuth
from discord import Colour, Embed, Message, TextChannel
from discord.ext import tasks
from discord.ext.commands import Bot, Cog, Context, group
@@ -46,7 +46,7 @@ class Reddit(Cog):
@tasks.loop(hours=0.99) # access tokens are valid for one hour
async def refresh_access_token(self) -> None:
- """Refresh the access token"""
+ """Refresh Reddits access token."""
headers = {"Authorization": self.client_auth}
data = {
"grant_type": "refresh_token",
@@ -54,7 +54,7 @@ class Reddit(Cog):
}
response = await self.bot.http_session.post(
- url = f"{self.URL}/api/v1/access_token",
+ url=f"{self.URL}/api/v1/access_token",
headers=headers,
data=data,
)
@@ -68,7 +68,7 @@ class Reddit(Cog):
@refresh_access_token.before_loop
async def get_tokens(self) -> None:
- """Get Reddit access and refresh tokens"""
+ """Get Reddit access and refresh tokens."""
await self.bot.wait_until_ready()
headers = {"User-Agent": self.USER_AGENT}
@@ -77,16 +77,16 @@ class Reddit(Cog):
"duration": "permanent"
}
- if RedditConfig.client_id and RedditConfig.secret:
- self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
- response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/access_token",
- headers=headers,
- auth=self.client_auth,
- data=data
- )
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=headers,
+ auth=self.client_auth,
+ data=data
+ )
+ if response.status == 200 and response.content_type == "application/json":
content = await response.json()
self.access_token = content["access_token"]
self.refresh_token = content["refresh_token"]
@@ -95,12 +95,9 @@ class Reddit(Cog):
"User-Agent": self.USER_AGENT
}
else:
- self.client_auth = None
- self.access_token = None
- self.refresh_token = None
- self.headers = None
-
- log.error("Unable to find client credentials.")
+ log.error("Authentication with Reddit API failed. Unloading extension.")
+ self.bot.remove_cog(self.__class__.__name__)
+ return
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
@@ -123,7 +120,7 @@ class Reddit(Cog):
content = await response.json()
posts = content["data"]["children"]
return posts[:amount]
-
+
await asyncio.sleep(3)
log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")