aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Xithrius <[email protected]>2021-05-09 13:49:31 -0700
committerGravatar GitHub <[email protected]>2021-05-09 13:49:31 -0700
commit5cc54b6eb9b3e20d23792d3b761ca85b4b0f22c4 (patch)
treec3233bfac93877105a0e0bf5b4ff1425cbbd1414
parentMerge pull request #1574 from python-discord/ping-bugs (diff)
parentMerge branch 'main' into annihilate_reddit (diff)
Merge pull request #1542 from RohanJnr/annihilate_reddit
Annihilate reddit cog
-rw-r--r--bot/constants.py13
-rw-r--r--bot/converters.py29
-rw-r--r--bot/exts/info/reddit.py308
-rw-r--r--bot/utils/time.py17
-rw-r--r--config-default.yml13
-rw-r--r--tests/bot/utils/test_time.py13
6 files changed, 0 insertions, 393 deletions
diff --git a/bot/constants.py b/bot/constants.py
index 7b2a38079..e1c3ade5a 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -307,10 +307,6 @@ class Emojis(metaclass=YAMLGetter):
new: str
pencil: str
- comments: str
- upvotes: str
- user: str
-
ok_hand: str
@@ -471,7 +467,6 @@ class Webhooks(metaclass=YAMLGetter):
dev_log: int
duck_pond: int
incidents_archive: int
- reddit: int
talent_pool: int
@@ -551,14 +546,6 @@ class URLs(metaclass=YAMLGetter):
paste_service: str
-class Reddit(metaclass=YAMLGetter):
- section = "reddit"
-
- client_id: Optional[str]
- secret: Optional[str]
- subreddits: list
-
-
class AntiSpam(metaclass=YAMLGetter):
section = 'anti_spam'
diff --git a/bot/converters.py b/bot/converters.py
index 3bf05cfb3..2a3943831 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -236,35 +236,6 @@ class Snowflake(IDConverter):
return snowflake
-class Subreddit(Converter):
- """Forces a string to begin with "r/" and checks if it's a valid subreddit."""
-
- @staticmethod
- async def convert(ctx: Context, sub: str) -> str:
- """
- Force sub to begin with "r/" and check if it's a valid subreddit.
-
- If sub is a valid subreddit, return it prepended with "r/"
- """
- sub = sub.lower()
-
- if not sub.startswith("r/"):
- sub = f"r/{sub}"
-
- resp = await ctx.bot.http_session.get(
- "https://www.reddit.com/subreddits/search.json",
- params={"q": sub}
- )
-
- json = await resp.json()
- if not json["data"]["children"]:
- raise BadArgument(
- f"The subreddit `{sub}` either doesn't exist, or it has no posts."
- )
-
- return sub
-
-
class TagNameConverter(Converter):
"""
Ensure that a proposed tag name is valid.
diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py
deleted file mode 100644
index e1f0c5f9f..000000000
--- a/bot/exts/info/reddit.py
+++ /dev/null
@@ -1,308 +0,0 @@
-import asyncio
-import logging
-import random
-import textwrap
-from collections import namedtuple
-from datetime import datetime, timedelta
-from html import unescape
-from typing import List
-
-from aiohttp import BasicAuth, ClientError
-from discord import Colour, Embed, TextChannel
-from discord.ext.commands import Cog, Context, group, has_any_role
-from discord.ext.tasks import loop
-from discord.utils import escape_markdown, sleep_until
-
-from bot.bot import Bot
-from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
-from bot.converters import Subreddit
-from bot.pagination import LinePaginator
-from bot.utils.messages import sub_clyde
-
-log = logging.getLogger(__name__)
-
-AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
-
-
-class Reddit(Cog):
- """Track subreddit posts and show detailed statistics about them."""
-
- HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
- URL = "https://www.reddit.com"
- OAUTH_URL = "https://oauth.reddit.com"
- MAX_RETRIES = 3
-
- def __init__(self, bot: Bot):
- self.bot = bot
-
- self.webhook = None
- self.access_token = None
- self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
-
- bot.loop.create_task(self.init_reddit_ready())
- self.auto_poster_loop.start()
-
- def cog_unload(self) -> None:
- """Stop the loop task and revoke the access token when the cog is unloaded."""
- self.auto_poster_loop.cancel()
- if self.access_token and self.access_token.expires_at > datetime.utcnow():
- self.bot.closing_tasks.append(asyncio.create_task(self.revoke_access_token()))
-
- async def init_reddit_ready(self) -> None:
- """Sets the reddit webhook when the cog is loaded."""
- await self.bot.wait_until_guild_available()
- if not self.webhook:
- self.webhook = await self.bot.fetch_webhook(Webhooks.reddit)
-
- @property
- def channel(self) -> TextChannel:
- """Get the #reddit channel object from the bot's cache."""
- return self.bot.get_channel(Channels.reddit)
-
- async def get_access_token(self) -> None:
- """
- Get a Reddit API OAuth2 access token and assign it to self.access_token.
-
- A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
- will be unloaded and a ClientError raised if retrieval was still unsuccessful.
- """
- for i in range(1, self.MAX_RETRIES + 1):
- response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/access_token",
- headers=self.HEADERS,
- auth=self.client_auth,
- data={
- "grant_type": "client_credentials",
- "duration": "temporary"
- }
- )
-
- if response.status == 200 and response.content_type == "application/json":
- content = await response.json()
- expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway.
- self.access_token = AccessToken(
- token=content["access_token"],
- expires_at=datetime.utcnow() + timedelta(seconds=expiration)
- )
-
- log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")
- return
- else:
- log.debug(
- f"Failed to get an access token: "
- f"status {response.status} & content type {response.content_type}; "
- f"retrying ({i}/{self.MAX_RETRIES})"
- )
-
- await asyncio.sleep(3)
-
- self.bot.remove_cog(self.qualified_name)
- raise ClientError("Authentication with the Reddit API failed. Unloading the cog.")
-
- async def revoke_access_token(self) -> None:
- """
- Revoke the OAuth2 access token for the Reddit API.
-
- For security reasons, it's good practice to revoke the token when it's no longer being used.
- """
- response = await self.bot.http_session.post(
- url=f"{self.URL}/api/v1/revoke_token",
- headers=self.HEADERS,
- auth=self.client_auth,
- data={
- "token": self.access_token.token,
- "token_type_hint": "access_token"
- }
- )
-
- if response.status == 204 and response.content_type == "application/json":
- self.access_token = None
- else:
- log.warning(f"Unable to revoke access token: status {response.status}.")
-
- async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
- """A helper method to fetch a certain amount of Reddit posts at a given route."""
- # Reddit's JSON responses only provide 25 posts at most.
- if not 25 >= amount > 0:
- raise ValueError("Invalid amount of subreddit posts requested.")
-
- # Renew the token if necessary.
- if not self.access_token or self.access_token.expires_at < datetime.utcnow():
- await self.get_access_token()
-
- url = f"{self.OAUTH_URL}/{route}"
- for _ in range(self.MAX_RETRIES):
- response = await self.bot.http_session.get(
- url=url,
- headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},
- params=params
- )
- if response.status == 200 and response.content_type == 'application/json':
- # Got appropriate response - process and return.
- content = await response.json()
- posts = content["data"]["children"]
-
- filtered_posts = [post for post in posts if not post["data"]["over_18"]]
-
- return filtered_posts[:amount]
-
- await asyncio.sleep(3)
-
- log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")
- return list() # Failed to get appropriate response within allowed number of retries.
-
- async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed:
- """
- Get the top amount of posts for a given subreddit within a specified timeframe.
-
- A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top
- weekly posts.
-
- The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most.
- """
- embed = Embed(description="")
-
- posts = await self.fetch_posts(
- route=f"{subreddit}/top",
- amount=amount,
- params={"t": time}
- )
- if not posts:
- embed.title = random.choice(ERROR_REPLIES)
- embed.colour = Colour.red()
- embed.description = (
- "Sorry! We couldn't find any SFW posts from that subreddit. "
- "If this problem persists, please let us know."
- )
-
- return embed
-
- for post in posts:
- data = post["data"]
-
- if text := unescape(data["selftext"]):
- text = textwrap.shorten(text, width=128, placeholder="...")
- text += "\n" # Add newline to separate embed info
-
- ups = data["ups"]
- comments = data["num_comments"]
- author = data["author"]
-
- title = textwrap.shorten(unescape(data["title"]), width=64, placeholder="...")
- # Normal brackets interfere with Markdown.
- title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌")
- link = self.URL + data["permalink"]
-
- embed.description += (
- f"**[{title}]({link})**\n"
- f"{text}"
- f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n"
- )
-
- embed.colour = Colour.blurple()
- return embed
-
- @loop()
- async def auto_poster_loop(self) -> None:
- """Post the top 5 posts daily, and the top 5 posts weekly."""
- # once d.py get support for `time` parameter in loop decorator,
- # this can be removed and the loop can use the `time=datetime.time.min` parameter
- now = datetime.utcnow()
- tomorrow = now + timedelta(days=1)
- midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)
-
- await sleep_until(midnight_tomorrow)
-
- await self.bot.wait_until_guild_available()
- if not self.webhook:
- await self.bot.fetch_webhook(Webhooks.reddit)
-
- if datetime.utcnow().weekday() == 0:
- await self.top_weekly_posts()
- # if it's a monday send the top weekly posts
-
- for subreddit in RedditConfig.subreddits:
- top_posts = await self.get_top_posts(subreddit=subreddit, time="day")
- username = sub_clyde(f"{subreddit} Top Daily Posts")
- message = await self.webhook.send(username=username, embed=top_posts, wait=True)
-
- if message.channel.is_news():
- await message.publish()
-
- async def top_weekly_posts(self) -> None:
- """Post a summary of the top posts."""
- for subreddit in RedditConfig.subreddits:
- # Send and pin the new weekly posts.
- top_posts = await self.get_top_posts(subreddit=subreddit, time="week")
- username = sub_clyde(f"{subreddit} Top Weekly Posts")
- message = await self.webhook.send(wait=True, username=username, embed=top_posts)
-
- if subreddit.lower() == "r/python":
- if not self.channel:
- log.warning("Failed to get #reddit channel to remove pins in the weekly loop.")
- return
-
- # Remove the oldest pins so that only 12 remain at most.
- pins = await self.channel.pins()
-
- while len(pins) >= 12:
- await pins[-1].unpin()
- del pins[-1]
-
- await message.pin()
-
- if message.channel.is_news():
- await message.publish()
-
- @group(name="reddit", invoke_without_command=True)
- async def reddit_group(self, ctx: Context) -> None:
- """View the top posts from various subreddits."""
- await ctx.send_help(ctx.command)
-
- @reddit_group.command(name="top")
- async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
- """Send the top posts of all time from a given subreddit."""
- async with ctx.typing():
- embed = await self.get_top_posts(subreddit=subreddit, time="all")
-
- await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed)
-
- @reddit_group.command(name="daily")
- async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
- """Send the top posts of today from a given subreddit."""
- async with ctx.typing():
- embed = await self.get_top_posts(subreddit=subreddit, time="day")
-
- await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed)
-
- @reddit_group.command(name="weekly")
- async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
- """Send the top posts of this week from a given subreddit."""
- async with ctx.typing():
- embed = await self.get_top_posts(subreddit=subreddit, time="week")
-
- await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)
-
- @has_any_role(*STAFF_ROLES)
- @reddit_group.command(name="subreddits", aliases=("subs",))
- async def subreddits_command(self, ctx: Context) -> None:
- """Send a paginated embed of all the subreddits we're relaying."""
- embed = Embed()
- embed.title = "Relayed subreddits."
- embed.colour = Colour.blurple()
-
- await LinePaginator.paginate(
- RedditConfig.subreddits,
- ctx, embed,
- footer_text="Use the reddit commands along with these to view their posts.",
- empty=False,
- max_lines=15
- )
-
-
-def setup(bot: Bot) -> None:
- """Load the Reddit cog."""
- if not RedditConfig.secret or not RedditConfig.client_id:
- log.error("Credentials not provided, cog not loaded.")
- return
- bot.add_cog(Reddit(bot))
diff --git a/bot/utils/time.py b/bot/utils/time.py
index 466f0adc2..d55a0e532 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -1,4 +1,3 @@
-import asyncio
import datetime
import re
from typing import Optional
@@ -144,22 +143,6 @@ def parse_rfc1123(stamp: str) -> datetime.datetime:
return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc)
-# Hey, this could actually be used in the off_topic_names and reddit cogs :)
-async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] = None) -> None:
- """
- Wait until a given time.
-
- :param time: A datetime.datetime object to wait until.
- :param start: The start from which to calculate the waiting duration. Defaults to UTC time.
- """
- delay = time - (start or datetime.datetime.utcnow())
- delay_seconds = delay.total_seconds()
-
- # Incorporate a small delay so we don't rapid-fire the event due to time precision errors
- if delay_seconds > 1.0:
- await asyncio.sleep(delay_seconds)
-
-
def format_infraction(timestamp: str) -> str:
"""Format an infraction timestamp to a more readable ISO 8601 format."""
return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT)
diff --git a/config-default.yml b/config-default.yml
index 46475f845..c5c9b12ce 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -73,11 +73,6 @@ style:
new: "\U0001F195"
pencil: "\u270F"
- # emotes used for #reddit
- comments: "<:reddit_comments:755845255001014384>"
- upvotes: "<:reddit_upvotes:755845219890757644>"
- user: "<:reddit_users:755845303822974997>"
-
ok_hand: ":ok_hand:"
icons:
@@ -293,7 +288,6 @@ guild:
duck_pond: 637821475327311927
incidents_archive: 720671599790915702
python_news: &PYNEWS_WEBHOOK 704381182279942324
- reddit: 635408384794951680
talent_pool: 569145364800602132
@@ -423,13 +417,6 @@ anti_spam:
max: 3
-reddit:
- client_id: !ENV "REDDIT_CLIENT_ID"
- secret: !ENV "REDDIT_SECRET"
- subreddits:
- - 'r/Python'
-
-
big_brother:
header_message_limit: 15
log_delay: 15
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 694d3a40f..115ddfb0d 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -1,7 +1,5 @@
-import asyncio
import unittest
from datetime import datetime, timezone
-from unittest.mock import AsyncMock, patch
from dateutil.relativedelta import relativedelta
@@ -56,17 +54,6 @@ class TimeTests(unittest.TestCase):
"""Testing format_infraction."""
self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
- @patch('asyncio.sleep', new_callable=AsyncMock)
- def test_wait_until(self, mock):
- """Testing wait_until."""
- start = datetime(2019, 1, 1, 0, 0)
- then = datetime(2019, 1, 1, 0, 10)
-
- # No return value
- self.assertIs(asyncio.run(time.wait_until(then, start)), None)
-
- mock.assert_called_once_with(10 * 60)
-
def test_format_infraction_with_duration_none_expiry(self):
"""format_infraction_with_duration should work for None expiry."""
test_cases = (