diff options
| -rw-r--r-- | bot/__main__.py | 3 | ||||
| -rw-r--r-- | bot/cogs/antimalware.py | 10 | ||||
| -rw-r--r-- | bot/cogs/config_verifier.py | 40 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 4 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 9 | ||||
| -rw-r--r-- | bot/constants.py | 5 | ||||
| -rw-r--r-- | bot/utils/__init__.py | 12 | ||||
| -rw-r--r-- | tests/bot/test_utils.py | 15 | 
8 files changed, 65 insertions, 33 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index 0079a9381..3df477a6d 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -10,7 +10,7 @@ from bot.bot import Bot  from bot.constants import Bot as BotConfig  sentry_logging = LoggingIntegration( -    level=logging.TRACE, +    level=logging.DEBUG,      event_level=logging.WARNING  ) @@ -31,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler")  bot.load_extension("bot.cogs.filtering")  bot.load_extension("bot.cogs.logging")  bot.load_extension("bot.cogs.security") +bot.load_extension("bot.cogs.config_verifier")  # Commands, etc  bot.load_extension("bot.cogs.antimalware") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 28e3e5d96..9e9e81364 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -4,7 +4,7 @@ from discord import Embed, Message, NotFound  from discord.ext.commands import Cog  from bot.bot import Bot -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs  log = logging.getLogger(__name__) @@ -18,7 +18,13 @@ class AntiMalware(Cog):      @Cog.listener()      async def on_message(self, message: Message) -> None:          """Identify messages with prohibited attachments.""" -        if not message.attachments: +        # Return when message don't have attachment and don't moderate DMs +        if not message.attachments or not message.guild: +            return + +        # Check if user is staff, if is, return +        # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance +        if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles):              return          embed = Embed() diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): +    """Verify config on startup.""" + +    def __init__(self, bot: Bot): +        self.bot = bot +        self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + +    async def verify_channels(self) -> None: +        """ +        Verify channels. + +        If any channels in config aren't present in server, log them in a warning. +        """ +        await self.bot.wait_until_guild_available() +        server = self.bot.get_guild(constants.Guild.id) + +        server_channel_ids = {channel.id for channel in server.channels} +        invalid_channels = [ +            channel_name for channel_name, channel_id in constants.Channels +            if channel_id not in server_channel_ids +        ] + +        if invalid_channels: +            log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: +    """Load the ConfigVerifier cog.""" +    bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index f2964cd78..f74089056 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -129,7 +129,9 @@ class ModManagement(commands.Cog):          # Re-schedule infraction if the expiration has been updated          if 'expires_at' in request_data: -            self.infractions_cog.cancel_task(new_infraction['id']) +            # A scheduled task should only exist if the old infraction wasn't permanent +            if old_infraction['expires_at']: +                self.infractions_cog.cancel_task(new_infraction['id'])              # If the infraction was not marked as permanent, schedule a new expiration task              if request_data['expires_at']: diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 4f6584aba..5a7fa100f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -43,8 +43,8 @@ class Reddit(Cog):      def cog_unload(self) -> None:          """Stop the loop task and revoke the access token when the cog is unloaded."""          self.auto_poster_loop.cancel() -        if self.access_token.expires_at < datetime.utcnow(): -            self.revoke_access_token() +        if self.access_token and self.access_token.expires_at > datetime.utcnow(): +            asyncio.create_task(self.revoke_access_token())      async def init_reddit_ready(self) -> None:          """Sets the reddit webhook when the cog is loaded.""" @@ -83,7 +83,7 @@ class Reddit(Cog):                      expires_at=datetime.utcnow() + timedelta(seconds=expiration)                  ) -                log.debug(f"New token acquired; expires on {self.access_token.expires_at}") +                log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")                  return              else:                  log.debug( @@ -290,4 +290,7 @@ class Reddit(Cog):  def setup(bot: Bot) -> None:      """Load the Reddit cog.""" +    if not RedditConfig.secret or not RedditConfig.client_id: +        log.error("Credentials not provided, cog not loaded.") +        return      bot.add_cog(Reddit(bot)) diff --git a/bot/constants.py b/bot/constants.py index b1713aa60..14f8dc094 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -186,6 +186,11 @@ class YAMLGetter(type):      def __getitem__(cls, name):          return cls.__getattr__(name) +    def __iter__(cls): +        """Return generator of key: value pairs of current constants class' config values.""" +        for name in cls.__annotations__: +            yield name, getattr(cls, name) +  # Dataclasses  class Bot(metaclass=YAMLGetter): diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 8184be824..3e4b15ce4 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,5 @@  from abc import ABCMeta -from typing import Any, Generator, Hashable, Iterable +from typing import Any, Hashable  from discord.ext.commands import CogMeta @@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict):          for k in list(self.keys()):              v = super(CaseInsensitiveDict, self).pop(k)              self.__setitem__(k, v) - - -def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]: -    """ -    Generator that allows you to iterate over any indexable collection in `size`-length chunks. - -    Found: https://stackoverflow.com/a/312464/4022104 -    """ -    for i in range(0, len(iterable), size): -        yield iterable[i:i + size] diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):          instance = utils.CaseInsensitiveDict()          instance.update({'FOO': 'bar'})          self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): -    """Tests the `chunk` method.""" - -    def test_empty_chunking(self): -        """Tests chunking on an empty iterable.""" -        generator = utils.chunks(iterable=[], size=5) -        self.assertEqual(list(generator), []) - -    def test_list_chunking(self): -        """Tests chunking a non-empty list.""" -        iterable = [1, 2, 3, 4, 5] -        generator = utils.chunks(iterable=iterable, size=2) -        self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) | 
