aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/__main__.py3
-rw-r--r--bot/cogs/antimalware.py10
-rw-r--r--bot/cogs/config_verifier.py40
-rw-r--r--bot/cogs/moderation/management.py4
-rw-r--r--bot/cogs/reddit.py9
-rw-r--r--bot/constants.py5
-rw-r--r--bot/utils/__init__.py12
-rw-r--r--tests/bot/test_utils.py15
8 files changed, 65 insertions, 33 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index 490163739..d21a1bcbc 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -10,7 +10,7 @@ from bot.bot import Bot
from bot.constants import Bot as BotConfig, DEBUG_MODE
sentry_logging = LoggingIntegration(
- level=logging.TRACE,
+ level=logging.DEBUG,
event_level=logging.WARNING
)
@@ -31,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
bot.load_extension("bot.cogs.logging")
bot.load_extension("bot.cogs.security")
+bot.load_extension("bot.cogs.config_verifier")
# Commands, etc
bot.load_extension("bot.cogs.antimalware")
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 28e3e5d96..9e9e81364 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -4,7 +4,7 @@ from discord import Embed, Message, NotFound
from discord.ext.commands import Cog
from bot.bot import Bot
-from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs
log = logging.getLogger(__name__)
@@ -18,7 +18,13 @@ class AntiMalware(Cog):
@Cog.listener()
async def on_message(self, message: Message) -> None:
"""Identify messages with prohibited attachments."""
- if not message.attachments:
+ # Return when message don't have attachment and don't moderate DMs
+ if not message.attachments or not message.guild:
+ return
+
+ # Check if user is staff, if is, return
+ # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
+ if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles):
return
embed = Embed()
diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py
new file mode 100644
index 000000000..d72c6c22e
--- /dev/null
+++ b/bot/cogs/config_verifier.py
@@ -0,0 +1,40 @@
+import logging
+
+from discord.ext.commands import Cog
+
+from bot import constants
+from bot.bot import Bot
+
+
+log = logging.getLogger(__name__)
+
+
+class ConfigVerifier(Cog):
+ """Verify config on startup."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.channel_verify_task = self.bot.loop.create_task(self.verify_channels())
+
+ async def verify_channels(self) -> None:
+ """
+ Verify channels.
+
+ If any channels in config aren't present in server, log them in a warning.
+ """
+ await self.bot.wait_until_guild_available()
+ server = self.bot.get_guild(constants.Guild.id)
+
+ server_channel_ids = {channel.id for channel in server.channels}
+ invalid_channels = [
+ channel_name for channel_name, channel_id in constants.Channels
+ if channel_id not in server_channel_ids
+ ]
+
+ if invalid_channels:
+ log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.")
+
+
+def setup(bot: Bot) -> None:
+ """Load the ConfigVerifier cog."""
+ bot.add_cog(ConfigVerifier(bot))
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index f2964cd78..f74089056 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -129,7 +129,9 @@ class ModManagement(commands.Cog):
# Re-schedule infraction if the expiration has been updated
if 'expires_at' in request_data:
- self.infractions_cog.cancel_task(new_infraction['id'])
+ # A scheduled task should only exist if the old infraction wasn't permanent
+ if old_infraction['expires_at']:
+ self.infractions_cog.cancel_task(new_infraction['id'])
# If the infraction was not marked as permanent, schedule a new expiration task
if request_data['expires_at']:
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 4f6584aba..5a7fa100f 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -43,8 +43,8 @@ class Reddit(Cog):
def cog_unload(self) -> None:
"""Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
- if self.access_token.expires_at < datetime.utcnow():
- self.revoke_access_token()
+ if self.access_token and self.access_token.expires_at > datetime.utcnow():
+ asyncio.create_task(self.revoke_access_token())
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
@@ -83,7 +83,7 @@ class Reddit(Cog):
expires_at=datetime.utcnow() + timedelta(seconds=expiration)
)
- log.debug(f"New token acquired; expires on {self.access_token.expires_at}")
+ log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")
return
else:
log.debug(
@@ -290,4 +290,7 @@ class Reddit(Cog):
def setup(bot: Bot) -> None:
"""Load the Reddit cog."""
+ if not RedditConfig.secret or not RedditConfig.client_id:
+ log.error("Credentials not provided, cog not loaded.")
+ return
bot.add_cog(Reddit(bot))
diff --git a/bot/constants.py b/bot/constants.py
index 9bc331dc4..ebd3b3d96 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -186,6 +186,11 @@ class YAMLGetter(type):
def __getitem__(cls, name):
return cls.__getattr__(name)
+ def __iter__(cls):
+ """Return generator of key: value pairs of current constants class' config values."""
+ for name in cls.__annotations__:
+ yield name, getattr(cls, name)
+
# Dataclasses
class Bot(metaclass=YAMLGetter):
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 8184be824..3e4b15ce4 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -1,5 +1,5 @@
from abc import ABCMeta
-from typing import Any, Generator, Hashable, Iterable
+from typing import Any, Hashable
from discord.ext.commands import CogMeta
@@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict):
for k in list(self.keys()):
v = super(CaseInsensitiveDict, self).pop(k)
self.__setitem__(k, v)
-
-
-def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]:
- """
- Generator that allows you to iterate over any indexable collection in `size`-length chunks.
-
- Found: https://stackoverflow.com/a/312464/4022104
- """
- for i in range(0, len(iterable), size):
- yield iterable[i:i + size]
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
index 58ae2a81a..d7bcc3ba6 100644
--- a/tests/bot/test_utils.py
+++ b/tests/bot/test_utils.py
@@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):
instance = utils.CaseInsensitiveDict()
instance.update({'FOO': 'bar'})
self.assertEqual(instance['foo'], 'bar')
-
-
-class ChunkTests(unittest.TestCase):
- """Tests the `chunk` method."""
-
- def test_empty_chunking(self):
- """Tests chunking on an empty iterable."""
- generator = utils.chunks(iterable=[], size=5)
- self.assertEqual(list(generator), [])
-
- def test_list_chunking(self):
- """Tests chunking a non-empty list."""
- iterable = [1, 2, 3, 4, 5]
- generator = utils.chunks(iterable=iterable, size=2)
- self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])