diff options
| -rw-r--r-- | bot/constants.py | 13 | ||||
| -rw-r--r-- | bot/converters.py | 29 | ||||
| -rw-r--r-- | bot/exts/info/reddit.py | 308 | ||||
| -rw-r--r-- | bot/utils/time.py | 17 | ||||
| -rw-r--r-- | config-default.yml | 13 | ||||
| -rw-r--r-- | tests/bot/utils/test_time.py | 13 | 
6 files changed, 0 insertions, 393 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 7b2a38079..e1c3ade5a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -307,10 +307,6 @@ class Emojis(metaclass=YAMLGetter):      new: str      pencil: str -    comments: str -    upvotes: str -    user: str -      ok_hand: str @@ -471,7 +467,6 @@ class Webhooks(metaclass=YAMLGetter):      dev_log: int      duck_pond: int      incidents_archive: int -    reddit: int      talent_pool: int @@ -551,14 +546,6 @@ class URLs(metaclass=YAMLGetter):      paste_service: str -class Reddit(metaclass=YAMLGetter): -    section = "reddit" - -    client_id: Optional[str] -    secret: Optional[str] -    subreddits: list - -  class AntiSpam(metaclass=YAMLGetter):      section = 'anti_spam' diff --git a/bot/converters.py b/bot/converters.py index 3bf05cfb3..2a3943831 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -236,35 +236,6 @@ class Snowflake(IDConverter):          return snowflake -class Subreddit(Converter): -    """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" - -    @staticmethod -    async def convert(ctx: Context, sub: str) -> str: -        """ -        Force sub to begin with "r/" and check if it's a valid subreddit. - -        If sub is a valid subreddit, return it prepended with "r/" -        """ -        sub = sub.lower() - -        if not sub.startswith("r/"): -            sub = f"r/{sub}" - -        resp = await ctx.bot.http_session.get( -            "https://www.reddit.com/subreddits/search.json", -            params={"q": sub} -        ) - -        json = await resp.json() -        if not json["data"]["children"]: -            raise BadArgument( -                f"The subreddit `{sub}` either doesn't exist, or it has no posts." -            ) - -        return sub - -  class TagNameConverter(Converter):      """      Ensure that a proposed tag name is valid. diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py deleted file mode 100644 index e1f0c5f9f..000000000 --- a/bot/exts/info/reddit.py +++ /dev/null @@ -1,308 +0,0 @@ -import asyncio -import logging -import random -import textwrap -from collections import namedtuple -from datetime import datetime, timedelta -from html import unescape -from typing import List - -from aiohttp import BasicAuth, ClientError -from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group, has_any_role -from discord.ext.tasks import loop -from discord.utils import escape_markdown, sleep_until - -from bot.bot import Bot -from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks -from bot.converters import Subreddit -from bot.pagination import LinePaginator -from bot.utils.messages import sub_clyde - -log = logging.getLogger(__name__) - -AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) - - -class Reddit(Cog): -    """Track subreddit posts and show detailed statistics about them.""" - -    HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} -    URL = "https://www.reddit.com" -    OAUTH_URL = "https://oauth.reddit.com" -    MAX_RETRIES = 3 - -    def __init__(self, bot: Bot): -        self.bot = bot - -        self.webhook = None -        self.access_token = None -        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) - -        bot.loop.create_task(self.init_reddit_ready()) -        self.auto_poster_loop.start() - -    def cog_unload(self) -> None: -        """Stop the loop task and revoke the access token when the cog is unloaded.""" -        self.auto_poster_loop.cancel() -        if self.access_token and self.access_token.expires_at > datetime.utcnow(): -            self.bot.closing_tasks.append(asyncio.create_task(self.revoke_access_token())) - -    async def init_reddit_ready(self) -> None: -        """Sets the reddit webhook when the cog is loaded.""" -        await self.bot.wait_until_guild_available() -        if not self.webhook: -            self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) - -    @property -    def channel(self) -> TextChannel: -        """Get the #reddit channel object from the bot's cache.""" -        return self.bot.get_channel(Channels.reddit) - -    async def get_access_token(self) -> None: -        """ -        Get a Reddit API OAuth2 access token and assign it to self.access_token. - -        A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog -        will be unloaded and a ClientError raised if retrieval was still unsuccessful. -        """ -        for i in range(1, self.MAX_RETRIES + 1): -            response = await self.bot.http_session.post( -                url=f"{self.URL}/api/v1/access_token", -                headers=self.HEADERS, -                auth=self.client_auth, -                data={ -                    "grant_type": "client_credentials", -                    "duration": "temporary" -                } -            ) - -            if response.status == 200 and response.content_type == "application/json": -                content = await response.json() -                expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway. -                self.access_token = AccessToken( -                    token=content["access_token"], -                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) -                ) - -                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") -                return -            else: -                log.debug( -                    f"Failed to get an access token: " -                    f"status {response.status} & content type {response.content_type}; " -                    f"retrying ({i}/{self.MAX_RETRIES})" -                ) - -            await asyncio.sleep(3) - -        self.bot.remove_cog(self.qualified_name) -        raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") - -    async def revoke_access_token(self) -> None: -        """ -        Revoke the OAuth2 access token for the Reddit API. - -        For security reasons, it's good practice to revoke the token when it's no longer being used. -        """ -        response = await self.bot.http_session.post( -            url=f"{self.URL}/api/v1/revoke_token", -            headers=self.HEADERS, -            auth=self.client_auth, -            data={ -                "token": self.access_token.token, -                "token_type_hint": "access_token" -            } -        ) - -        if response.status == 204 and response.content_type == "application/json": -            self.access_token = None -        else: -            log.warning(f"Unable to revoke access token: status {response.status}.") - -    async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: -        """A helper method to fetch a certain amount of Reddit posts at a given route.""" -        # Reddit's JSON responses only provide 25 posts at most. -        if not 25 >= amount > 0: -            raise ValueError("Invalid amount of subreddit posts requested.") - -        # Renew the token if necessary. -        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): -            await self.get_access_token() - -        url = f"{self.OAUTH_URL}/{route}" -        for _ in range(self.MAX_RETRIES): -            response = await self.bot.http_session.get( -                url=url, -                headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, -                params=params -            ) -            if response.status == 200 and response.content_type == 'application/json': -                # Got appropriate response - process and return. -                content = await response.json() -                posts = content["data"]["children"] - -                filtered_posts = [post for post in posts if not post["data"]["over_18"]] - -                return filtered_posts[:amount] - -            await asyncio.sleep(3) - -        log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}") -        return list()  # Failed to get appropriate response within allowed number of retries. - -    async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed: -        """ -        Get the top amount of posts for a given subreddit within a specified timeframe. - -        A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top -        weekly posts. - -        The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most. -        """ -        embed = Embed(description="") - -        posts = await self.fetch_posts( -            route=f"{subreddit}/top", -            amount=amount, -            params={"t": time} -        ) -        if not posts: -            embed.title = random.choice(ERROR_REPLIES) -            embed.colour = Colour.red() -            embed.description = ( -                "Sorry! We couldn't find any SFW posts from that subreddit. " -                "If this problem persists, please let us know." -            ) - -            return embed - -        for post in posts: -            data = post["data"] - -            if text := unescape(data["selftext"]): -                text = textwrap.shorten(text, width=128, placeholder="...") -                text += "\n"  # Add newline to separate embed info - -            ups = data["ups"] -            comments = data["num_comments"] -            author = data["author"] - -            title = textwrap.shorten(unescape(data["title"]), width=64, placeholder="...") -            # Normal brackets interfere with Markdown. -            title = escape_markdown(title).replace("[", "⦋").replace("]", "⦌") -            link = self.URL + data["permalink"] - -            embed.description += ( -                f"**[{title}]({link})**\n" -                f"{text}" -                f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n" -            ) - -        embed.colour = Colour.blurple() -        return embed - -    @loop() -    async def auto_poster_loop(self) -> None: -        """Post the top 5 posts daily, and the top 5 posts weekly.""" -        # once d.py get support for `time` parameter in loop decorator, -        # this can be removed and the loop can use the `time=datetime.time.min` parameter -        now = datetime.utcnow() -        tomorrow = now + timedelta(days=1) -        midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0) - -        await sleep_until(midnight_tomorrow) - -        await self.bot.wait_until_guild_available() -        if not self.webhook: -            await self.bot.fetch_webhook(Webhooks.reddit) - -        if datetime.utcnow().weekday() == 0: -            await self.top_weekly_posts() -            # if it's a monday send the top weekly posts - -        for subreddit in RedditConfig.subreddits: -            top_posts = await self.get_top_posts(subreddit=subreddit, time="day") -            username = sub_clyde(f"{subreddit} Top Daily Posts") -            message = await self.webhook.send(username=username, embed=top_posts, wait=True) - -            if message.channel.is_news(): -                await message.publish() - -    async def top_weekly_posts(self) -> None: -        """Post a summary of the top posts.""" -        for subreddit in RedditConfig.subreddits: -            # Send and pin the new weekly posts. -            top_posts = await self.get_top_posts(subreddit=subreddit, time="week") -            username = sub_clyde(f"{subreddit} Top Weekly Posts") -            message = await self.webhook.send(wait=True, username=username, embed=top_posts) - -            if subreddit.lower() == "r/python": -                if not self.channel: -                    log.warning("Failed to get #reddit channel to remove pins in the weekly loop.") -                    return - -                # Remove the oldest pins so that only 12 remain at most. -                pins = await self.channel.pins() - -                while len(pins) >= 12: -                    await pins[-1].unpin() -                    del pins[-1] - -                await message.pin() - -                if message.channel.is_news(): -                    await message.publish() - -    @group(name="reddit", invoke_without_command=True) -    async def reddit_group(self, ctx: Context) -> None: -        """View the top posts from various subreddits.""" -        await ctx.send_help(ctx.command) - -    @reddit_group.command(name="top") -    async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: -        """Send the top posts of all time from a given subreddit.""" -        async with ctx.typing(): -            embed = await self.get_top_posts(subreddit=subreddit, time="all") - -        await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) - -    @reddit_group.command(name="daily") -    async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: -        """Send the top posts of today from a given subreddit.""" -        async with ctx.typing(): -            embed = await self.get_top_posts(subreddit=subreddit, time="day") - -        await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) - -    @reddit_group.command(name="weekly") -    async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: -        """Send the top posts of this week from a given subreddit.""" -        async with ctx.typing(): -            embed = await self.get_top_posts(subreddit=subreddit, time="week") - -        await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) - -    @has_any_role(*STAFF_ROLES) -    @reddit_group.command(name="subreddits", aliases=("subs",)) -    async def subreddits_command(self, ctx: Context) -> None: -        """Send a paginated embed of all the subreddits we're relaying.""" -        embed = Embed() -        embed.title = "Relayed subreddits." -        embed.colour = Colour.blurple() - -        await LinePaginator.paginate( -            RedditConfig.subreddits, -            ctx, embed, -            footer_text="Use the reddit commands along with these to view their posts.", -            empty=False, -            max_lines=15 -        ) - - -def setup(bot: Bot) -> None: -    """Load the Reddit cog.""" -    if not RedditConfig.secret or not RedditConfig.client_id: -        log.error("Credentials not provided, cog not loaded.") -        return -    bot.add_cog(Reddit(bot)) diff --git a/bot/utils/time.py b/bot/utils/time.py index 466f0adc2..d55a0e532 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,4 +1,3 @@ -import asyncio  import datetime  import re  from typing import Optional @@ -144,22 +143,6 @@ def parse_rfc1123(stamp: str) -> datetime.datetime:      return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc) -# Hey, this could actually be used in the off_topic_names and reddit cogs :) -async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] = None) -> None: -    """ -    Wait until a given time. - -    :param time: A datetime.datetime object to wait until. -    :param start: The start from which to calculate the waiting duration. Defaults to UTC time. -    """ -    delay = time - (start or datetime.datetime.utcnow()) -    delay_seconds = delay.total_seconds() - -    # Incorporate a small delay so we don't rapid-fire the event due to time precision errors -    if delay_seconds > 1.0: -        await asyncio.sleep(delay_seconds) - -  def format_infraction(timestamp: str) -> str:      """Format an infraction timestamp to a more readable ISO 8601 format."""      return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) diff --git a/config-default.yml b/config-default.yml index 46475f845..c5c9b12ce 100644 --- a/config-default.yml +++ b/config-default.yml @@ -73,11 +73,6 @@ style:          new:        "\U0001F195"          pencil:     "\u270F" -        # emotes used for #reddit -        comments:       "<:reddit_comments:755845255001014384>" -        upvotes:        "<:reddit_upvotes:755845219890757644>" -        user:           "<:reddit_users:755845303822974997>" -          ok_hand: ":ok_hand:"      icons: @@ -293,7 +288,6 @@ guild:          duck_pond:                          637821475327311927          incidents_archive:                  720671599790915702          python_news:        &PYNEWS_WEBHOOK 704381182279942324 -        reddit:                             635408384794951680          talent_pool:                        569145364800602132 @@ -423,13 +417,6 @@ anti_spam:              max: 3 -reddit: -    client_id: !ENV "REDDIT_CLIENT_ID" -    secret: !ENV "REDDIT_SECRET" -    subreddits: -        - 'r/Python' - -  big_brother:      header_message_limit: 15      log_delay: 15 diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 694d3a40f..115ddfb0d 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,7 +1,5 @@ -import asyncio  import unittest  from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch  from dateutil.relativedelta import relativedelta @@ -56,17 +54,6 @@ class TimeTests(unittest.TestCase):          """Testing format_infraction."""          self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') -    @patch('asyncio.sleep', new_callable=AsyncMock) -    def test_wait_until(self, mock): -        """Testing wait_until.""" -        start = datetime(2019, 1, 1, 0, 0) -        then = datetime(2019, 1, 1, 0, 10) - -        # No return value -        self.assertIs(asyncio.run(time.wait_until(then, start)), None) - -        mock.assert_called_once_with(10 * 60) -      def test_format_infraction_with_duration_none_expiry(self):          """format_infraction_with_duration should work for None expiry."""          test_cases = ( | 
