diff options
author | 2021-05-09 13:49:31 -0700 | |
---|---|---|
committer | 2021-05-09 13:49:31 -0700 | |
commit | 5cc54b6eb9b3e20d23792d3b761ca85b4b0f22c4 (patch) | |
tree | c3233bfac93877105a0e0bf5b4ff1425cbbd1414 | |
parent | Merge pull request #1574 from python-discord/ping-bugs (diff) | |
parent | Merge branch 'main' into annihilate_reddit (diff) |
Merge pull request #1542 from RohanJnr/annihilate_reddit
Annihilate reddit cog
-rw-r--r-- | bot/constants.py | 13 | ||||
-rw-r--r-- | bot/converters.py | 29 | ||||
-rw-r--r-- | bot/exts/info/reddit.py | 308 | ||||
-rw-r--r-- | bot/utils/time.py | 17 | ||||
-rw-r--r-- | config-default.yml | 13 | ||||
-rw-r--r-- | tests/bot/utils/test_time.py | 13 |
6 files changed, 0 insertions, 393 deletions
diff --git a/bot/constants.py b/bot/constants.py index 7b2a38079..e1c3ade5a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -307,10 +307,6 @@ class Emojis(metaclass=YAMLGetter): new: str pencil: str - comments: str - upvotes: str - user: str - ok_hand: str @@ -471,7 +467,6 @@ class Webhooks(metaclass=YAMLGetter): dev_log: int duck_pond: int incidents_archive: int - reddit: int talent_pool: int @@ -551,14 +546,6 @@ class URLs(metaclass=YAMLGetter): paste_service: str -class Reddit(metaclass=YAMLGetter): - section = "reddit" - - client_id: Optional[str] - secret: Optional[str] - subreddits: list - - class AntiSpam(metaclass=YAMLGetter): section = 'anti_spam' diff --git a/bot/converters.py b/bot/converters.py index 3bf05cfb3..2a3943831 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -236,35 +236,6 @@ class Snowflake(IDConverter): return snowflake -class Subreddit(Converter): - """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" - - @staticmethod - async def convert(ctx: Context, sub: str) -> str: - """ - Force sub to begin with "r/" and check if it's a valid subreddit. - - If sub is a valid subreddit, return it prepended with "r/" - """ - sub = sub.lower() - - if not sub.startswith("r/"): - sub = f"r/{sub}" - - resp = await ctx.bot.http_session.get( - "https://www.reddit.com/subreddits/search.json", - params={"q": sub} - ) - - json = await resp.json() - if not json["data"]["children"]: - raise BadArgument( - f"The subreddit `{sub}` either doesn't exist, or it has no posts." - ) - - return sub - - class TagNameConverter(Converter): """ Ensure that a proposed tag name is valid. diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py deleted file mode 100644 index e1f0c5f9f..000000000 --- a/bot/exts/info/reddit.py +++ /dev/null @@ -1,308 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from html import unescape -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group, has_any_role -from discord.ext.tasks import loop -from discord.utils import escape_markdown, sleep_until - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): - """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} - URL = "https://www.reddit.com" - OAUTH_URL = "https://oauth.reddit.com" - MAX_RETRIES = 3 - - def __init__(self, bot: Bot): - self.bot = bot - - self.webhook = None - self.access_token = None - self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - - bot.loop.create_task(self.init_reddit_ready()) - self.auto_poster_loop.start() - - def cog_unload(self) -> None: - """Stop the loop task and revoke the access token when the cog is unloaded.""" - self.auto_poster_loop.cancel() - if self.access_token and self.access_token.expires_at > datetime.utcnow(): - self.bot.closing_tasks.append(asyncio.create_task(self.revoke_access_token())) - - async def init_reddit_ready(self) -> None: - """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_guild_available() - if not self.webhook: - self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - - @property - def channel(self) -> TextChannel: - """Get the #reddit channel object from the bot's cache.""" - return self.bot.get_channel(Channels.reddit) - - async def get_access_token(self) -> None: - """ - Get a Reddit API OAuth2 access token and assign it to self.access_token. - - A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog - will be unloaded and a ClientError raised if retrieval was still unsuccessful. - """ - for i in range(1, self.MAX_RETRIES + 1): - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/access_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "grant_type": "client_credentials", - "duration": "temporary" - } - ) - - if response.status == 200 and response.content_type == "application/json": - content = await response.json() - expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. - self.access_token = AccessToken( - token=content["access_token"], - expires_at=datetime.utcnow() + timedelta(seconds=expiration) - ) - - log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") - return - else: - log.debug( - f"Failed to get an access token: " - f"status {response.status} & content type {response.content_type}; " - f"retrying ({i}/{self.MAX_RETRIES})" - ) - - await asyncio.sleep(3) - - self.bot.remove_cog(self.qualified_name) - raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - - async def revoke_access_token(self) -> None: - """ - Revoke the OAuth2 access token for the Reddit API. - - For security reasons, it's good practice to revoke the token when it's no longer being used. - """ - response = await self.bot.http_session.post( - url=f"{self.URL}/api/v1/revoke_token", - headers=self.HEADERS, - auth=self.client_auth, - data={ - "token": self.access_token.token, - "token_type_hint": "access_token" - } - ) - - if response.status == 204 and response.content_type == "application/json": - self.access_token = None - else: - log.warning(f"Unable to revoke access token: status {response.status}.") - - async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: - """A helper method to fetch a certain amount of Reddit posts at a given route.""" - # Reddit's JSON responses only provide 25 posts at most. - if not 25 >= amount > 0: - raise ValueError("Invalid amount of subreddit posts requested.") - - # Renew the token if necessary. - if not self.access_token or self.access_token.expires_at < datetime.utcnow(): - await self.get_access_token() - - url = f"{self.OAUTH_URL}/{route}" - for _ in range(self.MAX_RETRIES): - response = await self.bot.http_session.get( - url=url, - headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, - params=params - ) - if response.status == 200 and response.content_type == 'application/json': - # Got appropriate response - process and return. - content = await response.json() - posts = content["data"]["children"] - - filtered_posts = [post for post in posts if not post["data"]["over_18"]] - - return filtered_posts[:amount] - - await asyncio.sleep(3) - - log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") - return list() # Failed to get appropriate response within allowed number of retries. - - async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: - """ - Get the top amount of posts for a given subreddit within a specified timeframe. - - A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top - weekly posts. - - The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. - """ - embed = Embed(description="") - - posts = await self.fetch_posts( - route=f"{subreddit}/top", - amount=amount, - params={"t": time} - ) - if not posts: - embed.title = random.choice(ERROR_REPLIES) - embed.colour = Colour.red() - embed.description = ( - "Sorry! We couldn't find any SFW posts from that subreddit. " - "If this problem persists, please let us know." - ) - - return embed - - for post in posts: - data = post["data"] - - if text := unescape(data["selftext"]): - text = textwrap.shorten(text, width=128, placeholder="...") - text += "\n" # Add newline to separate embed info - - ups = data["ups"] - comments = data["num_comments"] - author = data["author"] - - title = textwrap.shorten(unescape(data["title"]), width=64, placeholder="...") - # Normal brackets interfere with Markdown. - title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") - link = self.URL + data["permalink"] - - embed.description += ( - f"**[{title}]({link})**\n" - f"{text}" - f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" - ) - - embed.colour = Colour.blurple() - return embed - - @loop() - async def auto_poster_loop(self) -> None: - """Post the top 5 posts daily, and the top 5 posts weekly.""" - # once d.py get support for `time` parameter in loop decorator, - # this can be removed and the loop can use the `time=datetime.time.min` parameter - now = datetime.utcnow() - tomorrow = now + timedelta(days=1) - midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - - await sleep_until(midnight_tomorrow) - - await self.bot.wait_until_guild_available() - if not self.webhook: - await self.bot.fetch_webhook(Webhooks.reddit) - - if datetime.utcnow().weekday() == 0: - await self.top_weekly_posts() - # if it's a monday send the top weekly posts - - for subreddit in RedditConfig.subreddits: - top_posts = await self.get_top_posts(subreddit=subreddit, time="day") - username = sub_clyde(f"{subreddit} Top Daily Posts") - message = await self.webhook.send(username=username, embed=top_posts, wait=True) - - if message.channel.is_news(): - await message.publish() - - async def top_weekly_posts(self) -> None: - """Post a summary of the top posts.""" - for subreddit in RedditConfig.subreddits: - # Send and pin the new weekly posts. - top_posts = await self.get_top_posts(subreddit=subreddit, time="week") - username = sub_clyde(f"{subreddit} Top Weekly Posts") - message = await self.webhook.send(wait=True, username=username, embed=top_posts) - - if subreddit.lower() == "r/python": - if not self.channel: - log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") - return - - # Remove the oldest pins so that only 12 remain at most. - pins = await self.channel.pins() - - while len(pins) >= 12: - await pins[-1].unpin() - del pins[-1] - - await message.pin() - - if message.channel.is_news(): - await message.publish() - - @group(name="reddit", invoke_without_command=True) - async def reddit_group(self, ctx: Context) -> None: - """View the top posts from various subreddits.""" - await ctx.send_help(ctx.command) - - @reddit_group.command(name="top") - async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of all time from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="all") - - await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - - @reddit_group.command(name="daily") - async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of today from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="day") - - await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - - @reddit_group.command(name="weekly") - async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: - """Send the top posts of this week from a given subreddit.""" - async with ctx.typing(): - embed = await self.get_top_posts(subreddit=subreddit, time="week") - - await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - - @has_any_role(*STAFF_ROLES) - @reddit_group.command(name="subreddits", aliases=("subs",)) - async def subreddits_command(self, ctx: Context) -> None: - """Send a paginated embed of all the subreddits we're relaying.""" - embed = Embed() - embed.title = "Relayed subreddits." - embed.colour = Colour.blurple() - - await LinePaginator.paginate( - RedditConfig.subreddits, - ctx, embed, - footer_text="Use the reddit commands along with these to view their posts.", - empty=False, - max_lines=15 - ) - - -def setup(bot: Bot) -> None: - """Load the Reddit cog.""" - if not RedditConfig.secret or not RedditConfig.client_id: - log.error("Credentials not provided, cog not loaded.") - return - bot.add_cog(Reddit(bot)) diff --git a/bot/utils/time.py b/bot/utils/time.py index 466f0adc2..d55a0e532 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,4 +1,3 @@ -import asyncio import datetime import re from typing import Optional @@ -144,22 +143,6 @@ def parse_rfc1123(stamp: str) -> datetime.datetime: return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc) -# Hey, this could actually be used in the off_topic_names and reddit cogs :) -async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] = None) -> None: - """ - Wait until a given time. - - :param time: A datetime.datetime object to wait until. - :param start: The start from which to calculate the waiting duration. Defaults to UTC time. - """ - delay = time - (start or datetime.datetime.utcnow()) - delay_seconds = delay.total_seconds() - - # Incorporate a small delay so we don't rapid-fire the event due to time precision errors - if delay_seconds > 1.0: - await asyncio.sleep(delay_seconds) - - def format_infraction(timestamp: str) -> str: """Format an infraction timestamp to a more readable ISO 8601 format.""" return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) diff --git a/config-default.yml b/config-default.yml index 46475f845..c5c9b12ce 100644 --- a/config-default.yml +++ b/config-default.yml @@ -73,11 +73,6 @@ style: new: "\U0001F195" pencil: "\u270F" - # emotes used for #reddit - comments: "<:reddit_comments:755845255001014384>" - upvotes: "<:reddit_upvotes:755845219890757644>" - user: "<:reddit_users:755845303822974997>" - ok_hand: ":ok_hand:" icons: @@ -293,7 +288,6 @@ guild: duck_pond: 637821475327311927 incidents_archive: 720671599790915702 python_news: &PYNEWS_WEBHOOK 704381182279942324 - reddit: 635408384794951680 talent_pool: 569145364800602132 @@ -423,13 +417,6 @@ anti_spam: max: 3 -reddit: - client_id: !ENV "REDDIT_CLIENT_ID" - secret: !ENV "REDDIT_SECRET" - subreddits: - - 'r/Python' - - big_brother: header_message_limit: 15 log_delay: 15 diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 694d3a40f..115ddfb0d 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,7 +1,5 @@ -import asyncio import unittest from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch from dateutil.relativedelta import relativedelta @@ -56,17 +54,6 @@ class TimeTests(unittest.TestCase): """Testing format_infraction.""" self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') - @patch('asyncio.sleep', new_callable=AsyncMock) - def test_wait_until(self, mock): - """Testing wait_until.""" - start = datetime(2019, 1, 1, 0, 0) - then = datetime(2019, 1, 1, 0, 10) - - # No return value - self.assertIs(asyncio.run(time.wait_until(then, start)), None) - - mock.assert_called_once_with(10 * 60) - def test_format_infraction_with_duration_none_expiry(self): """format_infraction_with_duration should work for None expiry.""" test_cases = ( |