diff options
| author | 2019-12-13 10:06:55 +1000 | |
|---|---|---|
| committer | 2019-12-13 10:06:55 +1000 | |
| commit | ea9299bc51dc56d1b9775ac757868a61bcb98ad2 (patch) | |
| tree | 74569b82d99cf0c54089e480137a832f5cec5e62 | |
| parent | Display time left until expiration of infraction (#679) (diff) | |
| parent | Merge branch 'master' into reddit-api-oauth (diff) | |
Use OAuth to be Reddit API compliant (#510)
Use OAuth to be Reddit API compliant
Co-authored-by: Jens <[email protected]>
Co-authored-by: Mark <[email protected]>
Co-authored-by: null <[email protected]>
| -rw-r--r-- | azure-pipelines.yml | 2 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 91 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 2 | ||||
| -rw-r--r-- | docker-compose.yml | 2 | 
5 files changed, 88 insertions, 11 deletions
| diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs:        - script: python -m flake8          displayName: 'Run linter' -      - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner +      - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner          displayName: Run tests        - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index bec316ae7..aa487f18e 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,9 +2,11 @@ import asyncio  import logging  import random  import textwrap +from collections import namedtuple  from datetime import datetime, timedelta  from typing import List +from aiohttp import BasicAuth, ClientError  from discord import Colour, Embed, TextChannel  from discord.ext.commands import Cog, Context, group  from discord.ext.tasks import loop @@ -17,25 +19,32 @@ from bot.pagination import LinePaginator  log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) +  class Reddit(Cog):      """Track subreddit posts and show detailed statistics about them.""" -    HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} +    HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}      URL = "https://www.reddit.com" -    MAX_FETCH_RETRIES = 3 +    OAUTH_URL = "https://oauth.reddit.com" +    MAX_RETRIES = 3      def __init__(self, bot: Bot):          self.bot = bot -        self.webhook = None  # set in on_ready -        bot.loop.create_task(self.init_reddit_ready()) +        self.webhook = None +        self.access_token = None +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) +        bot.loop.create_task(self.init_reddit_ready())          self.auto_poster_loop.start()      def cog_unload(self) -> None: -        """Stops the loops when the cog is unloaded.""" +        """Stop the loop task and revoke the access token when the cog is unloaded."""          self.auto_poster_loop.cancel() +        if self.access_token.expires_at < datetime.utcnow(): +            self.revoke_access_token()      async def init_reddit_ready(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" @@ -48,20 +57,82 @@ class Reddit(Cog):          """Get the #reddit channel object from the bot's cache."""          return self.bot.get_channel(Channels.reddit) +    async def get_access_token(self) -> None: +        """ +        Get a Reddit API OAuth2 access token and assign it to self.access_token. + +        A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog +        will be unloaded and a ClientError raised if retrieval was still unsuccessful. +        """ +        for i in range(1, self.MAX_RETRIES + 1): +            response = await self.bot.http_session.post( +                url=f"{self.URL}/api/v1/access_token", +                headers=self.HEADERS, +                auth=self.client_auth, +                data={ +                    "grant_type": "client_credentials", +                    "duration": "temporary" +                } +            ) + +            if response.status == 200 and response.content_type == "application/json": +                content = await response.json() +                expiration = int(content["expires_in"]) - 60  # Subtract 1 minute for leeway. +                self.access_token = AccessToken( +                    token=content["access_token"], +                    expires_at=datetime.utcnow() + timedelta(seconds=expiration) +                ) + +                log.debug(f"New token acquired; expires on {self.access_token.expires_at}") +                return +            else: +                log.debug( +                    f"Failed to get an access token: " +                    f"status {response.status} & content type {response.content_type}; " +                    f"retrying ({i}/{self.MAX_RETRIES})" +                ) + +            await asyncio.sleep(3) + +        self.bot.remove_cog(self.qualified_name) +        raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + +    async def revoke_access_token(self) -> None: +        """ +        Revoke the OAuth2 access token for the Reddit API. + +        For security reasons, it's good practice to revoke the token when it's no longer being used. +        """ +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/revoke_token", +            headers=self.HEADERS, +            auth=self.client_auth, +            data={ +                "token": self.access_token.token, +                "token_type_hint": "access_token" +            } +        ) + +        if response.status == 204 and response.content_type == "application/json": +            self.access_token = None +        else: +            log.warning(f"Unable to revoke access token: status {response.status}.") +      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route."""          # Reddit's JSON responses only provide 25 posts at most.          if not 25 >= amount > 0:              raise ValueError("Invalid amount of subreddit posts requested.") -        if params is None: -            params = {} +        # Renew the token if necessary. +        if not self.access_token or self.access_token.expires_at < datetime.utcnow(): +            await self.get_access_token() -        url = f"{self.URL}/{route}.json" -        for _ in range(self.MAX_FETCH_RETRIES): +        url = f"{self.OAUTH_URL}/{route}" +        for _ in range(self.MAX_RETRIES):              response = await self.bot.http_session.get(                  url=url, -                headers=self.HEADERS, +                headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},                  params=params              )              if response.status == 200 and response.content_type == 'application/json': diff --git a/bot/constants.py b/bot/constants.py index 89504a2e0..ed85adf6a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -465,6 +465,8 @@ class Reddit(metaclass=YAMLGetter):      section = "reddit"      subreddits: list +    client_id: str +    secret: str  class Wolfram(metaclass=YAMLGetter): diff --git a/config-default.yml b/config-default.yml index 930a1a0e6..e6f0fda21 100644 --- a/config-default.yml +++ b/config-default.yml @@ -365,6 +365,8 @@ anti_malware:  reddit:      subreddits:          - 'r/Python' +    client_id: !ENV "REDDIT_CLIENT_ID" +    secret:    !ENV "REDDIT_SECRET"  wolfram: diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services:      environment:        BOT_TOKEN: ${BOT_TOKEN}        BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 +      REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} +      REDDIT_SECRET: ${REDDIT_SECRET} | 
