diff options
| author | 2021-05-09 13:49:32 -0700 | |
|---|---|---|
| committer | 2021-05-09 13:49:32 -0700 | |
| commit | ba886c24fe6fa583cf090dace8870ff9d7263383 (patch) | |
| tree | fbbbcf61ef0797722a1ea8136aaa412cd6283365 | |
| parent | Merge pull request #724 from ToxicKidz/fix/mosaic-command (diff) | |
| parent | Pull upstream and resolve conflicts. (diff) | |
Merge pull request #569 from RohanJnr/reddit_migration
Reddit migration
| -rw-r--r-- | bot/constants.py | 19 | ||||
| -rw-r--r-- | bot/exts/evergreen/reddit.py | 425 | ||||
| -rw-r--r-- | bot/utils/converters.py | 32 | ||||
| -rw-r--r-- | bot/utils/messages.py | 19 | ||||
| -rw-r--r-- | bot/utils/pagination.py | 6 | 
5 files changed, 404 insertions, 97 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 6cbfa8db..549d01b6 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -20,6 +20,7 @@ __all__ = (      "Roles",      "Tokens",      "Wolfram", +    "Reddit",      "RedisConfig",      "MODERATION_ROLES",      "STAFF_ROLES", @@ -115,6 +116,7 @@ class Channels(NamedTuple):      voice_chat_0 = 412357430186344448      voice_chat_1 = 799647045886541885      staff_voice = 541638762007101470 +    reddit = int(environ.get("CHANNEL_REDDIT", 458224812528238616))  class Categories(NamedTuple): @@ -217,6 +219,15 @@ class Emojis:      status_dnd = "<:status_dnd:470326272082313216>"      status_offline = "<:status_offline:470326266537705472>" +    # Reddit emojis +    reddit = "<:reddit:676030265734332427>" +    reddit_post_text = "<:reddit_post_text:676030265910493204>" +    reddit_post_video = "<:reddit_post_video:676030265839190047>" +    reddit_post_photo = "<:reddit_post_photo:676030265734201344>" +    reddit_upvote = "<:reddit_upvote:755845219890757644>" +    reddit_comments = "<:reddit_comments:755845255001014384>" +    reddit_users = "<:reddit_users:755845303822974997>" +  class Icons:      questionmark = "https://cdn.discordapp.com/emojis/512367613339369475.png" @@ -293,6 +304,14 @@ class Source:      github_avatar_url = "https://avatars1.githubusercontent.com/u/9919" +class Reddit: +    subreddits = ["r/Python"] + +    client_id = environ.get("REDDIT_CLIENT_ID") +    secret = environ.get("REDDIT_SECRET") +    webhook = int(environ.get("REDDIT_WEBHOOK", 635408384794951680)) + +  # Default role combinations  MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner  STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner diff --git a/bot/exts/evergreen/reddit.py b/bot/exts/evergreen/reddit.py index 2be511c8..916563ac 100644 --- a/bot/exts/evergreen/reddit.py +++ b/bot/exts/evergreen/reddit.py @@ -1,128 +1,367 @@ +import asyncio  import logging  import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List, Union -import discord -from discord.ext import commands -from discord.ext.commands.cooldowns import BucketType +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.tasks import loop +from discord.utils import escape_markdown, sleep_until -from bot.utils.pagination import ImagePaginator +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES +from bot.utils.converters import Subreddit +from bot.utils.extensions import invoke_help_command +from bot.utils.messages import sub_clyde +from bot.utils.pagination import ImagePaginator, LinePaginator  log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) -class Reddit(commands.Cog): -    """Fetches reddit posts.""" -    def __init__(self, bot: commands.Bot): +class Reddit(Cog): +    """Track subreddit posts and show detailed statistics about them.""" + +    HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} +    URL = "https://www.reddit.com" +    OAUTH_URL = "https://oauth.reddit.com" +    MAX_RETRIES = 3 + +    def __init__(self, bot: Bot):          self.bot = bot -    async def fetch(self, url: str) -> dict: -        """Send a get request to the reddit API and get json response.""" -        session = self.bot.http_session -        params = { -            'limit': 50 -        } -        headers = { -            'User-Agent': 'Iceman' -        } - -        async with session.get(url=url, params=params, headers=headers) as response: -            return await response.json() - -    @commands.command(name='reddit') -    @commands.cooldown(1, 10, BucketType.user) -    async def get_reddit(self, ctx: commands.Context, subreddit: str = 'python', sort: str = "hot") -> None: -        """ -        Fetch reddit posts by using this command. +        self.webhook = None +        self.access_token = None +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) -        Gets a post from r/python by default. -        Usage: -        --> .reddit [subreddit_name] [hot/top/new] -        """ +        bot.loop.create_task(self.init_reddit_ready()) +        self.auto_poster_loop.start() + +    def cog_unload(self) -> None: +        """Stop the loop task and revoke the access token when the cog is unloaded.""" +        self.auto_poster_loop.cancel() +        if self.access_token and self.access_token.expires_at > datetime.utcnow(): +            asyncio.create_task(self.revoke_access_token()) + +    async def init_reddit_ready(self) -> None: +        """Sets the reddit webhook when the cog is loaded.""" +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            self.webhook = await self.bot.fetch_webhook(RedditConfig.webhook) + +    @property +    def channel(self) -> TextChannel: +        """Get the #reddit channel object from the bot's cache.""" +        return self.bot.get_channel(Channels.reddit) + +    def build_pagination_pages(self, posts: List[dict], paginate: bool) -> Union[List[tuple], str]: +        """Build embed pages required for Paginator."""          pages = [] -        sort_list = ["hot", "new", "top", "rising"] -        if sort.lower() not in sort_list: -            await ctx.send(f"Invalid sorting: {sort}\nUsing default sorting: `Hot`") -            sort = "hot" +        first_page = "" +        for post in posts: +            post_page = "" +            image_url = "" -        data = await self.fetch(f'https://www.reddit.com/r/{subreddit}/{sort}/.json') +            data = post["data"] -        try: -            posts = data["data"]["children"] -        except KeyError: -            return await ctx.send('Subreddit not found!') -        if not posts: -            return await ctx.send('No posts available!') +            title = textwrap.shorten(data["title"], width=50, placeholder="...") + +            # Normal brackets interfere with Markdown. +            title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") +            link = self.URL + data["permalink"] + +            first_page += f"**[{title.replace('*', '')}]({link})**\n" + +            text = data["selftext"] +            if text: +                first_page += textwrap.shorten(text, width=100, placeholder="...").replace("*", "") + "\n" + +            ups = data["ups"] +            comments = data["num_comments"] +            author = data["author"] + +            content_type = Emojis.reddit_post_text +            if data["is_video"] or {"youtube", "youtu.be"}.issubset(set(data["url"].split("."))): +                # This means the content type in the post is a video. +                content_type = f"{Emojis.reddit_post_video}" + +            elif data["url"].endswith(("jpg", "png", "gif")): +                # This means the content type in the post is an image. +                content_type = f"{Emojis.reddit_post_photo}" +                image_url = data["url"] -        if posts[0]["data"]["over_18"] is True: -            return await ctx.send( -                "You cannot access this Subreddit as it is meant for those who " -                "are 18 years or older." +            first_page += ( +                f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}" +                f"\u2002{comments}\u2003{Emojis.reddit_users}{author}\n\n"              ) -        embed_titles = "" +            if paginate: +                post_page += f"**[{title}]({link})**\n\n" +                if text: +                    post_page += textwrap.shorten(text, width=252, placeholder="...") + "\n\n" +                post_page += ( +                    f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}\u2002" +                    f"{comments}\u2003{Emojis.reddit_users}{author}" +                ) -        # Chooses k unique random elements from a population sequence or set. -        random_posts = random.sample(posts, k=min(len(posts), 5)) +                pages.append((post_page, image_url)) -        # ----------------------------------------------------------- -        # This code below is bound of change when the emojis are added. +        if not paginate: +            # Return the first summery page if pagination is not required +            return first_page -        upvote_emoji = self.bot.get_emoji(755845219890757644) -        comment_emoji = self.bot.get_emoji(755845255001014384) -        user_emoji = self.bot.get_emoji(755845303822974997) -        text_emoji = self.bot.get_emoji(676030265910493204) -        video_emoji = self.bot.get_emoji(676030265839190047) -        image_emoji = self.bot.get_emoji(676030265734201344) -        reddit_emoji = self.bot.get_emoji(676030265734332427) +        pages.insert(0, (first_page, ""))  # Using image paginator, hence settings image url to empty string +        return pages -        # ------------------------------------------------------------ +    async def get_access_token(self) -> None: +        """ +        Get a Reddit API OAuth2 access token and assign it to self.access_token. -        for i, post in enumerate(random_posts, start=1): -            post_title = post["data"]["title"][0:50] -            post_url = post['data']['url'] -            if post_title == "": -                post_title = "No Title." -            elif post_title == post_url: -                post_title = "Title is itself a link." +        A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog +        will be unloaded and a ClientError raised if retrieval was still unsuccessful. +        """ +        for i in range(1, self.MAX_RETRIES + 1): +            response = await self.bot.http_session.post( +                url=f"{self.URL}/api/v1/access_token", +                headers=self.HEADERS, +                auth=self.client_auth, +                data={ +                    "grant_type": "client_credentials", +                    "duration": "temporary" +                } +            ) -            # ------------------------------------------------------------------ -            # Embed building. +            if response.status == 200 and response.content_type == "application/json": +                content = await response.json() +                expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway. +                self.access_token = AccessToken( +                    token=content["access_token"], +                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) +                ) -            embed_titles += f"**{i}.[{post_title}]({post_url})**\n" -            image_url = " " -            post_stats = f"{text_emoji}"  # Set default content type to text. +                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") +                return +            else: +                log.debug( +                    f"Failed to get an access token: " +                    f"status {response.status} & content type {response.content_type}; " +                    f"retrying ({i}/{self.MAX_RETRIES})" +                ) -            if post["data"]["is_video"] is True or "youtube" in post_url.split("."): -                # This means the content type in the post is a video. -                post_stats = f"{video_emoji} " +            await asyncio.sleep(3) -            elif post_url.endswith("jpg") or post_url.endswith("png") or post_url.endswith("gif"): -                # This means the content type in the post is an image. -                post_stats = f"{image_emoji} " -                image_url = post_url - -            votes = f'{upvote_emoji}{post["data"]["ups"]}' -            comments = f'{comment_emoji}\u2002{ post["data"]["num_comments"]}' -            post_stats += ( -                f"\u2002{votes}\u2003" -                f"{comments}" -                f'\u2003{user_emoji}\u2002{post["data"]["author"]}\n' +        self.bot.remove_cog(self.qualified_name) +        raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + +    async def revoke_access_token(self) -> None: +        """ +        Revoke the OAuth2 access token for the Reddit API. + +        For security reasons, it's good practice to revoke the token when it's no longer being used. +        """ +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/revoke_token", +            headers=self.HEADERS, +            auth=self.client_auth, +            data={ +                "token": self.access_token.token, +                "token_type_hint": "access_token" +            } +        ) + +        if response.status == 204 and response.content_type == "application/json": +            self.access_token = None +        else: +            log.warning(f"Unable to revoke access token: status {response.status}.") + +    async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: +        """A helper method to fetch a certain amount of Reddit posts at a given route.""" +        # Reddit's JSON responses only provide 25 posts at most. +        if not 25 >= amount > 0: +            raise ValueError("Invalid amount of subreddit posts requested.") + +        # Renew the token if necessary. +        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token() + +        url = f"{self.OAUTH_URL}/{route}" +        for _ in range(self.MAX_RETRIES): +            response = await self.bot.http_session.get( +                url=url, +                headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, +                params=params +            ) +            if response.status == 200 and response.content_type == 'application/json': +                # Got appropriate response - process and return. +                content = await response.json() +                posts = content["data"]["children"] + +                filtered_posts = [post for post in posts if not post["data"]["over_18"]] + +                return filtered_posts[:amount] + +            await asyncio.sleep(3) + +        log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") +        return list()  # Failed to get appropriate response within allowed number of retries. + +    async def get_top_posts( +            self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False +    ) -> Union[Embed, List[tuple]]: +        """ +        Get the top amount of posts for a given subreddit within a specified timeframe. + +        A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top +        weekly posts. + +        The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. +        """ +        embed = Embed() + +        posts = await self.fetch_posts( +            route=f"{subreddit}/top", +            amount=amount, +            params={"t": time} +        ) +        if not posts: +            embed.title = random.choice(ERROR_REPLIES) +            embed.colour = Colour.red() +            embed.description = ( +                "Sorry! We couldn't find any SFW posts from that subreddit. " +                "If this problem persists, please let us know."              ) -            embed_titles += f"{post_stats}\n" -            page_text = f"**[{post_title}]({post_url})**\n{post_stats}\n{post['data']['selftext'][0:200]}" -            embed = discord.Embed() -            page_tuple = (page_text, image_url) -            pages.append(page_tuple) +            return embed + +        if paginate: +            return self.build_pagination_pages(posts, paginate=True) + +        # Use only starting summary page for #reddit channel posts. +        embed.description = self.build_pagination_pages(posts, paginate=False) +        embed.colour = Colour.blurple() +        return embed + +    @loop() +    async def auto_poster_loop(self) -> None: +        """Post the top 5 posts daily, and the top 5 posts weekly.""" +        # once d.py get support for `time` parameter in loop decorator, +        # this can be removed and the loop can use the `time=datetime.time.min` parameter +        now = datetime.utcnow() +        tomorrow = now + timedelta(days=1) +        midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + +        await sleep_until(midnight_tomorrow) + +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            await self.bot.fetch_webhook(RedditConfig.webhook) + +        if datetime.utcnow().weekday() == 0: +            await self.top_weekly_posts() +            # if it's a monday send the top weekly posts + +        for subreddit in RedditConfig.subreddits: +            top_posts = await self.get_top_posts(subreddit=subreddit, time="day") +            username = sub_clyde(f"{subreddit} Top Daily Posts") +            message = await self.webhook.send(username=username, embed=top_posts, wait=True) + +            if message.channel.is_news(): +                await message.publish() + +    async def top_weekly_posts(self) -> None: +        """Post a summary of the top posts.""" +        for subreddit in RedditConfig.subreddits: +            # Send and pin the new weekly posts. +            top_posts = await self.get_top_posts(subreddit=subreddit, time="week") +            username = sub_clyde(f"{subreddit} Top Weekly Posts") +            message = await self.webhook.send(wait=True, username=username, embed=top_posts) -            # ------------------------------------------------------------------ +            if subreddit.lower() == "r/python": +                if not self.channel: +                    log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") +                    return + +                # Remove the oldest pins so that only 12 remain at most. +                pins = await self.channel.pins() + +                while len(pins) >= 12: +                    await pins[-1].unpin() +                    del pins[-1] + +                await message.pin() + +                if message.channel.is_news(): +                    await message.publish() + +    @group(name="reddit", invoke_without_command=True) +    async def reddit_group(self, ctx: Context) -> None: +        """View the top posts from various subreddits.""" +        await invoke_help_command(ctx) + +    @reddit_group.command(name="top") +    async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of all time from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="all", paginate=True) + +        await ctx.send(f"Here are the top {subreddit} posts of all time!") +        embed = Embed( +            color=Colour.blurple() +        ) -        pages.insert(0, (embed_titles, " ")) -        embed.set_author(name=f"r/{posts[0]['data']['subreddit']} - {sort}", icon_url=reddit_emoji.url)          await ImagePaginator.paginate(pages, ctx, embed) +    @reddit_group.command(name="daily") +    async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of today from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="day", paginate=True) + +        await ctx.send(f"Here are today's top {subreddit} posts!") +        embed = Embed( +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @reddit_group.command(name="weekly") +    async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of this week from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="week", paginate=True) + +        await ctx.send(f"Here are this week's top {subreddit} posts!") +        embed = Embed( +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @has_any_role(*STAFF_ROLES) +    @reddit_group.command(name="subreddits", aliases=("subs",)) +    async def subreddits_command(self, ctx: Context) -> None: +        """Send a paginated embed of all the subreddits we're relaying.""" +        embed = Embed() +        embed.title = "Relayed subreddits." +        embed.colour = Colour.blurple() + +        await LinePaginator.paginate( +            RedditConfig.subreddits, +            ctx, embed, +            footer_text="Use the reddit commands along with these to view their posts.", +            empty=False, +            max_lines=15 +        ) + -def setup(bot: commands.Bot) -> None: -    """Load the Cog.""" +def setup(bot: Bot) -> None: +    """Load the Reddit cog.""" +    if not RedditConfig.secret or not RedditConfig.client_id: +        log.error("Credentials not provided, cog not loaded.") +        return      bot.add_cog(Reddit(bot)) diff --git a/bot/utils/converters.py b/bot/utils/converters.py index 228714c9..27804170 100644 --- a/bot/utils/converters.py +++ b/bot/utils/converters.py @@ -1,5 +1,6 @@  import discord -from discord.ext.commands.converter import MessageConverter +from discord.ext.commands import BadArgument, Context +from discord.ext.commands.converter import Converter, MessageConverter  class WrappedMessageConverter(MessageConverter): @@ -14,3 +15,32 @@ class WrappedMessageConverter(MessageConverter):              argument = argument[1:-1]          return await super().convert(ctx, argument) + + +class Subreddit(Converter): +    """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" + +    @staticmethod +    async def convert(ctx: Context, sub: str) -> str: +        """ +        Force sub to begin with "r/" and check if it's a valid subreddit. + +        If sub is a valid subreddit, return it prepended with "r/" +        """ +        sub = sub.lower() + +        if not sub.startswith("r/"): +            sub = f"r/{sub}" + +        resp = await ctx.bot.http_session.get( +            "https://www.reddit.com/subreddits/search.json", +            params={"q": sub} +        ) + +        json = await resp.json() +        if not json["data"]["children"]: +            raise BadArgument( +                f"The subreddit `{sub}` either doesn't exist, or it has no posts." +            ) + +        return sub diff --git a/bot/utils/messages.py b/bot/utils/messages.py new file mode 100644 index 00000000..a6c035f9 --- /dev/null +++ b/bot/utils/messages.py @@ -0,0 +1,19 @@ +import re +from typing import Optional + + +def sub_clyde(username: Optional[str]) -> Optional[str]: +    """ +    Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"E" and return the new string. + +    Discord disallows "clyde" anywhere in the username for webhooks. It will return a 400. +    Return None only if `username` is None. +    """ +    def replace_e(match: re.Match) -> str: +        char = "е" if match[2] == "e" else "Е" +        return match[1] + char + +    if username: +        return re.sub(r"(clyd)(e)", replace_e, username, flags=re.I) +    else: +        return username  # Empty string or None diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index a4d0cc56..917275c0 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -4,6 +4,7 @@ from typing import Iterable, List, Optional, Tuple  from discord import Embed, Member, Reaction  from discord.abc import User +from discord.embeds import EmptyEmbed  from discord.ext.commands import Context, Paginator  from bot.constants import Emojis @@ -417,9 +418,8 @@ class ImagePaginator(Paginator):              await message.edit(embed=embed)              embed.description = paginator.pages[current_page] -            image = paginator.images[current_page] -            if image: -                embed.set_image(url=image) +            image = paginator.images[current_page] or EmptyEmbed +            embed.set_image(url=image)              embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")              log.debug(f"Got {reaction_type} page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") | 
