diff options
-rw-r--r-- | bot/cogs/reddit.py | 229 | ||||
-rw-r--r-- | bot/constants.py | 2 | ||||
-rw-r--r-- | config-default.yml | 2 |
3 files changed, 85 insertions, 148 deletions
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0f575cece..7749d237f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,13 +2,14 @@ import asyncio import logging import random import textwrap -from datetime import datetime, timedelta +from datetime import datetime from typing import List -from discord import Colour, Embed, Message, TextChannel +from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.tasks import loop -from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES +from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES, Webhooks from bot.converters import Subreddit from bot.decorators import with_role from bot.pagination import LinePaginator @@ -26,15 +27,25 @@ class Reddit(Cog): def __init__(self, bot: Bot): self.bot = bot - self.reddit_channel = None + self.webhook = None # set in on_ready + bot.loop.create_task(self.init_reddit_ready()) - self.prev_lengths = {} - self.last_ids = {} + self.auto_poster_loop.start() - self.new_posts_task = None - self.top_weekly_posts_task = None + def cog_unload(self) -> None: + """Stops the loops when the cog is unloaded.""" + self.auto_poster_loop.cancel() - self.bot.loop.create_task(self.init_reddit_polling()) + async def init_reddit_ready(self) -> None: + """Sets the reddit webhook when the cog is loaded.""" + await self.bot.wait_until_ready() + if not self.webhook: + self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) + + @property + def channel(self) -> TextChannel: + """Get the #reddit channel object from the bot's cache.""" + return self.bot.get_channel(Channels.reddit) async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -63,23 +74,22 @@ class Reddit(Cog): log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") return list() # Failed to get appropriate response within allowed number of retries. - async def send_top_posts( - self, channel: TextChannel, subreddit: Subreddit, content: str = None, time: str = "all" - ) -> Message: - """Create an embed for the top posts, then send it in a given TextChannel.""" - # Create the new spicy embed. - embed = Embed() - embed.description = "" - - # Get the posts - async with channel.typing(): - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=5, - params={ - "t": time - } - ) + async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: + """ + Get the top amount of posts for a given subreddit within a specified timeframe. + + A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top + weekly posts. + + The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. + """ + embed = Embed(description="") + + posts = await self.fetch_posts( + route=f"{subreddit}/top", + amount=amount, + params={"t": time} + ) if not posts: embed.title = random.choice(ERROR_REPLIES) @@ -89,9 +99,7 @@ class Reddit(Cog): "If this problem persists, please let us know." ) - return await channel.send( - embed=embed - ) + return embed for post in posts: data = post["data"] @@ -115,103 +123,51 @@ class Reddit(Cog): ) embed.colour = Colour.blurple() + return embed - return await channel.send( - content=content, - embed=embed - ) - - async def poll_new_posts(self) -> None: - """Periodically search for new subreddit posts.""" - while True: - await asyncio.sleep(RedditConfig.request_delay) - - for subreddit in RedditConfig.subreddits: - # Make a HEAD request to the subreddit - head_response = await self.bot.http_session.head( - url=f"{self.URL}/{subreddit}/new.rss", - headers=self.HEADERS - ) - - content_length = head_response.headers["content-length"] - - # If the content is the same size as before, assume there's no new posts. - if content_length == self.prev_lengths.get(subreddit, None): - continue - - self.prev_lengths[subreddit] = content_length - - # Now we can actually fetch the new data - posts = await self.fetch_posts(f"{subreddit}/new") - new_posts = [] + @loop() + async def auto_poster_loop(self) -> None: + """Post the top 5 posts daily, and the top 5 posts weekly.""" + # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter + now = datetime.utcnow() + midnight_tomorrow = now.replace(day=now.day + 1, hour=0, minute=0, second=0) + seconds_until = (midnight_tomorrow - now).total_seconds() - # Only show new posts if we've checked before. - if subreddit in self.last_ids: - for post in posts: - data = post["data"] + await asyncio.sleep(seconds_until) - # Convert the ID to an integer for easy comparison. - int_id = int(data["id"], 36) - - # If we've already seen this post, finish checking - if int_id <= self.last_ids[subreddit]: - break - - embed_data = { - "title": textwrap.shorten(data["title"], width=64, placeholder="..."), - "text": textwrap.shorten(data["selftext"], width=128, placeholder="..."), - "url": self.URL + data["permalink"], - "author": data["author"] - } - - new_posts.append(embed_data) - - self.last_ids[subreddit] = int(posts[0]["data"]["id"], 36) - - # Send all of the new posts as spicy embeds - for data in new_posts: - embed = Embed() - - embed.title = data["title"] - embed.url = data["url"] - embed.description = data["text"] - embed.set_footer(text=f"Posted by u/{data['author']} in {subreddit}") - embed.colour = Colour.blurple() - - await self.reddit_channel.send(embed=embed) + await self.bot.wait_until_ready() + if not self.webhook: + await self.bot.fetch_webhook(Webhooks.reddit) - log.trace(f"Sent {len(new_posts)} new {subreddit} posts to channel {self.reddit_channel.id}.") + if datetime.utcnow().weekday() == 0: + await self.top_weekly_posts() + # if it's a monday send the top weekly posts - async def poll_top_weekly_posts(self) -> None: - """Post a summary of the top posts every week.""" - while True: - now = datetime.utcnow() + for subreddit in RedditConfig.subreddits: + top_posts = await self.get_top_posts(subreddit=subreddit, time="day") + await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) - # Calculate the amount of seconds until midnight next monday. - monday = now + timedelta(days=7 - now.weekday()) - monday = monday.replace(hour=0, minute=0, second=0) - until_monday = (monday - now).total_seconds() + async def top_weekly_posts(self) -> None: + """Post a summary of the top posts.""" + for subreddit in RedditConfig.subreddits: + # Send and pin the new weekly posts. + top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - await asyncio.sleep(until_monday) + message = await self.webhook.send(wait=True, username=f"{subreddit} Top Weekly Posts", embed=top_posts) - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - message = await self.send_top_posts( - channel=self.reddit_channel, - subreddit=subreddit, - content=f"This week's top {subreddit} posts have arrived!", - time="week" - ) + if subreddit.lower() == "r/python": + if not self.channel: + log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") + return - if subreddit.lower() == "r/python": - # Remove the oldest pins so that only 5 remain at most. - pins = await self.reddit_channel.pins() + # Remove the oldest pins so that only 12 remain at most. + pins = await self.channel.pins() - while len(pins) >= 5: - await pins[-1].unpin() - del pins[-1] + while len(pins) >= 12: + await pins[-1].unpin() + del pins[-1] - await message.pin() + await message.pin() @group(name="reddit", invoke_without_command=True) async def reddit_group(self, ctx: Context) -> None: @@ -221,32 +177,26 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" - await self.send_top_posts( - channel=ctx.channel, - subreddit=subreddit, - content=f"Here are the top {subreddit} posts of all time!", - time="all" - ) + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="all") + + await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" - await self.send_top_posts( - channel=ctx.channel, - subreddit=subreddit, - content=f"Here are today's top {subreddit} posts!", - time="day" - ) + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="day") + + await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" - await self.send_top_posts( - channel=ctx.channel, - subreddit=subreddit, - content=f"Here are this week's top {subreddit} posts!", - time="week" - ) + async with ctx.typing(): + embed = await self.get_top_posts(subreddit=subreddit, time="week") + + await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) @with_role(*STAFF_ROLES) @reddit_group.command(name="subreddits", aliases=("subs",)) @@ -264,19 +214,6 @@ class Reddit(Cog): max_lines=15 ) - async def init_reddit_polling(self) -> None: - """Initiate reddit post event loop.""" - await self.bot.wait_until_ready() - self.reddit_channel = await self.bot.fetch_channel(Channels.reddit) - - if self.reddit_channel is not None: - if self.new_posts_task is None: - self.new_posts_task = self.bot.loop.create_task(self.poll_new_posts()) - if self.top_weekly_posts_task is None: - self.top_weekly_posts_task = self.bot.loop.create_task(self.poll_top_weekly_posts()) - else: - log.warning("Couldn't locate a channel for subreddit relaying.") - def setup(bot: Bot) -> None: """Reddit cog load.""" diff --git a/bot/constants.py b/bot/constants.py index 60fc1b723..838fe7a79 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -369,6 +369,7 @@ class Webhooks(metaclass=YAMLGetter): talent_pool: int big_brother: int + reddit: int class Roles(metaclass=YAMLGetter): @@ -444,7 +445,6 @@ class URLs(metaclass=YAMLGetter): class Reddit(metaclass=YAMLGetter): section = "reddit" - request_delay: int subreddits: list diff --git a/config-default.yml b/config-default.yml index 8e86234ac..4638a89ee 100644 --- a/config-default.yml +++ b/config-default.yml @@ -147,6 +147,7 @@ guild: webhooks: talent_pool: 569145364800602132 big_brother: 569133704568373283 + reddit: 635408384794951680 filter: @@ -350,7 +351,6 @@ anti_malware: reddit: - request_delay: 60 subreddits: - 'r/Python' |