aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts/evergreen/reddit.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bot/exts/evergreen/reddit.py425
1 files changed, 332 insertions, 93 deletions
diff --git a/bot/exts/evergreen/reddit.py b/bot/exts/evergreen/reddit.py
index 49127bea..e57fa2c0 100644
--- a/bot/exts/evergreen/reddit.py
+++ b/bot/exts/evergreen/reddit.py
@@ -1,128 +1,367 @@
+import asyncio
import logging
import random
+import textwrap
+from collections import namedtuple
+from datetime import datetime, timedelta
+from typing import List, Union
-import discord
-from discord.ext import commands
-from discord.ext.commands.cooldowns import BucketType
+from aiohttp import BasicAuth, ClientError
+from discord import Colour, Embed, TextChannel
+from discord.ext.commands import Cog, Context, group, has_any_role
+from discord.ext.tasks import loop
+from discord.utils import escape_markdown, sleep_until
-from bot.utils.pagination import ImagePaginator
+from bot.bot import Bot
+from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES
+from bot.utils.converters import Subreddit
+from bot.utils.extensions import invoke_help_command
+from bot.utils.messages import sub_clyde
+from bot.utils.pagination import ImagePaginator, LinePaginator
log = logging.getLogger(__name__)
+AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
-class Reddit(commands.Cog):
- """Fetches reddit posts."""
- def __init__(self, bot: commands.Bot):
+class Reddit(Cog):
+ """Track subreddit posts and show detailed statistics about them."""
+
+ HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
+ URL = "https://www.reddit.com"
+ OAUTH_URL = "https://oauth.reddit.com"
+ MAX_RETRIES = 3
+
+ def __init__(self, bot: Bot):
self.bot = bot
- async def fetch(self, url: str) -> dict:
- """Send a get request to the reddit API and get json response."""
- session = self.bot.http_session
- params = {
- 'limit': 50
- }
- headers = {
- 'User-Agent': 'Iceman'
- }
-
- async with session.get(url=url, params=params, headers=headers) as response:
- return await response.json()
-
- @commands.command(name='reddit')
- @commands.cooldown(1, 10, BucketType.user)
- async def get_reddit(self, ctx: commands.Context, subreddit: str = 'python', sort: str = "hot") -> None:
- """
- Fetch reddit posts by using this command.
+ self.webhook = None
+ self.access_token = None
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
- Gets a post from r/python by default.
- Usage:
- --> .reddit [subreddit_name] [hot/top/new]
- """
+ bot.loop.create_task(self.init_reddit_ready())
+ self.auto_poster_loop.start()
+
+ def cog_unload(self) -> None:
+ """Stop the loop task and revoke the access token when the cog is unloaded."""
+ self.auto_poster_loop.cancel()
+ if self.access_token and self.access_token.expires_at > datetime.utcnow():
+ asyncio.create_task(self.revoke_access_token())
+
+ async def init_reddit_ready(self) -> None:
+ """Sets the reddit webhook when the cog is loaded."""
+ await self.bot.wait_until_guild_available()
+ if not self.webhook:
+ self.webhook = await self.bot.fetch_webhook(RedditConfig.webhook)
+
+ @property
+ def channel(self) -> TextChannel:
+ """Get the #reddit channel object from the bot's cache."""
+ return self.bot.get_channel(Channels.reddit)
+
+ def build_pagination_pages(self, posts: List[dict], paginate: bool) -> Union[List[tuple], str]:
+ """Build embed pages required for Paginator."""
pages = []
- sort_list = ["hot", "new", "top", "rising"]
- if sort.lower() not in sort_list:
- await ctx.send(f"Invalid sorting: {sort}\nUsing default sorting: `Hot`")
- sort = "hot"
+ first_page = ""
+ for post in posts:
+ post_page = ""
+ image_url = ""
- data = await self.fetch(f'https://www.reddit.com/r/{subreddit}/{sort}/.json')
+ data = post["data"]
- try:
- posts = data["data"]["children"]
- except KeyError:
- return await ctx.send('Subreddit not found!')
- if not posts:
- return await ctx.send('No posts available!')
+ title = textwrap.shorten(data["title"], width=50, placeholder="...")
+
+ # Normal brackets interfere with Markdown.
+ title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌")
+ link = self.URL + data["permalink"]
+
+ first_page += f"**[{title.replace('*', '')}]({link})**\n"
+
+ text = data["selftext"]
+ if text:
+ first_page += textwrap.shorten(text, width=100, placeholder="...").replace("*", "") + "\n"
+
+ ups = data["ups"]
+ comments = data["num_comments"]
+ author = data["author"]
+
+ content_type = Emojis.reddit_post_text
+ if data["is_video"] or {"youtube", "youtu.be"}.issubset(set(data["url"].split("."))):
+ # This means the content type in the post is a video.
+ content_type = f"{Emojis.reddit_post_video}"
+
+ elif data["url"].endswith(("jpg", "png", "gif")):
+ # This means the content type in the post is an image.
+ content_type = f"{Emojis.reddit_post_photo}"
+ image_url = data["url"]
- if posts[1]["data"]["over_18"] is True:
- return await ctx.send(
- "You cannot access this Subreddit as it is ment for those who "
- "are 18 years or older."
+ first_page += (
+ f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}"
+ f"\u2002{comments}\u2003{Emojis.reddit_users}{author}\n\n"
)
- embed_titles = ""
+ if paginate:
+ post_page += f"**[{title}]({link})**\n\n"
+ if text:
+ post_page += textwrap.shorten(text, width=252, placeholder="...") + "\n\n"
+ post_page += (
+ f"{content_type}\u2003{Emojis.reddit_upvote}{ups}\u2003{Emojis.reddit_comments}\u2002"
+ f"{comments}\u2003{Emojis.reddit_users}{author}"
+ )
- # Chooses k unique random elements from a population sequence or set.
- random_posts = random.sample(posts, k=5)
+ pages.append((post_page, image_url))
- # -----------------------------------------------------------
- # This code below is bound of change when the emojis are added.
+ if not paginate:
+ # Return the first summery page if pagination is not required
+ return first_page
- upvote_emoji = self.bot.get_emoji(755845219890757644)
- comment_emoji = self.bot.get_emoji(755845255001014384)
- user_emoji = self.bot.get_emoji(755845303822974997)
- text_emoji = self.bot.get_emoji(676030265910493204)
- video_emoji = self.bot.get_emoji(676030265839190047)
- image_emoji = self.bot.get_emoji(676030265734201344)
- reddit_emoji = self.bot.get_emoji(676030265734332427)
+ pages.insert(0, (first_page, "")) # Using image paginator, hence settings image url to empty string
+ return pages
- # ------------------------------------------------------------
+ async def get_access_token(self) -> None:
+ """
+ Get a Reddit API OAuth2 access token and assign it to self.access_token.
- for i, post in enumerate(random_posts, start=1):
- post_title = post["data"]["title"][0:50]
- post_url = post['data']['url']
- if post_title == "":
- post_title = "No Title."
- elif post_title == post_url:
- post_title = "Title is itself a link."
+ A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
+ will be unloaded and a ClientError raised if retrieval was still unsuccessful.
+ """
+ for i in range(1, self.MAX_RETRIES + 1):
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=self.HEADERS,
+ auth=self.client_auth,
+ data={
+ "grant_type": "client_credentials",
+ "duration": "temporary"
+ }
+ )
- # ------------------------------------------------------------------
- # Embed building.
+ if response.status == 200 and response.content_type == "application/json":
+ content = await response.json()
+ expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway.
+ self.access_token = AccessToken(
+ token=content["access_token"],
+ expires_at=datetime.utcnow() + timedelta(seconds=expiration)
+ )
- embed_titles += f"**{i}.[{post_title}]({post_url})**\n"
- image_url = " "
- post_stats = f"{text_emoji}" # Set default content type to text.
+ log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")
+ return
+ else:
+ log.debug(
+ f"Failed to get an access token: "
+ f"status {response.status} & content type {response.content_type}; "
+ f"retrying ({i}/{self.MAX_RETRIES})"
+ )
- if post["data"]["is_video"] is True or "youtube" in post_url.split("."):
- # This means the content type in the post is a video.
- post_stats = f"{video_emoji} "
+ await asyncio.sleep(3)
- elif post_url.endswith("jpg") or post_url.endswith("png") or post_url.endswith("gif"):
- # This means the content type in the post is an image.
- post_stats = f"{image_emoji} "
- image_url = post_url
-
- votes = f'{upvote_emoji}{post["data"]["ups"]}'
- comments = f'{comment_emoji}\u2002{ post["data"]["num_comments"]}'
- post_stats += (
- f"\u2002{votes}\u2003"
- f"{comments}"
- f'\u2003{user_emoji}\u2002{post["data"]["author"]}\n'
+ self.bot.remove_cog(self.qualified_name)
+ raise ClientError("Authentication with the Reddit API failed. Unloading the cog.")
+
+ async def revoke_access_token(self) -> None:
+ """
+ Revoke the OAuth2 access token for the Reddit API.
+
+ For security reasons, it's good practice to revoke the token when it's no longer being used.
+ """
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/revoke_token",
+ headers=self.HEADERS,
+ auth=self.client_auth,
+ data={
+ "token": self.access_token.token,
+ "token_type_hint": "access_token"
+ }
+ )
+
+ if response.status in [200, 204] and response.content_type == "application/json":
+ self.access_token = None
+ else:
+ log.warning(f"Unable to revoke access token: status {response.status}.")
+
+ async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
+ """A helper method to fetch a certain amount of Reddit posts at a given route."""
+ # Reddit's JSON responses only provide 25 posts at most.
+ if not 25 >= amount > 0:
+ raise ValueError("Invalid amount of subreddit posts requested.")
+
+ # Renew the token if necessary.
+ if not self.access_token or self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
+
+ url = f"{self.OAUTH_URL}/{route}"
+ for _ in range(self.MAX_RETRIES):
+ response = await self.bot.http_session.get(
+ url=url,
+ headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},
+ params=params
+ )
+ if response.status == 200 and response.content_type == 'application/json':
+ # Got appropriate response - process and return.
+ content = await response.json()
+ posts = content["data"]["children"]
+
+ filtered_posts = [post for post in posts if not post["data"]["over_18"]]
+
+ return filtered_posts[:amount]
+
+ await asyncio.sleep(3)
+
+ log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")
+ return list() # Failed to get appropriate response within allowed number of retries.
+
+ async def get_top_posts(
+ self, subreddit: Subreddit, time: str = "all", amount: int = 5, paginate: bool = False
+ ) -> Union[Embed, List[tuple]]:
+ """
+ Get the top amount of posts for a given subreddit within a specified timeframe.
+
+ A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top
+ weekly posts.
+
+ The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most.
+ """
+ embed = Embed()
+
+ posts = await self.fetch_posts(
+ route=f"{subreddit}/top",
+ amount=amount,
+ params={"t": time}
+ )
+ if not posts:
+ embed.title = random.choice(ERROR_REPLIES)
+ embed.colour = Colour.red()
+ embed.description = (
+ "Sorry! We couldn't find any SFW posts from that subreddit. "
+ "If this problem persists, please let us know."
)
- embed_titles += f"{post_stats}\n"
- page_text = f"**[{post_title}]({post_url})**\n{post_stats}\n{post['data']['selftext'][0:200]}"
- embed = discord.Embed()
- page_tuple = (page_text, image_url)
- pages.append(page_tuple)
+ return embed
+
+ if paginate:
+ return self.build_pagination_pages(posts, paginate=True)
+
+ # Use only starting summary page for #reddit channel posts.
+ embed.description = self.build_pagination_pages(posts, paginate=False)
+ embed.colour = Colour.blurple()
+ return embed
+
+ @loop()
+ async def auto_poster_loop(self) -> None:
+ """Post the top 5 posts daily, and the top 5 posts weekly."""
+ # once d.py get support for `time` parameter in loop decorator,
+ # this can be removed and the loop can use the `time=datetime.time.min` parameter
+ now = datetime.utcnow()
+ tomorrow = now + timedelta(days=1)
+ midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)
+
+ await sleep_until(midnight_tomorrow)
+
+ await self.bot.wait_until_guild_available()
+ if not self.webhook:
+ await self.bot.fetch_webhook(RedditConfig.webhook)
+
+ if datetime.utcnow().weekday() == 0:
+ await self.top_weekly_posts()
+ # if it's a monday send the top weekly posts
+
+ for subreddit in RedditConfig.subreddits:
+ top_posts = await self.get_top_posts(subreddit=subreddit, time="day")
+ username = sub_clyde(f"{subreddit} Top Daily Posts")
+ message = await self.webhook.send(username=username, embed=top_posts, wait=True)
+
+ if message.channel.is_news():
+ await message.publish()
+
+ async def top_weekly_posts(self) -> None:
+ """Post a summary of the top posts."""
+ for subreddit in RedditConfig.subreddits:
+ # Send and pin the new weekly posts.
+ top_posts = await self.get_top_posts(subreddit=subreddit, time="week")
+ username = sub_clyde(f"{subreddit} Top Weekly Posts")
+ message = await self.webhook.send(wait=True, username=username, embed=top_posts)
- # ------------------------------------------------------------------
+ if subreddit.lower() == "r/python":
+ if not self.channel:
+ log.warning("Failed to get #reddit channel to remove pins in the weekly loop.")
+ return
+
+ # Remove the oldest pins so that only 12 remain at most.
+ pins = await self.channel.pins()
+
+ while len(pins) >= 12:
+ await pins[-1].unpin()
+ del pins[-1]
+
+ await message.pin()
+
+ if message.channel.is_news():
+ await message.publish()
+
+ @group(name="reddit", invoke_without_command=True)
+ async def reddit_group(self, ctx: Context) -> None:
+ """View the top posts from various subreddits."""
+ await invoke_help_command(ctx)
+
+ @reddit_group.command(name="top")
+ async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
+ """Send the top posts of all time from a given subreddit."""
+ async with ctx.typing():
+ pages = await self.get_top_posts(subreddit=subreddit, time="all", paginate=True)
+
+ await ctx.send(f"Here are the top {subreddit} posts of all time!")
+ embed = Embed(
+ color=Colour.blurple()
+ )
- pages.insert(0, (embed_titles, " "))
- embed.set_author(name=f"r/{posts[0]['data']['subreddit']} - {sort}", icon_url=reddit_emoji.url)
await ImagePaginator.paginate(pages, ctx, embed)
+ @reddit_group.command(name="daily")
+ async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
+ """Send the top posts of today from a given subreddit."""
+ async with ctx.typing():
+ pages = await self.get_top_posts(subreddit=subreddit, time="day", paginate=True)
+
+ await ctx.send(f"Here are today's top {subreddit} posts!")
+ embed = Embed(
+ color=Colour.blurple()
+ )
+
+ await ImagePaginator.paginate(pages, ctx, embed)
+
+ @reddit_group.command(name="weekly")
+ async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
+ """Send the top posts of this week from a given subreddit."""
+ async with ctx.typing():
+ pages = await self.get_top_posts(subreddit=subreddit, time="week", paginate=True)
+
+ await ctx.send(f"Here are this week's top {subreddit} posts!")
+ embed = Embed(
+ color=Colour.blurple()
+ )
+
+ await ImagePaginator.paginate(pages, ctx, embed)
+
+ @has_any_role(*STAFF_ROLES)
+ @reddit_group.command(name="subreddits", aliases=("subs",))
+ async def subreddits_command(self, ctx: Context) -> None:
+ """Send a paginated embed of all the subreddits we're relaying."""
+ embed = Embed()
+ embed.title = "Relayed subreddits."
+ embed.colour = Colour.blurple()
+
+ await LinePaginator.paginate(
+ RedditConfig.subreddits,
+ ctx, embed,
+ footer_text="Use the reddit commands along with these to view their posts.",
+ empty=False,
+ max_lines=15
+ )
+
-def setup(bot: commands.Bot) -> None:
- """Load the Cog."""
+def setup(bot: Bot) -> None:
+ """Load the Reddit cog."""
+ if not RedditConfig.secret or not RedditConfig.client_id:
+ log.error("Credentials not provided, cog not loaded.")
+ return
bot.add_cog(Reddit(bot))