aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/utilities/reddit.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/exts/utilities/reddit.py')
-rw-r--r--bot/exts/utilities/reddit.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/bot/exts/utilities/reddit.py b/bot/exts/utilities/reddit.py
index cfc70d85..5dd4a377 100644
--- a/bot/exts/utilities/reddit.py
+++ b/bot/exts/utilities/reddit.py
@@ -20,16 +20,15 @@ from bot.utils.pagination import ImagePaginator, LinePaginator
log = logging.getLogger(__name__)
AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
+HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
+URL = "https://www.reddit.com"
+OAUTH_URL = "https://oauth.reddit.com"
+MAX_RETRIES = 3
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
- URL = "https://www.reddit.com"
- OAUTH_URL = "https://oauth.reddit.com"
- MAX_RETRIES = 3
-
def __init__(self, bot: Bot):
self.bot = bot
@@ -37,7 +36,7 @@ class Reddit(Cog):
self.access_token = None
self.client_auth = BasicAuth(RedditConfig.client_id.get_secret_value(), RedditConfig.secret.get_secret_value())
- # self.auto_poster_loop.start()
+ self.auto_poster_loop.start()
async def cog_unload(self) -> None:
"""Stop the loop task and revoke the access token when the cog is unloaded."""
@@ -68,7 +67,7 @@ class Reddit(Cog):
# Normal brackets interfere with Markdown.
title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌")
- link = self.URL + data["permalink"]
+ link = URL + data["permalink"]
first_page += f"**[{title.replace('*', '')}]({link})**\n"
@@ -121,10 +120,10 @@ class Reddit(Cog):
A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
will be unloaded and a ClientError raised if retrieval was still unsuccessful.
"""
- for i in range(1, self.MAX_RETRIES + 1):
+ for i in range(1, MAX_RETRIES + 1):
response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/access_token",
- headers=self.HEADERS,
+ url=f"{URL}/api/v1/access_token",
+ headers=HEADERS,
auth=self.client_auth,
data={
"grant_type": "client_credentials",
@@ -144,7 +143,7 @@ class Reddit(Cog):
return
log.debug(
f"Failed to get an access token: status {response.status} & content type {response.content_type}; "
- f"retrying ({i}/{self.MAX_RETRIES})"
+ f"retrying ({i}/{MAX_RETRIES})"
)
await asyncio.sleep(3)
@@ -159,8 +158,8 @@ class Reddit(Cog):
For security reasons, it's good practice to revoke the token when it's no longer being used.
"""
response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/revoke_token",
- headers=self.HEADERS,
+ url=f"{URL}/api/v1/revoke_token",
+ headers=HEADERS,
auth=self.client_auth,
data={
"token": self.access_token.token,
@@ -173,7 +172,7 @@ class Reddit(Cog):
else:
log.warning(f"Unable to revoke access token: status {response.status}.")
- async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> list[dict]:
+ async def fetch_posts(self, route: str, *, amount: int = 25, params: dict | None = None) -> list[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
if not 25 >= amount > 0:
@@ -183,11 +182,11 @@ class Reddit(Cog):
if not self.access_token or self.access_token.expires_at < datetime.now(tz=UTC):
await self.get_access_token()
- url = f"{self.OAUTH_URL}/{route}"
- for _ in range(self.MAX_RETRIES):
+ url = f"{OAUTH_URL}/{route}"
+ for _ in range(MAX_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},
+ headers=HEADERS | {"Authorization": f"bearer {self.access_token.token}"},
params=params
)
if response.status == 200 and response.content_type == "application/json":