diff options
| -rw-r--r-- | bot/__main__.py | 1 | ||||
| -rw-r--r-- | bot/cogs/config_verifier.py | 40 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 9 | ||||
| -rw-r--r-- | bot/constants.py | 5 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 12 | ||||
| -rw-r--r-- | tests/bot/test_utils.py | 15 | 
6 files changed, 53 insertions, 29 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index a3f1855b5..d21a1bcbc 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -31,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler")  bot.load_extension("bot.cogs.filtering")  bot.load_extension("bot.cogs.logging")  bot.load_extension("bot.cogs.security") +bot.load_extension("bot.cogs.config_verifier")  # Commands, etc  bot.load_extension("bot.cogs.antimalware") diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): +    """Verify config on startup.""" + +    def __init__(self, bot: Bot): +        self.bot = bot +        self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + +    async def verify_channels(self) -> None: +        """ +        Verify channels. + +        If any channels in config aren't present in server, log them in a warning. +        """ +        await self.bot.wait_until_guild_available() +        server = self.bot.get_guild(constants.Guild.id) + +        server_channel_ids = {channel.id for channel in server.channels} +        invalid_channels = [ +            channel_name for channel_name, channel_id in constants.Channels +            if channel_id not in server_channel_ids +        ] + +        if invalid_channels: +            log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: +    """Load the ConfigVerifier cog.""" +    bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 4f6584aba..5a7fa100f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -43,8 +43,8 @@ class Reddit(Cog):      def cog_unload(self) -> None:          """Stop the loop task and revoke the access token when the cog is unloaded."""          self.auto_poster_loop.cancel() -        if self.access_token.expires_at < datetime.utcnow(): -            self.revoke_access_token() +        if self.access_token and self.access_token.expires_at > datetime.utcnow(): +            asyncio.create_task(self.revoke_access_token())      async def init_reddit_ready(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" @@ -83,7 +83,7 @@ class Reddit(Cog):                      expires_at=datetime.utcnow() + timedelta(seconds=expiration)                  ) -                log.debug(f"New token acquired; expires on {self.access_token.expires_at}") +                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")                  return              else:                  log.debug( @@ -290,4 +290,7 @@ class Reddit(Cog):  def setup(bot: Bot) -> None:      """Load the Reddit cog.""" +    if not RedditConfig.secret or not RedditConfig.client_id: +        log.error("Credentials not provided, cog not loaded.") +        return      bot.add_cog(Reddit(bot)) diff --git a/bot/constants.py b/bot/constants.py index 9bc331dc4..ebd3b3d96 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -186,6 +186,11 @@ class YAMLGetter(type):      def __getitem__(cls, name):          return cls.__getattr__(name) +    def __iter__(cls): +        """Return generator of key: value pairs of current constants class' config values.""" +        for name in cls.__annotations__: +            yield name, getattr(cls, name) +  # Dataclasses  class Bot(metaclass=YAMLGetter): diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 8184be824..3e4b15ce4 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,5 @@  from abc import ABCMeta -from typing import Any, Generator, Hashable, Iterable +from typing import Any, Hashable  from discord.ext.commands import CogMeta @@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict):          for k in list(self.keys()):              v = super(CaseInsensitiveDict, self).pop(k)              self.__setitem__(k, v) - - -def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]: -    """ -    Generator that allows you to iterate over any indexable collection in `size`-length chunks. - -    Found: https://stackoverflow.com/a/312464/4022104 -    """ -    for i in range(0, len(iterable), size): -        yield iterable[i:i + size] diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):          instance = utils.CaseInsensitiveDict()          instance.update({'FOO': 'bar'})          self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): -    """Tests the `chunk` method.""" - -    def test_empty_chunking(self): -        """Tests chunking on an empty iterable.""" -        generator = utils.chunks(iterable=[], size=5) -        self.assertEqual(list(generator), []) - -    def test_list_chunking(self): -        """Tests chunking a non-empty list.""" -        iterable = [1, 2, 3, 4, 5] -        generator = utils.chunks(iterable=iterable, size=2) -        self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) | 
