diff options
| -rw-r--r-- | bot/__main__.py | 1 | ||||
| -rw-r--r-- | bot/cogs/config_verifier.py | 40 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 5 | ||||
| -rw-r--r-- | bot/constants.py | 4 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 12 | ||||
| -rw-r--r-- | tests/bot/test_utils.py | 15 |
6 files changed, 50 insertions, 27 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index 490163739..79f89b467 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -31,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") bot.load_extension("bot.cogs.logging") bot.load_extension("bot.cogs.security") +bot.load_extension("bot.cogs.config_verifier") # Commands, etc bot.load_extension("bot.cogs.antimalware") diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py new file mode 100644 index 000000000..cc19f7423 --- /dev/null +++ b/bot/cogs/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_ready() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 4f6584aba..e93e4de0c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -290,4 +290,7 @@ class Reddit(Cog): def setup(bot: Bot) -> None: """Load the Reddit cog.""" - bot.add_cog(Reddit(bot)) + if None not in (RedditConfig.client_id, RedditConfig.secret): + bot.add_cog(Reddit(bot)) + return + log.error("Credentials not provided, cog not loaded.") diff --git a/bot/constants.py b/bot/constants.py index 9bc331dc4..3776ceb84 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -186,6 +186,10 @@ class YAMLGetter(type): def __getitem__(cls, name): return cls.__getattr__(name) + def __iter__(cls): + """Returns iterator of key: value pairs of current constants class.""" + return iter(_CONFIG_YAML[cls.section][cls.subsection].items()) + # Dataclasses class Bot(metaclass=YAMLGetter): diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 8184be824..3e4b15ce4 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Generator, Hashable, Iterable +from typing import Any, Hashable from discord.ext.commands import CogMeta @@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict): for k in list(self.keys()): v = super(CaseInsensitiveDict, self).pop(k) self.__setitem__(k, v) - - -def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]: - """ - Generator that allows you to iterate over any indexable collection in `size`-length chunks. - - Found: https://stackoverflow.com/a/312464/4022104 - """ - for i in range(0, len(iterable), size): - yield iterable[i:i + size] diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase): instance = utils.CaseInsensitiveDict() instance.update({'FOO': 'bar'}) self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): - """Tests the `chunk` method.""" - - def test_empty_chunking(self): - """Tests chunking on an empty iterable.""" - generator = utils.chunks(iterable=[], size=5) - self.assertEqual(list(generator), []) - - def test_list_chunking(self): - """Tests chunking a non-empty list.""" - iterable = [1, 2, 3, 4, 5] - generator = utils.chunks(iterable=iterable, size=2) - self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) |