diff options
| -rw-r--r-- | azure-pipelines.yml | 2 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 95 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 2 | ||||
| -rw-r--r-- | docker-compose.yml | 2 |
5 files changed, 96 insertions, 7 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0d06e9c26..22bb66bf0 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,9 +2,11 @@ import asyncio import logging import random import textwrap +from collections import namedtuple from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth from discord import Colour, Embed, TextChannel from discord.ext.commands import Bot, Cog, Context, group from discord.ext.tasks import loop @@ -20,14 +22,16 @@ log = logging.getLogger(__name__) class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + USER_AGENT = "docker-python3:Discord Bot of PythonDiscord (https://pythondiscord.com/):1.0.0 (by /u/PythonDiscord)" URL = "https://www.reddit.com" - MAX_FETCH_RETRIES = 3 + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready + self.webhook = None + self.access_token = None bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() @@ -35,6 +39,8 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stops the loops when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -47,6 +53,66 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) + async def get_access_token(self) -> None: + """Get Reddit access tokens.""" + headers = {"User-Agent": self.USER_AGENT} + data = { + "grant_type": "client_credentials", + "duration": "temporary" + } + + log.info(f"{RedditConfig.client_id}, {RedditConfig.secret}") + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + + for _ in range(self.MAX_RETRIES): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(hours=1) + ) + self.headers = { + "Authorization": "bearer " + self.access_token.token, + "User-Agent": self.USER_AGENT + } + return + + await asyncio.sleep(3) + + log.error("Authentication with Reddit API failed. Unloading extension.") + self.bot.remove_cog(self.__class__.__name__) + return + + async def revoke_access_token(self) -> None: + """Revoke the access token for Reddit API.""" + # Access tokens are valid for 1 hour. + # The token should be revoked, since the API is called only once a day. + headers = {"User-Agent": self.USER_AGENT} + data = { + "token": self.access_token.token, + "token_type_hint": "access_token" + } + + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=headers, + auth=self.client_auth, + data=data + ) + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + self.headers = None + return + + log.warning(f"Unable to revoke access token, status code {response.status}.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. @@ -56,11 +122,11 @@ class Reddit(Cog): if params is None: params = {} - url = f"{self.URL}/{route}.json" - for _ in range(self.MAX_FETCH_RETRIES): + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers=self.headers, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -140,6 +206,11 @@ class Reddit(Cog): if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() + if datetime.utcnow().weekday() == 0: await self.top_weekly_posts() # if it's a monday send the top weekly posts @@ -178,6 +249,10 @@ class Reddit(Cog): @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of all time from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="all") @@ -186,6 +261,10 @@ class Reddit(Cog): @reddit_group.command(name="daily") async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of today from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="day") @@ -194,6 +273,10 @@ class Reddit(Cog): @reddit_group.command(name="weekly") async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: """Send the top posts of this week from a given subreddit.""" + if not self.access_token: + await self.get_access_token() + elif self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() async with ctx.typing(): embed = await self.get_top_posts(subreddit=subreddit, time="week") diff --git a/bot/constants.py b/bot/constants.py index 89504a2e0..ed85adf6a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -465,6 +465,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 930a1a0e6..e6f0fda21 100644 --- a/config-default.yml +++ b/config-default.yml @@ -365,6 +365,8 @@ anti_malware: reddit: subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} |