diff options
Diffstat (limited to 'bot/exts/utilities/reddit.py')
| -rw-r--r-- | bot/exts/utilities/reddit.py | 33 | 
1 files changed, 15 insertions, 18 deletions
| diff --git a/bot/exts/utilities/reddit.py b/bot/exts/utilities/reddit.py index f7c196ae..f8e358de 100644 --- a/bot/exts/utilities/reddit.py +++ b/bot/exts/utilities/reddit.py @@ -3,8 +3,7 @@ import logging  import random  import textwrap  from collections import namedtuple -from datetime import datetime, timedelta -from typing import Union +from datetime import UTC, datetime, timedelta  from aiohttp import BasicAuth, ClientError  from discord import Colour, Embed, TextChannel @@ -43,8 +42,8 @@ class Reddit(Cog):      async def cog_unload(self) -> None:          """Stop the loop task and revoke the access token when the cog is unloaded."""          self.auto_poster_loop.cancel() -        if self.access_token and self.access_token.expires_at > datetime.utcnow(): -            asyncio.create_task(self.revoke_access_token()) +        if self.access_token and self.access_token.expires_at > datetime.now(tz=UTC): +            await self.revoke_access_token()      async def cog_load(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" @@ -55,7 +54,7 @@ class Reddit(Cog):          """Get the #reddit channel object from the bot's cache."""          return self.bot.get_channel(Channels.reddit) -    def build_pagination_pages(self, posts: list[dict], paginate: bool) -> Union[list[tuple], str]: +    def build_pagination_pages(self, posts: list[dict], paginate: bool) -> list[tuple] | str:          """Build embed pages required for Paginator."""          pages = []          first_page = "" @@ -138,17 +137,15 @@ class Reddit(Cog):                  expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway.                  self.access_token = AccessToken(                      token=content["access_token"], -                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) +                    expires_at=datetime.now(tz=UTC) + timedelta(seconds=expiration)                  )                  log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")                  return -            else: -                log.debug( -                    f"Failed to get an access token: " -                    f"status {response.status} & content type {response.content_type}; " -                    f"retrying ({i}/{self.MAX_RETRIES})" -                ) +            log.debug( +                f"Failed to get an access token: status {response.status} & content type {response.content_type}; " +                f"retrying ({i}/{self.MAX_RETRIES})" +            )              await asyncio.sleep(3) @@ -183,7 +180,7 @@ class Reddit(Cog):              raise ValueError("Invalid amount of subreddit posts requested.")          # Renew the token if necessary. -        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): +        if not self.access_token or self.access_token.expires_at < datetime.now(tz=UTC):              await self.get_access_token()          url = f"{self.OAUTH_URL}/{route}" @@ -193,7 +190,7 @@ class Reddit(Cog):                  headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},                  params=params              ) -            if response.status == 200 and response.content_type == 'application/json': +            if response.status == 200 and response.content_type == "application/json":                  # Got appropriate response - process and return.                  content = await response.json()                  posts = content["data"]["children"] @@ -205,11 +202,11 @@ class Reddit(Cog):              await asyncio.sleep(3)          log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") -        return list()  # Failed to get appropriate response within allowed number of retries. +        return []  # Failed to get appropriate response within allowed number of retries.      async def get_top_posts(              self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False -    ) -> Union[Embed, list[tuple]]: +    ) -> Embed | list[tuple]:          """          Get the top amount of posts for a given subreddit within a specified timeframe. @@ -248,7 +245,7 @@ class Reddit(Cog):          """Post the top 5 posts daily, and the top 5 posts weekly."""          # once d.py get support for `time` parameter in loop decorator,          # this can be removed and the loop can use the `time=datetime.time.min` parameter -        now = datetime.utcnow() +        now = datetime.now(tz=UTC)          tomorrow = now + timedelta(days=1)          midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) @@ -257,7 +254,7 @@ class Reddit(Cog):          if not self.webhook:              await self.bot.fetch_webhook(RedditConfig.webhook) -        if datetime.utcnow().weekday() == 0: +        if datetime.now(tz=UTC).weekday() == 0:              await self.top_weekly_posts()              # if it's a monday send the top weekly posts | 
