diff options
| -rw-r--r-- | azure-pipelines.yml | 2 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 82 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 2 | ||||
| -rw-r--r-- | docker-compose.yml | 2 |
5 files changed, 83 insertions, 7 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 7749d237f..7e2ba40d5 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -5,6 +5,7 @@ import textwrap from datetime import datetime from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group from discord.ext.tasks import loop @@ -19,10 +20,14 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + # Change your client's User-Agent string to something unique and descriptive, + # including the target platform, a unique application identifier, a version string, + # and your username as contact information, in the following format: + # <platform>:<app ID>:<version string> (by /u/<reddit username>) + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):v?.?.? (by /u/PythonDiscord)" URL = "https://www.reddit.com" - MAX_FETCH_RETRIES = 3 + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot @@ -47,6 +52,61 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) + async def get_access_tokens(self) -> None: + """Get Reddit access tokens.""" + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "temporary" + } + + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + self.access_token = content["access_token"] + self.headers = { + "Authorization": "bearer " + self.access_token, + "User-Agent": self.USER_AGENT + } + return + + await asyncio.sleep(3) + + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return + + async def revoke_access_token(self) -> None: + """Revoke the access token for Reddit API.""" + # Access tokens are valid for 1 hour. + # The token should be revoked, since the API is called only once a day. + headers = {"User-Agent": self.USER_AGENT} + data = { + "token": self.access_token, + "token_type_hint": "access_token" + } + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + self.headers = None + return + + log.warning(f"Unable to revoke access token, status code {response.status}.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -56,11 +116,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" - for _ in range(self.MAX_FETCH_RETRIES): + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -139,6 +199,8 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) + await self.get_access_tokens() + if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() # if it's a monday send the top weekly posts @@ -147,6 +209,8 @@ class Reddit(Cog): top_posts = await self.get_top_posts(subreddit=subreddit, time="day") await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts) + await self.revoke_access_token() + async def top_weekly_posts(self) -> None: """Post a summary of the top posts.""" for subreddit in RedditConfig.subreddits: @@ -178,25 +242,31 @@ class Reddit(Cog): async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" async with ctx.typing(): + await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="all") await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed) + await self.revoke_access_token() @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" async with ctx.typing(): + await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="day") await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed) + await self.revoke_access_token() @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" async with ctx.typing(): + await self.get_access_tokens() embed = await self.get_top_posts(subreddit=subreddit, time="week") await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) + await self.revoke_access_token() @with_role(*STAFF_ROLES) @reddit_group.command(name="subreddits", aliases=("subs",)) diff --git a/bot/constants.py b/bot/constants.py index 838fe7a79..b11ab65e9 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -446,6 +446,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 4638a89ee..bd85e1509 100644 --- a/config-default.yml +++ b/config-default.yml @@ -353,6 +353,8 @@ anti_malware: reddit: subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} |