diff options
Diffstat (limited to 'bot/exts/utilities/reddit.py')
-rw-r--r-- | bot/exts/utilities/reddit.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/bot/exts/utilities/reddit.py b/bot/exts/utilities/reddit.py index cfc70d85..5dd4a377 100644 --- a/bot/exts/utilities/reddit.py +++ b/bot/exts/utilities/reddit.py @@ -20,16 +20,15 @@ from bot.utils.pagination import ImagePaginator, LinePaginator log = logging.getLogger(__name__) AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) +HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} +URL = "https://www.reddit.com" +OAUTH_URL = "https://oauth.reddit.com" +MAX_RETRIES = 3 class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - def __init__(self, bot: Bot): self.bot = bot @@ -37,7 +36,7 @@ class Reddit(Cog): self.access_token = None self.client_auth = BasicAuth(RedditConfig.client_id.get_secret_value(), RedditConfig.secret.get_secret_value()) - # self.auto_poster_loop.start() + self.auto_poster_loop.start() async def cog_unload(self) -> None: """Stop the loop task and revoke the access token when the cog is unloaded.""" @@ -68,7 +67,7 @@ class Reddit(Cog): # Normal brackets interfere with Markdown. title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") - link = self.URL + data["permalink"] + link = URL + data["permalink"] first_page += f"**[{title.replace('*', '')}]({link})**\n" @@ -121,10 +120,10 @@ class Reddit(Cog): A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog will be unloaded and a ClientError raised if retrieval was still unsuccessful. """ - for i in range(1, self.MAX_RETRIES + 1): + for i in range(1, MAX_RETRIES + 1): response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, + url=f"{URL}/api/v1/access_token", + headers=HEADERS, auth=self.client_auth, data={ "grant_type": "client_credentials", @@ -144,7 +143,7 @@ class Reddit(Cog): return log.debug( f"Failed to get an access token: status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" + f"retrying ({i}/{MAX_RETRIES})" ) await asyncio.sleep(3) @@ -159,8 +158,8 @@ class Reddit(Cog): For security reasons, it's good practice to revoke the token when it's no longer being used. """ response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, + url=f"{URL}/api/v1/revoke_token", + headers=HEADERS, auth=self.client_auth, data={ "token": self.access_token.token, @@ -173,7 +172,7 @@ class Reddit(Cog): else: log.warning(f"Unable to revoke access token: status {response.status}.") - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> list[dict]: + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict | None = None) -> list[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: @@ -183,11 +182,11 @@ class Reddit(Cog): if not self.access_token or self.access_token.expires_at < datetime.now(tz=UTC): await self.get_access_token() - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): + url = f"{OAUTH_URL}/{route}" + for _ in range(MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + headers=HEADERS | {"Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == "application/json": |