aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/reddit.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index d9e1f0a39..0a0279a39 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -41,7 +41,7 @@ class Reddit(Cog):
self.auto_poster_loop.start()
def cog_unload(self) -> None:
- """Stops the loops when the cog is unloaded."""
+ """Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
if self.access_token.expires_at < datetime.utcnow():
self.revoke_access_token()
@@ -58,7 +58,12 @@ class Reddit(Cog):
return self.bot.get_channel(Channels.reddit)
async def get_access_token(self) -> None:
- """Get a Reddit API OAuth2 access token."""
+ """
+ Get a Reddit API OAuth2 access token and assign it to self.access_token.
+
+ A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
+ will be unloaded if retrieval was still unsuccessful.
+ """
headers = {"User-Agent": self.USER_AGENT}
data = {
"grant_type": "client_credentials",
@@ -72,6 +77,7 @@ class Reddit(Cog):
auth=self.client_auth,
data=data
)
+
if response.status == 200 and response.content_type == "application/json":
content = await response.json()
expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway.
@@ -87,14 +93,16 @@ class Reddit(Cog):
await asyncio.sleep(3)
- log.error("Authentication with Reddit API failed. Unloading extension.")
+ log.error("Authentication with Reddit API failed. Unloading the cog.")
self.bot.remove_cog(self.qualified_name)
return
async def revoke_access_token(self) -> None:
- """Revoke the access token for Reddit API."""
- # Access tokens are valid for 1 hour.
- # The token should be revoked, since the API is called only once a day.
+ """
+ Revoke the OAuth2 access token for the Reddit API.
+
+ For security reasons, it's good practice to revoke the token when it's no longer being used.
+ """
headers = {"User-Agent": self.USER_AGENT}
data = {
"token": self.access_token.token,
@@ -107,12 +115,12 @@ class Reddit(Cog):
auth=self.client_auth,
data=data
)
+
if response.status == 204 and response.content_type == "application/json":
self.access_token = None
self.headers = None
- return
-
- log.warning(f"Unable to revoke access token, status code {response.status}.")
+ else:
+ log.warning(f"Unable to revoke access token: status {response.status}.")
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""