diff options
Diffstat (limited to 'bot/exts/evergreen/reddit.py')
-rw-r--r-- | bot/exts/evergreen/reddit.py | 360 |
1 files changed, 360 insertions, 0 deletions
diff --git a/bot/exts/evergreen/reddit.py b/bot/exts/evergreen/reddit.py new file mode 100644 index 00000000..fb447cda --- /dev/null +++ b/bot/exts/evergreen/reddit.py @@ -0,0 +1,360 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List, Union + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.tasks import loop +from discord.utils import escape_markdown, sleep_until + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES +from bot.utils.converters import Subreddit +from bot.utils.messages import sub_clyde +from bot.utils.pagination import ImagePaginator, LinePaginator + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): + """Track subreddit posts and show detailed statistics about them.""" + + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} + URL = "https://www.reddit.com" + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 + + def __init__(self, bot: Bot): + self.bot = bot + + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + bot.loop.create_task(self.init_reddit_ready()) + self.auto_poster_loop.start() + + def cog_unload(self) -> None: + """Stop the loop task and revoke the access token when the cog is unloaded.""" + self.auto_poster_loop.cancel() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) + + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_guild_available() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(RedditConfig.webhook) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) + + def build_pagination_pages(self, posts: List[dict]) -> List[tuple]: + """Build embed pages required for Paginator.""" + pages = [] + first_page = "" + for i, post in enumerate(posts, start=1): + post_page = "" + image_url = "" + + data = post["data"] + + title = textwrap.shorten(data["title"], width=64, placeholder="...") + + # Normal brackets interfere with Markdown. + title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") + link = self.URL + data["permalink"] + + first_page += f"**{i}. [{title.replace('*', '')}]({link})**\n" + post_page += f"**{i}. [{title}]({link})**\n\n" + + text = data["selftext"] + if text: + first_page += textwrap.shorten(text, width=128, placeholder="...").replace("*", "") + "\n" + post_page += textwrap.shorten(text, width=252, placeholder="...") + "\n\n" + + ups = data["ups"] + comments = data["num_comments"] + author = data["author"] + + content_type = Emojis.reddit_post_text + if data["is_video"] is True or "youtube" in data["url"].split("."): + # This means the content type in the post is a video. + content_type = f"{Emojis.reddit_post_video}" + + elif any(data["url"].endswith(pic_format) for pic_format in ("jpg", "png", "gif")): + # This means the content type in the post is an image. + content_type = f"{Emojis.reddit_post_photo}" + image_url = data["url"] + + first_page += ( + f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}" + f"\u2002{comments}\u2003{Emojis.reddit_users}{author}\n\n" + ) + post_page += ( + f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}\u2002" + f"{comments}\u2003{Emojis.reddit_users}{author}" + ) + + pages.append((post_page, image_url)) + + pages.insert(0, (first_page, "")) + return pages + + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: + """A helper method to fetch a certain amount of Reddit posts at a given route.""" + # Reddit's JSON responses only provide 25 posts at most. + if not 25 >= amount > 0: + raise ValueError("Invalid amount of subreddit posts requested.") + + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.get( + url=url, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, + params=params + ) + if response.status == 200 and response.content_type == 'application/json': + # Got appropriate response - process and return. + content = await response.json() + posts = content["data"]["children"] + + filtered_posts = [post for post in posts if not post["data"]["over_18"]] + + return filtered_posts[:amount] + + await asyncio.sleep(3) + + log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") + return list() # Failed to get appropriate response within allowed number of retries. + + async def get_top_posts( + self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False + ) -> Union[Embed, List[tuple]]: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) + if not posts: + embed.title = random.choice(ERROR_REPLIES) + embed.colour = Colour.red() + embed.description = ( + "Sorry! We couldn't find any SFW posts from that subreddit. " + "If this problem persists, please let us know." + ) + + return embed + + pages = self.build_pagination_pages(posts) + + if paginate: + return pages + + embed.description += pages[0] + embed.colour = Colour.blurple() + return embed + + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once d.py get support for `time` parameter in loop decorator, + # this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + tomorrow = now + timedelta(days=1) + midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + + await sleep_until(midnight_tomorrow) + + await self.bot.wait_until_guild_available() + if not self.webhook: + await self.bot.fetch_webhook(RedditConfig.webhook) + + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts + + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + username = sub_clyde(f"{subreddit} Top Daily Posts") + message = await self.webhook.send(username=username, embed=top_posts, wait=True) + + if message.channel.is_news(): + await message.publish() + + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") + username = sub_clyde(f"{subreddit} Top Weekly Posts") + message = await self.webhook.send(wait=True, username=username, embed=top_posts) + + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return + + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() + + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] + + await message.pin() + + if message.channel.is_news(): + await message.publish() + + @group(name="reddit", invoke_without_command=True) + async def reddit_group(self, ctx: Context) -> None: + """View the top posts from various subreddits.""" + await ctx.send_help(ctx.command) + + @reddit_group.command(name="top") + async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of all time from a given subreddit.""" + async with ctx.typing(): + pages = await self.get_top_posts(subreddit=subreddit, time="all", paginate=True) + + embed = Embed( + title=f"{Emojis.reddit} {subreddit} - Top\n\n", + color=Colour.blurple() + ) + + await ImagePaginator.paginate(pages, ctx, embed) + + @reddit_group.command(name="daily") + async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of today from a given subreddit.""" + async with ctx.typing(): + pages = await self.get_top_posts(subreddit=subreddit, time="day", paginate=True) + + embed = Embed( + title=f"{Emojis.reddit} {subreddit} - Daily\n\n", + color=Colour.blurple() + ) + + await ImagePaginator.paginate(pages, ctx, embed) + + @reddit_group.command(name="weekly") + async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: + """Send the top posts of this week from a given subreddit.""" + async with ctx.typing(): + pages = await self.get_top_posts(subreddit=subreddit, time="week", paginate=True) + + embed = Embed( + title=f"{Emojis.reddit} {subreddit} - Weekly\n\n", + color=Colour.blurple() + ) + + await ImagePaginator.paginate(pages, ctx, embed) + + @has_any_role(*STAFF_ROLES) + @reddit_group.command(name="subreddits", aliases=("subs",)) + async def subreddits_command(self, ctx: Context) -> None: + """Send a paginated embed of all the subreddits we're relaying.""" + embed = Embed() + embed.title = "Relayed subreddits." + embed.colour = Colour.blurple() + + await LinePaginator.paginate( + RedditConfig.subreddits, + ctx, embed, + footer_text="Use the reddit commands along with these to view their posts.", + empty=False, + max_lines=15 + ) + + +def setup(bot: Bot) -> None: + """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return + bot.add_cog(Reddit(bot)) |