diff options
| -rw-r--r-- | bot/constants.py | 19 | ||||
| -rw-r--r-- | bot/exts/evergreen/avatar_modification/avatar_modify.py | 70 | ||||
| -rw-r--r-- | bot/exts/evergreen/reddit.py | 425 | ||||
| -rw-r--r-- | bot/utils/converters.py | 32 | ||||
| -rw-r--r-- | bot/utils/messages.py | 19 | ||||
| -rw-r--r-- | bot/utils/pagination.py | 6 | 
6 files changed, 440 insertions, 131 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 884cf3a8..9b45b89a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -20,6 +20,7 @@ __all__ = (      "Roles",      "Tokens",      "Wolfram", +    "Reddit",      "RedisConfig",      "MODERATION_ROLES",      "STAFF_ROLES", @@ -115,6 +116,7 @@ class Channels(NamedTuple):      voice_chat_0 = 412357430186344448      voice_chat_1 = 799647045886541885      staff_voice = 541638762007101470 +    reddit = int(environ.get("CHANNEL_REDDIT", 458224812528238616))  class Categories(NamedTuple): @@ -218,6 +220,15 @@ class Emojis:      status_dnd = "<:status_dnd:470326272082313216>"      status_offline = "<:status_offline:470326266537705472>" +    # Reddit emojis +    reddit = "<:reddit:676030265734332427>" +    reddit_post_text = "<:reddit_post_text:676030265910493204>" +    reddit_post_video = "<:reddit_post_video:676030265839190047>" +    reddit_post_photo = "<:reddit_post_photo:676030265734201344>" +    reddit_upvote = "<:reddit_upvote:755845219890757644>" +    reddit_comments = "<:reddit_comments:755845255001014384>" +    reddit_users = "<:reddit_users:755845303822974997>" +  class Icons:      questionmark = "https://cdn.discordapp.com/emojis/512367613339369475.png" @@ -294,6 +305,14 @@ class Source:      github_avatar_url = "https://avatars1.githubusercontent.com/u/9919" +class Reddit: +    subreddits = ["r/Python"] + +    client_id = environ.get("REDDIT_CLIENT_ID") +    secret = environ.get("REDDIT_SECRET") +    webhook = int(environ.get("REDDIT_WEBHOOK", 635408384794951680)) + +  # Default role combinations  MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner  STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner diff --git a/bot/exts/evergreen/avatar_modification/avatar_modify.py b/bot/exts/evergreen/avatar_modification/avatar_modify.py index 2afc3b74..693d15c7 100644 --- a/bot/exts/evergreen/avatar_modification/avatar_modify.py +++ b/bot/exts/evergreen/avatar_modification/avatar_modify.py @@ -12,7 +12,7 @@ from aiohttp import client_exceptions  from discord.ext import commands  from discord.ext.commands.errors import BadArgument -from bot.constants import Client, Colours, Emojis +from bot.constants import Colours, Emojis  from bot.exts.evergreen.avatar_modification._effects import PfpEffects  from bot.utils.extensions import invoke_help_command  from bot.utils.halloween import spookifications @@ -66,23 +66,25 @@ class AvatarModify(commands.Cog):      def __init__(self, bot: commands.Bot) -> None:          self.bot = bot -    async def _fetch_member(self, member_id: int) -> t.Optional[discord.Member]: +    async def _fetch_user(self, user_id: int) -> t.Optional[discord.User]:          """ -        Fetches a member and handles errors. +        Fetches a user and handles errors. -        This helper funciton is required as the member cache doesn't always have the most up to date +        This helper function is required as the member cache doesn't always have the most up to date          profile picture. This can lead to errors if the image is delted from the Discord CDN. +        fetch_member can't be used due to the avatar url being part of the user object, and +        some weird caching that D.py does          """          try: -            member = await self.bot.get_guild(Client.guild).fetch_member(member_id) +            user = await self.bot.fetch_user(user_id)          except discord.errors.NotFound: -            log.debug(f"Member {member_id} left the guild before we could get their pfp.") +            log.debug(f"User {user_id} could not be found.")              return None          except discord.HTTPException: -            log.exception(f"Exception while trying to retrieve member {member_id} from Discord.") +            log.exception(f"Exception while trying to retrieve user {user_id} from Discord.")              return None -        return member +        return user      @commands.group(aliases=("avatar_mod", "pfp_mod", "avatarmod", "pfpmod"))      async def avatar_modify(self, ctx: commands.Context) -> None: @@ -94,13 +96,13 @@ class AvatarModify(commands.Cog):      async def eightbit_command(self, ctx: commands.Context) -> None:          """Pixelates your avatar and changes the palette to an 8bit one."""          async with ctx.typing(): -            member = await self._fetch_member(ctx.author.id) -            if not member: -                await ctx.send(f"{Emojis.cross_mark} Could not get member info.") +            user = await self._fetch_user(ctx.author.id) +            if not user: +                await ctx.send(f"{Emojis.cross_mark} Could not get user info.")                  return -            image_bytes = await member.avatar_url.read() -            file_name = file_safe_name("eightbit_avatar", member.display_name) +            image_bytes = await user.avatar_url_as(size=1024).read() +            file_name = file_safe_name("eightbit_avatar", ctx.author.display_name)              file = await in_executor(                  PfpEffects.apply_effect, @@ -115,7 +117,7 @@ class AvatarModify(commands.Cog):              )              embed.set_image(url=f"attachment://{file_name}") -            embed.set_footer(text=f"Made by {member.display_name}.", icon_url=member.avatar_url) +            embed.set_footer(text=f"Made by {ctx.author.display_name}.", icon_url=user.avatar_url)          await ctx.send(embed=embed, file=file) @@ -140,9 +142,9 @@ class AvatarModify(commands.Cog):                  return args[0]          async with ctx.typing(): -            member = await self._fetch_member(ctx.author.id) -            if not member: -                await ctx.send(f"{Emojis.cross_mark} Could not get member info.") +            user = await self._fetch_user(ctx.author.id) +            if not user: +                await ctx.send(f"{Emojis.cross_mark} Could not get user info.")                  return              egg = None @@ -155,8 +157,8 @@ class AvatarModify(commands.Cog):                      return                  ctx.send = send_message  # Reassigns ctx.send -            image_bytes = await member.avatar_url_as(size=256).read() -            file_name = file_safe_name("easterified_avatar", member.display_name) +            image_bytes = await user.avatar_url_as(size=256).read() +            file_name = file_safe_name("easterified_avatar", ctx.author.display_name)              file = await in_executor(                  PfpEffects.apply_effect, @@ -171,7 +173,7 @@ class AvatarModify(commands.Cog):                  description="Here is your lovely avatar, all bright and colourful\nwith Easter pastel colours. Enjoy :D"              )              embed.set_image(url=f"attachment://{file_name}") -            embed.set_footer(text=f"Made by {member.display_name}.", icon_url=member.avatar_url) +            embed.set_footer(text=f"Made by {ctx.author.display_name}.", icon_url=user.avatar_url)          await ctx.send(file=file, embed=embed) @@ -226,11 +228,11 @@ class AvatarModify(commands.Cog):              return          async with ctx.typing(): -            member = await self._fetch_member(ctx.author.id) -            if not member: -                await ctx.send(f"{Emojis.cross_mark} Could not get member info.") +            user = await self._fetch_user(ctx.author.id) +            if not user: +                await ctx.send(f"{Emojis.cross_mark} Could not get user info.")                  return -            image_bytes = await member.avatar_url_as(size=1024).read() +            image_bytes = await user.avatar_url_as(size=1024).read()              await self.send_pride_image(ctx, image_bytes, pixels, flag, option)      @prideavatar.command() @@ -286,13 +288,13 @@ class AvatarModify(commands.Cog):          if member is None:              member = ctx.author -        member = await self._fetch_member(member.id) -        if not member: -            await ctx.send(f"{Emojis.cross_mark} Could not get member info.") +        user = await self._fetch_user(member.id) +        if not user: +            await ctx.send(f"{Emojis.cross_mark} Could not get user info.")              return          async with ctx.typing(): -            image_bytes = await member.avatar_url.read() +            image_bytes = await user.avatar_url_as(size=1024).read()              file_name = file_safe_name("spooky_avatar", member.display_name) @@ -317,9 +319,9 @@ class AvatarModify(commands.Cog):      async def mosaic_command(self, ctx: commands.Context, squares: int = 16) -> None:          """Splits your avatar into x squares, randomizes them and stitches them back into a new image!"""          async with ctx.typing(): -            member = await self._fetch_member(ctx.author.id) -            if not member: -                await ctx.send(f"{Emojis.cross_mark} Could not get member info.") +            user = await self._fetch_user(ctx.author.id) +            if not user: +                await ctx.send(f"{Emojis.cross_mark} Could not get user info.")                  return              if not 1 <= squares <= MAX_SQUARES: @@ -332,7 +334,7 @@ class AvatarModify(commands.Cog):              file_name = file_safe_name("mosaic_avatar", ctx.author.display_name) -            img_bytes = await member.avatar_url.read() +            img_bytes = await user.avatar_url_as(size=1024).read()              file = await in_executor(                  PfpEffects.mosaic_effect, @@ -349,7 +351,7 @@ class AvatarModify(commands.Cog):                  description = "What a masterpiece. :star:"              else:                  title = "Your mosaic avatar" -                description = "Here is your avatar. I think it looks a bit *puzzling*" +                description = f"Here is your avatar. I think it looks a bit *puzzling*\nMade with {squares} squares."              embed = discord.Embed(                  title=title, @@ -358,7 +360,7 @@ class AvatarModify(commands.Cog):              )              embed.set_image(url=f"attachment://{file_name}") -            embed.set_footer(text=f"Made by {ctx.author.display_name}", icon_url=ctx.author.avatar_url) +            embed.set_footer(text=f"Made by {ctx.author.display_name}", icon_url=user.avatar_url)              await ctx.send(file=file, embed=embed) diff --git a/bot/exts/evergreen/reddit.py b/bot/exts/evergreen/reddit.py index 2be511c8..916563ac 100644 --- a/bot/exts/evergreen/reddit.py +++ b/bot/exts/evergreen/reddit.py @@ -1,128 +1,367 @@ +import asyncio  import logging  import random +import textwrap +from collections import namedtuple +from datetime import datetime, timedelta +from typing import List, Union -import discord -from discord.ext import commands -from discord.ext.commands.cooldowns import BucketType +from aiohttp import BasicAuth, ClientError +from discord import Colour, Embed, TextChannel +from discord.ext.commands import Cog, Context, group, has_any_role +from discord.ext.tasks import loop +from discord.utils import escape_markdown, sleep_until -from bot.utils.pagination import ImagePaginator +from bot.bot import Bot +from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES +from bot.utils.converters import Subreddit +from bot.utils.extensions import invoke_help_command +from bot.utils.messages import sub_clyde +from bot.utils.pagination import ImagePaginator, LinePaginator  log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) -class Reddit(commands.Cog): -    """Fetches reddit posts.""" -    def __init__(self, bot: commands.Bot): +class Reddit(Cog): +    """Track subreddit posts and show detailed statistics about them.""" + +    HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} +    URL = "https://www.reddit.com" +    OAUTH_URL = "https://oauth.reddit.com" +    MAX_RETRIES = 3 + +    def __init__(self, bot: Bot):          self.bot = bot -    async def fetch(self, url: str) -> dict: -        """Send a get request to the reddit API and get json response.""" -        session = self.bot.http_session -        params = { -            'limit': 50 -        } -        headers = { -            'User-Agent': 'Iceman' -        } - -        async with session.get(url=url, params=params, headers=headers) as response: -            return await response.json() - -    @commands.command(name='reddit') -    @commands.cooldown(1, 10, BucketType.user) -    async def get_reddit(self, ctx: commands.Context, subreddit: str = 'python', sort: str = "hot") -> None: -        """ -        Fetch reddit posts by using this command. +        self.webhook = None +        self.access_token = None +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) -        Gets a post from r/python by default. -        Usage: -        --> .reddit [subreddit_name] [hot/top/new] -        """ +        bot.loop.create_task(self.init_reddit_ready()) +        self.auto_poster_loop.start() + +    def cog_unload(self) -> None: +        """Stop the loop task and revoke the access token when the cog is unloaded.""" +        self.auto_poster_loop.cancel() +        if self.access_token and self.access_token.expires_at > datetime.utcnow(): +            asyncio.create_task(self.revoke_access_token()) + +    async def init_reddit_ready(self) -> None: +        """Sets the reddit webhook when the cog is loaded.""" +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            self.webhook = await self.bot.fetch_webhook(RedditConfig.webhook) + +    @property +    def channel(self) -> TextChannel: +        """Get the #reddit channel object from the bot's cache.""" +        return self.bot.get_channel(Channels.reddit) + +    def build_pagination_pages(self, posts: List[dict], paginate: bool) -> Union[List[tuple], str]: +        """Build embed pages required for Paginator."""          pages = [] -        sort_list = ["hot", "new", "top", "rising"] -        if sort.lower() not in sort_list: -            await ctx.send(f"Invalid sorting: {sort}\nUsing default sorting: `Hot`") -            sort = "hot" +        first_page = "" +        for post in posts: +            post_page = "" +            image_url = "" -        data = await self.fetch(f'https://www.reddit.com/r/{subreddit}/{sort}/.json') +            data = post["data"] -        try: -            posts = data["data"]["children"] -        except KeyError: -            return await ctx.send('Subreddit not found!') -        if not posts: -            return await ctx.send('No posts available!') +            title = textwrap.shorten(data["title"], width=50, placeholder="...") + +            # Normal brackets interfere with Markdown. +            title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") +            link = self.URL + data["permalink"] + +            first_page += f"**[{title.replace('*', '')}]({link})**\n" + +            text = data["selftext"] +            if text: +                first_page += textwrap.shorten(text, width=100, placeholder="...").replace("*", "") + "\n" + +            ups = data["ups"] +            comments = data["num_comments"] +            author = data["author"] + +            content_type = Emojis.reddit_post_text +            if data["is_video"] or {"youtube", "youtu.be"}.issubset(set(data["url"].split("."))): +                # This means the content type in the post is a video. +                content_type = f"{Emojis.reddit_post_video}" + +            elif data["url"].endswith(("jpg", "png", "gif")): +                # This means the content type in the post is an image. +                content_type = f"{Emojis.reddit_post_photo}" +                image_url = data["url"] -        if posts[0]["data"]["over_18"] is True: -            return await ctx.send( -                "You cannot access this Subreddit as it is meant for those who " -                "are 18 years or older." +            first_page += ( +                f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}" +                f"\u2002{comments}\u2003{Emojis.reddit_users}{author}\n\n"              ) -        embed_titles = "" +            if paginate: +                post_page += f"**[{title}]({link})**\n\n" +                if text: +                    post_page += textwrap.shorten(text, width=252, placeholder="...") + "\n\n" +                post_page += ( +                    f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}\u2002" +                    f"{comments}\u2003{Emojis.reddit_users}{author}" +                ) -        # Chooses k unique random elements from a population sequence or set. -        random_posts = random.sample(posts, k=min(len(posts), 5)) +                pages.append((post_page, image_url)) -        # ----------------------------------------------------------- -        # This code below is bound of change when the emojis are added. +        if not paginate: +            # Return the first summery page if pagination is not required +            return first_page -        upvote_emoji = self.bot.get_emoji(755845219890757644) -        comment_emoji = self.bot.get_emoji(755845255001014384) -        user_emoji = self.bot.get_emoji(755845303822974997) -        text_emoji = self.bot.get_emoji(676030265910493204) -        video_emoji = self.bot.get_emoji(676030265839190047) -        image_emoji = self.bot.get_emoji(676030265734201344) -        reddit_emoji = self.bot.get_emoji(676030265734332427) +        pages.insert(0, (first_page, ""))  # Using image paginator, hence settings image url to empty string +        return pages -        # ------------------------------------------------------------ +    async def get_access_token(self) -> None: +        """ +        Get a Reddit API OAuth2 access token and assign it to self.access_token. -        for i, post in enumerate(random_posts, start=1): -            post_title = post["data"]["title"][0:50] -            post_url = post['data']['url'] -            if post_title == "": -                post_title = "No Title." -            elif post_title == post_url: -                post_title = "Title is itself a link." +        A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog +        will be unloaded and a ClientError raised if retrieval was still unsuccessful. +        """ +        for i in range(1, self.MAX_RETRIES + 1): +            response = await self.bot.http_session.post( +                url=f"{self.URL}/api/v1/access_token", +                headers=self.HEADERS, +                auth=self.client_auth, +                data={ +                    "grant_type": "client_credentials", +                    "duration": "temporary" +                } +            ) -            # ------------------------------------------------------------------ -            # Embed building. +            if response.status == 200 and response.content_type == "application/json": +                content = await response.json() +                expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway. +                self.access_token = AccessToken( +                    token=content["access_token"], +                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) +                ) -            embed_titles += f"**{i}.[{post_title}]({post_url})**\n" -            image_url = " " -            post_stats = f"{text_emoji}"  # Set default content type to text. +                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") +                return +            else: +                log.debug( +                    f"Failed to get an access token: " +                    f"status {response.status} & content type {response.content_type}; " +                    f"retrying ({i}/{self.MAX_RETRIES})" +                ) -            if post["data"]["is_video"] is True or "youtube" in post_url.split("."): -                # This means the content type in the post is a video. -                post_stats = f"{video_emoji} " +            await asyncio.sleep(3) -            elif post_url.endswith("jpg") or post_url.endswith("png") or post_url.endswith("gif"): -                # This means the content type in the post is an image. -                post_stats = f"{image_emoji} " -                image_url = post_url - -            votes = f'{upvote_emoji}{post["data"]["ups"]}' -            comments = f'{comment_emoji}\u2002{ post["data"]["num_comments"]}' -            post_stats += ( -                f"\u2002{votes}\u2003" -                f"{comments}" -                f'\u2003{user_emoji}\u2002{post["data"]["author"]}\n' +        self.bot.remove_cog(self.qualified_name) +        raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + +    async def revoke_access_token(self) -> None: +        """ +        Revoke the OAuth2 access token for the Reddit API. + +        For security reasons, it's good practice to revoke the token when it's no longer being used. +        """ +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/revoke_token", +            headers=self.HEADERS, +            auth=self.client_auth, +            data={ +                "token": self.access_token.token, +                "token_type_hint": "access_token" +            } +        ) + +        if response.status == 204 and response.content_type == "application/json": +            self.access_token = None +        else: +            log.warning(f"Unable to revoke access token: status {response.status}.") + +    async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: +        """A helper method to fetch a certain amount of Reddit posts at a given route.""" +        # Reddit's JSON responses only provide 25 posts at most. +        if not 25 >= amount > 0: +            raise ValueError("Invalid amount of subreddit posts requested.") + +        # Renew the token if necessary. +        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token() + +        url = f"{self.OAUTH_URL}/{route}" +        for _ in range(self.MAX_RETRIES): +            response = await self.bot.http_session.get( +                url=url, +                headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, +                params=params +            ) +            if response.status == 200 and response.content_type == 'application/json': +                # Got appropriate response - process and return. +                content = await response.json() +                posts = content["data"]["children"] + +                filtered_posts = [post for post in posts if not post["data"]["over_18"]] + +                return filtered_posts[:amount] + +            await asyncio.sleep(3) + +        log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") +        return list()  # Failed to get appropriate response within allowed number of retries. + +    async def get_top_posts( +            self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False +    ) -> Union[Embed, List[tuple]]: +        """ +        Get the top amount of posts for a given subreddit within a specified timeframe. + +        A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top +        weekly posts. + +        The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. +        """ +        embed = Embed() + +        posts = await self.fetch_posts( +            route=f"{subreddit}/top", +            amount=amount, +            params={"t": time} +        ) +        if not posts: +            embed.title = random.choice(ERROR_REPLIES) +            embed.colour = Colour.red() +            embed.description = ( +                "Sorry! We couldn't find any SFW posts from that subreddit. " +                "If this problem persists, please let us know."              ) -            embed_titles += f"{post_stats}\n" -            page_text = f"**[{post_title}]({post_url})**\n{post_stats}\n{post['data']['selftext'][0:200]}" -            embed = discord.Embed() -            page_tuple = (page_text, image_url) -            pages.append(page_tuple) +            return embed + +        if paginate: +            return self.build_pagination_pages(posts, paginate=True) + +        # Use only starting summary page for #reddit channel posts. +        embed.description = self.build_pagination_pages(posts, paginate=False) +        embed.colour = Colour.blurple() +        return embed + +    @loop() +    async def auto_poster_loop(self) -> None: +        """Post the top 5 posts daily, and the top 5 posts weekly.""" +        # once d.py get support for `time` parameter in loop decorator, +        # this can be removed and the loop can use the `time=datetime.time.min` parameter +        now = datetime.utcnow() +        tomorrow = now + timedelta(days=1) +        midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) + +        await sleep_until(midnight_tomorrow) + +        await self.bot.wait_until_guild_available() +        if not self.webhook: +            await self.bot.fetch_webhook(RedditConfig.webhook) + +        if datetime.utcnow().weekday() == 0: +            await self.top_weekly_posts() +            # if it's a monday send the top weekly posts + +        for subreddit in RedditConfig.subreddits: +            top_posts = await self.get_top_posts(subreddit=subreddit, time="day") +            username = sub_clyde(f"{subreddit} Top Daily Posts") +            message = await self.webhook.send(username=username, embed=top_posts, wait=True) + +            if message.channel.is_news(): +                await message.publish() + +    async def top_weekly_posts(self) -> None: +        """Post a summary of the top posts.""" +        for subreddit in RedditConfig.subreddits: +            # Send and pin the new weekly posts. +            top_posts = await self.get_top_posts(subreddit=subreddit, time="week") +            username = sub_clyde(f"{subreddit} Top Weekly Posts") +            message = await self.webhook.send(wait=True, username=username, embed=top_posts) -            # ------------------------------------------------------------------ +            if subreddit.lower() == "r/python": +                if not self.channel: +                    log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") +                    return + +                # Remove the oldest pins so that only 12 remain at most. +                pins = await self.channel.pins() + +                while len(pins) >= 12: +                    await pins[-1].unpin() +                    del pins[-1] + +                await message.pin() + +                if message.channel.is_news(): +                    await message.publish() + +    @group(name="reddit", invoke_without_command=True) +    async def reddit_group(self, ctx: Context) -> None: +        """View the top posts from various subreddits.""" +        await invoke_help_command(ctx) + +    @reddit_group.command(name="top") +    async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of all time from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="all", paginate=True) + +        await ctx.send(f"Here are the top {subreddit} posts of all time!") +        embed = Embed( +            color=Colour.blurple() +        ) -        pages.insert(0, (embed_titles, " ")) -        embed.set_author(name=f"r/{posts[0]['data']['subreddit']} - {sort}", icon_url=reddit_emoji.url)          await ImagePaginator.paginate(pages, ctx, embed) +    @reddit_group.command(name="daily") +    async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of today from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="day", paginate=True) + +        await ctx.send(f"Here are today's top {subreddit} posts!") +        embed = Embed( +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @reddit_group.command(name="weekly") +    async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: +        """Send the top posts of this week from a given subreddit.""" +        async with ctx.typing(): +            pages = await self.get_top_posts(subreddit=subreddit, time="week", paginate=True) + +        await ctx.send(f"Here are this week's top {subreddit} posts!") +        embed = Embed( +            color=Colour.blurple() +        ) + +        await ImagePaginator.paginate(pages, ctx, embed) + +    @has_any_role(*STAFF_ROLES) +    @reddit_group.command(name="subreddits", aliases=("subs",)) +    async def subreddits_command(self, ctx: Context) -> None: +        """Send a paginated embed of all the subreddits we're relaying.""" +        embed = Embed() +        embed.title = "Relayed subreddits." +        embed.colour = Colour.blurple() + +        await LinePaginator.paginate( +            RedditConfig.subreddits, +            ctx, embed, +            footer_text="Use the reddit commands along with these to view their posts.", +            empty=False, +            max_lines=15 +        ) + -def setup(bot: commands.Bot) -> None: -    """Load the Cog.""" +def setup(bot: Bot) -> None: +    """Load the Reddit cog.""" +    if not RedditConfig.secret or not RedditConfig.client_id: +        log.error("Credentials not provided, cog not loaded.") +        return      bot.add_cog(Reddit(bot)) diff --git a/bot/utils/converters.py b/bot/utils/converters.py index 228714c9..27804170 100644 --- a/bot/utils/converters.py +++ b/bot/utils/converters.py @@ -1,5 +1,6 @@  import discord -from discord.ext.commands.converter import MessageConverter +from discord.ext.commands import BadArgument, Context +from discord.ext.commands.converter import Converter, MessageConverter  class WrappedMessageConverter(MessageConverter): @@ -14,3 +15,32 @@ class WrappedMessageConverter(MessageConverter):              argument = argument[1:-1]          return await super().convert(ctx, argument) + + +class Subreddit(Converter): +    """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" + +    @staticmethod +    async def convert(ctx: Context, sub: str) -> str: +        """ +        Force sub to begin with "r/" and check if it's a valid subreddit. + +        If sub is a valid subreddit, return it prepended with "r/" +        """ +        sub = sub.lower() + +        if not sub.startswith("r/"): +            sub = f"r/{sub}" + +        resp = await ctx.bot.http_session.get( +            "https://www.reddit.com/subreddits/search.json", +            params={"q": sub} +        ) + +        json = await resp.json() +        if not json["data"]["children"]: +            raise BadArgument( +                f"The subreddit `{sub}` either doesn't exist, or it has no posts." +            ) + +        return sub diff --git a/bot/utils/messages.py b/bot/utils/messages.py new file mode 100644 index 00000000..a6c035f9 --- /dev/null +++ b/bot/utils/messages.py @@ -0,0 +1,19 @@ +import re +from typing import Optional + + +def sub_clyde(username: Optional[str]) -> Optional[str]: +    """ +    Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"E" and return the new string. + +    Discord disallows "clyde" anywhere in the username for webhooks. It will return a 400. +    Return None only if `username` is None. +    """ +    def replace_e(match: re.Match) -> str: +        char = "е" if match[2] == "e" else "Е" +        return match[1] + char + +    if username: +        return re.sub(r"(clyd)(e)", replace_e, username, flags=re.I) +    else: +        return username  # Empty string or None diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index a4d0cc56..917275c0 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -4,6 +4,7 @@ from typing import Iterable, List, Optional, Tuple  from discord import Embed, Member, Reaction  from discord.abc import User +from discord.embeds import EmptyEmbed  from discord.ext.commands import Context, Paginator  from bot.constants import Emojis @@ -417,9 +418,8 @@ class ImagePaginator(Paginator):              await message.edit(embed=embed)              embed.description = paginator.pages[current_page] -            image = paginator.images[current_page] -            if image: -                embed.set_image(url=image) +            image = paginator.images[current_page] or EmptyEmbed +            embed.set_image(url=image)              embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")              log.debug(f"Got {reaction_type} page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") | 
