diff options
Diffstat (limited to 'bot/exts/evergreen')
| -rw-r--r-- | bot/exts/evergreen/reddit.py | 360 | 
1 files changed, 360 insertions, 0 deletions
| diff --git a/bot/exts/evergreen/reddit.py b/bot/exts/evergreen/reddit.py new file mode 100644 index 00000000..fb447cda --- /dev/null +++ b/bot/exts/evergreen/reddit.py @@ -0,0 +1,360 @@ +import asyncio +import logging +import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List, Union + +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.tasks import loop +from discord.utils import escape_markdown, sleep_until + +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES +from bot.utils.converters import Subreddit +from bot.utils.messages import sub_clyde +from bot.utils.pagination import ImagePaginator, LinePaginator + +log = logging.getLogger(__name__) + +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + + +class Reddit(Cog): +    """Track subreddit posts and show detailed statistics about them.""" + +    HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} +    URL = "https://www.reddit.com" +    OAUTH_URL = "https://oauth.reddit.com" +    MAX_RETRIES = 3 + +    def __init__(self, bot: Bot): +        self.bot = bot + +        self.webhook = None +        self.access_token = None +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + +        bot.loop.create_task(self.init_reddit_ready()) +        self.auto_poster_loop.start() + +    def cog_unload(self) -> None: +        """Stop the loop task and revoke the access token when the cog is unloaded.""" +        self.auto_poster_loop.cancel() +        if self.access_token and self.access_token.expires_at > datetime.utcnow(): +            asyncio.create_task(self.revoke_access_token()) + +    async def init_reddit_ready(self) -> None: +        """Sets the reddit webhook when the cog is loaded.""" +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            self.webhook = await self.bot.fetch_webhook(RedditConfig.webhook) + +    @property +    def channel(self) -> TextChannel: +        """Get the #reddit channel object from the bot's cache.""" +        return self.bot.get_channel(Channels.reddit) + +    def build_pagination_pages(self, posts: List[dict]) -> List[tuple]: +        """Build embed pages required for Paginator.""" +        pages = [] +        first_page = "" +        for i, post in enumerate(posts, start=1): +            post_page = "" +            image_url = "" + +            data = post["data"] + +            title = textwrap.shorten(data["title"], width=64, placeholder="...") + +            # Normal brackets interfere with Markdown. +            title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") +            link = self.URL + data["permalink"] + +            first_page += f"**{i}. [{title.replace('*', '')}]({link})**\n" +            post_page += f"**{i}. [{title}]({link})**\n\n" + +            text = data["selftext"] +            if text: +                first_page += textwrap.shorten(text, width=128, placeholder="...").replace("*", "") + "\n" +                post_page += textwrap.shorten(text, width=252, placeholder="...") + "\n\n" + +            ups = data["ups"] +            comments = data["num_comments"] +            author = data["author"] + +            content_type = Emojis.reddit_post_text +            if data["is_video"] is True or "youtube" in data["url"].split("."): +                # This means the content type in the post is a video. +                content_type = f"{Emojis.reddit_post_video}" + +            elif any(data["url"].endswith(pic_format) for pic_format in ("jpg", "png", "gif")): +                # This means the content type in the post is an image. +                content_type = f"{Emojis.reddit_post_photo}" +                image_url = data["url"] + +            first_page += ( +                f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}" +                f"\u2002{comments}\u2003{Emojis.reddit_users}{author}\n\n" +            ) +            post_page += ( +                f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}\u2002" +                f"{comments}\u2003{Emojis.reddit_users}{author}" +            ) + +            pages.append((post_page, image_url)) + +        pages.insert(0, (first_page, "")) +        return pages + +    async def get_access_token(self) -> None: +        """ +        Get a Reddit API OAuth2 access token and assign it to self.access_token. + +        A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog +        will be unloaded and a ClientError raised if retrieval was still unsuccessful. +        """ +        for i in range(1, self.MAX_RETRIES + 1): +            response = await self.bot.http_session.post( +                url=f"{self.URL}/api/v1/access_token", +                headers=self.HEADERS, +                auth=self.client_auth, +                data={ +                    "grant_type": "client_credentials", +                    "duration": "temporary" +                } +            ) + +            if response.status == 200 and response.content_type == "application/json": +                content = await response.json() +                expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway. +                self.access_token = AccessToken( +                    token=content["access_token"], +                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) +                ) + +                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") +                return +            else: +                log.debug( +                    f"Failed to get an access token: " +                    f"status {response.status} & content type {response.content_type}; " +                    f"retrying ({i}/{self.MAX_RETRIES})" +                ) + +            await asyncio.sleep(3) + +        self.bot.remove_cog(self.qualified_name) +        raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + +    async def revoke_access_token(self) -> None: +        """ +        Revoke the OAuth2 access token for the Reddit API. + +        For security reasons, it's good practice to revoke the token when it's no longer being used. +        """ +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/revoke_token", +            headers=self.HEADERS, +            auth=self.client_auth, +            data={ +                "token": self.access_token.token, +                "token_type_hint": "access_token" +            } +        ) + +        if response.status == 204 and response.content_type == "application/json": +            self.access_token = None +        else: +            log.warning(f"Unable to revoke access token: status {response.status}.") + +    async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: +        """A helper method to fetch a certain amount of Reddit posts at a given route.""" +        # Reddit's JSON responses only provide 25 posts at most. +        if not 25 >= amount > 0: +            raise ValueError("Invalid amount of subreddit posts requested.") + +        # Renew the token if necessary. +        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token() + +        url = f"{self.OAUTH_URL}/{route}" +        for _ in range(self.MAX_RETRIES): +            response = await self.bot.http_session.get( +                url=url, +                headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, +                params=params +            ) +            if response.status == 200 and response.content_type == 'application/json': +                # Got appropriate response - process and return. +                content = await response.json() +                posts = content["data"]["children"] + +                filtered_posts = [post for post in posts if not post["data"]["over_18"]] + +                return filtered_posts[:amount] + +            await asyncio.sleep(3) + +        log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") +        return list()  # Failed to get appropriate response within allowed number of retries. + +    async def get_top_posts( +            self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False +    ) -> Union[Embed, List[tuple]]: +        """ +        Get the top amount of posts for a given subreddit within a specified timeframe. + +        A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top +        weekly posts. + +        The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. +        """ +        embed = Embed(description="") + +        posts = await self.fetch_posts( +            route=f"{subreddit}/top", +            amount=amount, +            params={"t": time} +        ) +        if not posts: +            embed.title = random.choice(ERROR_REPLIES) +            embed.colour = Colour.red() +            embed.description = ( +                "Sorry! We couldn't find any SFW posts from that subreddit. " +                "If this problem persists, please let us know." +            ) + +            return embed + +        pages = self.build_pagination_pages(posts) + +        if paginate: +            return pages + +        embed.description += pages[0] +        embed.colour = Colour.blurple() +        return embed + +    @loop() +    async def auto_poster_loop(self) -> None: +        """Post the top 5 posts daily, and the top 5 posts weekly.""" +        # once d.py get support for `time` parameter in loop decorator, +        # this can be removed and the loop can use the `time=datetime.time.min` parameter +        now = datetime.utcnow() +        tomorrow = now + timedelta(days=1) +        midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + +        await sleep_until(midnight_tomorrow) + +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            await self.bot.fetch_webhook(RedditConfig.webhook) + +        if datetime.utcnow().weekday() == 0: +            await self.top_weekly_posts() +            # if it's a monday send the top weekly posts + +        for subreddit in RedditConfig.subreddits: +            top_posts = await self.get_top_posts(subreddit=subreddit, time="day") +            username = sub_clyde(f"{subreddit} Top Daily Posts") +            message = await self.webhook.send(username=username, embed=top_posts, wait=True) + +            if message.channel.is_news(): +                await message.publish() + +    async def top_weekly_posts(self) -> None: +        """Post a summary of the top posts.""" +        for subreddit in RedditConfig.subreddits: +            # Send and pin the new weekly posts. +            top_posts = await self.get_top_posts(subreddit=subreddit, time="week") +            username = sub_clyde(f"{subreddit} Top Weekly Posts") +            message = await self.webhook.send(wait=True, username=username, embed=top_posts) + +            if subreddit.lower() == "r/python": +                if not self.channel: +                    log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") +                    return + +                # Remove the oldest pins so that only 12 remain at most. +                pins = await self.channel.pins() + +                while len(pins) >= 12: +                    await pins[-1].unpin() +                    del pins[-1] + +                await message.pin() + +                if message.channel.is_news(): +                    await message.publish() + +    @group(name="reddit", invoke_without_command=True) +    async def reddit_group(self, ctx: Context) -> None: +        """View the top posts from various subreddits.""" +        await ctx.send_help(ctx.command) + +    @reddit_group.command(name="top") +    async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of all time from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="all", paginate=True) + +        embed = Embed( +            title=f"{Emojis.reddit} {subreddit} - Top\n\n", +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @reddit_group.command(name="daily") +    async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of today from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="day", paginate=True) + +        embed = Embed( +            title=f"{Emojis.reddit} {subreddit} - Daily\n\n", +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @reddit_group.command(name="weekly") +    async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of this week from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="week", paginate=True) + +        embed = Embed( +            title=f"{Emojis.reddit} {subreddit} - Weekly\n\n", +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @has_any_role(*STAFF_ROLES) +    @reddit_group.command(name="subreddits", aliases=("subs",)) +    async def subreddits_command(self, ctx: Context) -> None: +        """Send a paginated embed of all the subreddits we're relaying.""" +        embed = Embed() +        embed.title = "Relayed subreddits." +        embed.colour = Colour.blurple() + +        await LinePaginator.paginate( +            RedditConfig.subreddits, +            ctx, embed, +            footer_text="Use the reddit commands along with these to view their posts.", +            empty=False, +            max_lines=15 +        ) + + +def setup(bot: Bot) -> None: +    """Load the Reddit cog.""" +    if not RedditConfig.secret or not RedditConfig.client_id: +        log.error("Credentials not provided, cog not loaded.") +        return +    bot.add_cog(Reddit(bot)) | 
