diff options
| -rw-r--r-- | bot/cogs/reddit.py | 35 | 
1 files changed, 16 insertions, 19 deletions
| diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 25df014f8..451d2bf4c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,10 @@ import asyncio  import logging  import random  import textwrap -from aiohttp import BasicAuth  from datetime import datetime, timedelta  from typing import List +from aiohttp import BasicAuth  from discord import Colour, Embed, Message, TextChannel  from discord.ext import tasks  from discord.ext.commands import Bot, Cog, Context, group @@ -46,7 +46,7 @@ class Reddit(Cog):      @tasks.loop(hours=0.99)  # access tokens are valid for one hour      async def refresh_access_token(self) -> None: -        """Refresh the access token""" +        """Refresh Reddits access token."""          headers = {"Authorization": self.client_auth}          data = {              "grant_type": "refresh_token", @@ -54,7 +54,7 @@ class Reddit(Cog):          }          response = await self.bot.http_session.post( -            url = f"{self.URL}/api/v1/access_token", +            url=f"{self.URL}/api/v1/access_token",              headers=headers,              data=data,          ) @@ -68,7 +68,7 @@ class Reddit(Cog):      @refresh_access_token.before_loop      async def get_tokens(self) -> None: -        """Get Reddit access and refresh tokens""" +        """Get Reddit access and refresh tokens."""          await self.bot.wait_until_ready()          headers = {"User-Agent": self.USER_AGENT} @@ -77,16 +77,16 @@ class Reddit(Cog):              "duration": "permanent"          } -        if RedditConfig.client_id and RedditConfig.secret: -            self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) -            response = await self.bot.http_session.post( -                url=f"{self.URL}/api/v1/access_token", -                headers=headers, -                auth=self.client_auth, -                data=data -            ) +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/access_token", +            headers=headers, +            auth=self.client_auth, +            data=data +        ) +        if response.status == 200 and response.content_type == "application/json":              content = await response.json()              self.access_token = content["access_token"]              self.refresh_token = content["refresh_token"] @@ -95,12 +95,9 @@ class Reddit(Cog):                  "User-Agent": self.USER_AGENT              }          else: -            self.client_auth = None -            self.access_token = None -            self.refresh_token = None -            self.headers = None - -            log.error("Unable to find client credentials.") +            log.error("Authentication with Reddit API failed. Unloading extension.") +            self.bot.remove_cog(self.__class__.__name__) +            return      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -123,7 +120,7 @@ class Reddit(Cog):                  content = await response.json()                  posts = content["data"]["children"]                  return posts[:amount] -             +              await asyncio.sleep(3)          log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") | 
