aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--azure-pipelines.yml2
-rw-r--r--bot/__main__.py26
-rw-r--r--bot/api.py41
-rw-r--r--bot/bot.py53
-rw-r--r--bot/cogs/alias.py6
-rw-r--r--bot/cogs/antimalware.py6
-rw-r--r--bot/cogs/antispam.py6
-rw-r--r--bot/cogs/bot.py16
-rw-r--r--bot/cogs/clean.py68
-rw-r--r--bot/cogs/defcon.py6
-rw-r--r--bot/cogs/doc.py8
-rw-r--r--bot/cogs/duck_pond.py6
-rw-r--r--bot/cogs/error_handler.py16
-rw-r--r--bot/cogs/eval.py6
-rw-r--r--bot/cogs/extensions.py4
-rw-r--r--bot/cogs/filtering.py6
-rw-r--r--bot/cogs/free.py6
-rw-r--r--bot/cogs/help.py3
-rw-r--r--bot/cogs/information.py6
-rw-r--r--bot/cogs/jams.py8
-rw-r--r--bot/cogs/logging.py6
-rw-r--r--bot/cogs/moderation/__init__.py16
-rw-r--r--bot/cogs/moderation/infractions.py12
-rw-r--r--bot/cogs/moderation/management.py55
-rw-r--r--bot/cogs/moderation/modlog.py3
-rw-r--r--bot/cogs/moderation/scheduler.py20
-rw-r--r--bot/cogs/moderation/superstarify.py3
-rw-r--r--bot/cogs/off_topic_names.py6
-rw-r--r--bot/cogs/reddit.py97
-rw-r--r--bot/cogs/reminders.py6
-rw-r--r--bot/cogs/security.py7
-rw-r--r--bot/cogs/site.py6
-rw-r--r--bot/cogs/snekbox.py6
-rw-r--r--bot/cogs/sync/__init__.py10
-rw-r--r--bot/cogs/sync/cog.py58
-rw-r--r--bot/cogs/sync/syncers.py7
-rw-r--r--bot/cogs/tags.py6
-rw-r--r--bot/cogs/token_remover.py77
-rw-r--r--bot/cogs/utils.py6
-rw-r--r--bot/cogs/verification.py22
-rw-r--r--bot/cogs/watchchannels/__init__.py13
-rw-r--r--bot/cogs/watchchannels/bigbrother.py3
-rw-r--r--bot/cogs/watchchannels/talentpool.py3
-rw-r--r--bot/cogs/watchchannels/watchchannel.py3
-rw-r--r--bot/cogs/wolfram.py10
-rw-r--r--bot/constants.py10
-rw-r--r--bot/converters.py23
-rw-r--r--bot/interpreter.py4
-rw-r--r--bot/rules/attachments.py4
-rw-r--r--bot/utils/time.py31
-rw-r--r--config-default.yml17
-rw-r--r--docker-compose.yml2
-rw-r--r--tests/bot/cogs/test_duck_pond.py12
-rw-r--r--tests/bot/cogs/test_security.py11
-rw-r--r--tests/bot/cogs/test_token_remover.py8
-rw-r--r--tests/bot/rules/test_attachments.py110
-rw-r--r--tests/bot/rules/test_links.py26
-rw-r--r--tests/bot/rules/test_mentions.py95
-rw-r--r--tests/bot/utils/test_time.py162
-rw-r--r--tests/helpers.py4
60 files changed, 928 insertions, 351 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index da3b06201..0400ac4d2 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,7 +30,7 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner
+ - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
displayName: Run tests
- script: coverage report -m && coverage xml -o coverage.xml
diff --git a/bot/__main__.py b/bot/__main__.py
index ea7c43a12..84bc7094b 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,18 +1,11 @@
-import asyncio
-import logging
-import socket
-
import discord
-from aiohttp import AsyncResolver, ClientSession, TCPConnector
-from discord.ext.commands import Bot, when_mentioned_or
+from discord.ext.commands import when_mentioned_or
from bot import patches
-from bot.api import APIClient, APILoggingHandler
+from bot.bot import Bot
from bot.constants import Bot as BotConfig, DEBUG_MODE
-log = logging.getLogger('bot')
-
bot = Bot(
command_prefix=when_mentioned_or(BotConfig.prefix),
activity=discord.Game(name="Commands: !help"),
@@ -20,18 +13,6 @@ bot = Bot(
max_messages=10_000,
)
-# Global aiohttp session for all cogs
-# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads
-# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod.
-bot.http_session = ClientSession(
- connector=TCPConnector(
- resolver=AsyncResolver(),
- family=socket.AF_INET,
- )
-)
-bot.api_client = APIClient(loop=asyncio.get_event_loop())
-log.addHandler(APILoggingHandler(bot.api_client))
-
# Internal/debug
bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
@@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'):
patches.message_edited_at.apply_patch()
bot.run(BotConfig.token)
-
-# This calls a coroutine, so it doesn't do anything at the moment.
-# bot.http_session.close() # Close the aiohttp session when the bot finishes running
diff --git a/bot/api.py b/bot/api.py
index 7f26e5305..56db99828 100644
--- a/bot/api.py
+++ b/bot/api.py
@@ -32,7 +32,7 @@ class ResponseCodeError(ValueError):
class APIClient:
"""Django Site API wrapper."""
- def __init__(self, **kwargs):
+ def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):
auth_headers = {
'Authorization': f"Token {Keys.site_api}"
}
@@ -42,12 +42,39 @@ class APIClient:
else:
kwargs['headers'] = auth_headers
- self.session = aiohttp.ClientSession(**kwargs)
+ self.session: Optional[aiohttp.ClientSession] = None
+ self.loop = loop
+
+ self._ready = asyncio.Event(loop=loop)
+ self._creation_task = None
+ self._session_args = kwargs
+
+ self.recreate()
@staticmethod
def _url_for(endpoint: str) -> str:
return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}"
+ async def _create_session(self) -> None:
+ """Create the aiohttp session and set the ready event."""
+ self.session = aiohttp.ClientSession(**self._session_args)
+ self._ready.set()
+
+ async def close(self) -> None:
+ """Close the aiohttp session and unset the ready event."""
+ if not self._ready.is_set():
+ return
+
+ await self.session.close()
+ self._ready.clear()
+
+ def recreate(self) -> None:
+ """Schedule the aiohttp session to be created if it's been closed."""
+ if self.session is None or self.session.closed:
+ # Don't schedule a task if one is already in progress.
+ if self._creation_task is None or self._creation_task.done():
+ self._creation_task = self.loop.create_task(self._create_session())
+
async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:
"""Raise ResponseCodeError for non-OK response if an exception should be raised."""
if should_raise and response.status >= 400:
@@ -60,30 +87,40 @@ class APIClient:
async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API GET."""
+ await self._ready.wait()
+
async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API PATCH."""
+ await self._ready.wait()
+
async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API POST."""
+ await self._ready.wait()
+
async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API PUT."""
+ await self._ready.wait()
+
async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
"""Site API DELETE."""
+ await self._ready.wait()
+
async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp:
if resp.status == 204:
return None
diff --git a/bot/bot.py b/bot/bot.py
new file mode 100644
index 000000000..8f808272f
--- /dev/null
+++ b/bot/bot.py
@@ -0,0 +1,53 @@
+import logging
+import socket
+from typing import Optional
+
+import aiohttp
+from discord.ext import commands
+
+from bot import api
+
+log = logging.getLogger('bot')
+
+
+class Bot(commands.Bot):
+ """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client."""
+
+ def __init__(self, *args, **kwargs):
+ # Use asyncio for DNS resolution instead of threads so threads aren't spammed.
+ # Use AF_INET as its socket family to prevent HTTPS related problems both locally
+ # and in production.
+ self.connector = aiohttp.TCPConnector(
+ resolver=aiohttp.AsyncResolver(),
+ family=socket.AF_INET,
+ )
+
+ super().__init__(*args, connector=self.connector, **kwargs)
+
+ self.http_session: Optional[aiohttp.ClientSession] = None
+ self.api_client = api.APIClient(loop=self.loop, connector=self.connector)
+
+ log.addHandler(api.APILoggingHandler(self.api_client))
+
+ def add_cog(self, cog: commands.Cog) -> None:
+ """Adds a "cog" to the bot and logs the operation."""
+ super().add_cog(cog)
+ log.info(f"Cog loaded: {cog.qualified_name}")
+
+ def clear(self) -> None:
+ """Clears the internal state of the bot and resets the API client."""
+ super().clear()
+ self.api_client.recreate()
+
+ async def close(self) -> None:
+ """Close the aiohttp session after closing the Discord connection."""
+ await super().close()
+
+ await self.http_session.close()
+ await self.api_client.close()
+
+ async def start(self, *args, **kwargs) -> None:
+ """Open an aiohttp session before logging in and connecting to Discord."""
+ self.http_session = aiohttp.ClientSession(connector=self.connector)
+
+ await super().start(*args, **kwargs)
diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py
index 5190c559b..c1db38462 100644
--- a/bot/cogs/alias.py
+++ b/bot/cogs/alias.py
@@ -3,8 +3,9 @@ import logging
from typing import Union
from discord import Colour, Embed, Member, User
-from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group
+from discord.ext.commands import Cog, Command, Context, clean_content, command, group
+from bot.bot import Bot
from bot.cogs.extensions import Extension
from bot.cogs.watchchannels.watchchannel import proxy_user
from bot.converters import TagNameConverter
@@ -147,6 +148,5 @@ class Alias (Cog):
def setup(bot: Bot) -> None:
- """Alias cog load."""
+ """Load the Alias cog."""
bot.add_cog(Alias(bot))
- log.info("Cog loaded: Alias")
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 602819191..28e3e5d96 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -1,8 +1,9 @@
import logging
from discord import Embed, Message, NotFound
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
+from bot.bot import Bot
from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs
log = logging.getLogger(__name__)
@@ -49,6 +50,5 @@ class AntiMalware(Cog):
def setup(bot: Bot) -> None:
- """Antimalware cog load."""
+ """Load the AntiMalware cog."""
bot.add_cog(AntiMalware(bot))
- log.info("Cog loaded: AntiMalware")
diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py
index 1340eb608..f454061a6 100644
--- a/bot/cogs/antispam.py
+++ b/bot/cogs/antispam.py
@@ -7,9 +7,10 @@ from operator import itemgetter
from typing import Dict, Iterable, List, Set
from discord import Colour, Member, Message, NotFound, Object, TextChannel
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
from bot import rules
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import (
AntiSpam as AntiSpamConfig, Channels,
@@ -276,7 +277,6 @@ def validate_config(rules: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:
def setup(bot: Bot) -> None:
- """Antispam cog load."""
+ """Validate the AntiSpam configs and load the AntiSpam cog."""
validation_errors = validate_config()
bot.add_cog(AntiSpam(bot, validation_errors))
- log.info("Cog loaded: AntiSpam")
diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py
index ee0a463de..73b1e8f41 100644
--- a/bot/cogs/bot.py
+++ b/bot/cogs/bot.py
@@ -5,8 +5,10 @@ import time
from typing import Optional, Tuple
from discord import Embed, Message, RawMessageUpdateEvent, TextChannel
-from discord.ext.commands import Bot, Cog, Context, command, group
+from discord.ext.commands import Cog, Context, command, group
+from bot.bot import Bot
+from bot.cogs.token_remover import TokenRemover
from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs
from bot.decorators import with_role
from bot.utils.messages import wait_for_deletion
@@ -16,7 +18,7 @@ log = logging.getLogger(__name__)
RE_MARKDOWN = re.compile(r'([*_~`|>])')
-class Bot(Cog):
+class BotCog(Cog, name="Bot"):
"""Bot information commands."""
def __init__(self, bot: Bot):
@@ -238,9 +240,10 @@ class Bot(Cog):
)
and not msg.author.bot
and len(msg.content.splitlines()) > 3
+ and not TokenRemover.is_token_in_message(msg)
)
- if parse_codeblock:
+ if parse_codeblock: # no token in the msg
on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300
if not on_cooldown or DEBUG_MODE:
try:
@@ -373,10 +376,9 @@ class Bot(Cog):
bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id])
await bot_message.delete()
del self.codeblock_message_ids[payload.message_id]
- log.trace("User's incorrect code block has been fixed. Removing bot formatting message.")
+ log.trace("User's incorrect code block has been fixed. Removing bot formatting message.")
def setup(bot: Bot) -> None:
- """Bot cog load."""
- bot.add_cog(Bot(bot))
- log.info("Cog loaded: Bot")
+ """Load the Bot cog."""
+ bot.add_cog(BotCog(bot))
diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py
index dca411d01..2104efe57 100644
--- a/bot/cogs/clean.py
+++ b/bot/cogs/clean.py
@@ -3,9 +3,10 @@ import random
import re
from typing import Optional
-from discord import Colour, Embed, Message, User
-from discord.ext.commands import Bot, Cog, Context, group
+from discord import Colour, Embed, Message, TextChannel, User
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import (
Channels, CleanMessages, Colours, Event,
@@ -37,9 +38,13 @@ class Clean(Cog):
return self.bot.get_cog("ModLog")
async def _clean_messages(
- self, amount: int, ctx: Context,
- bots_only: bool = False, user: User = None,
- regex: Optional[str] = None
+ self,
+ amount: int,
+ ctx: Context,
+ bots_only: bool = False,
+ user: User = None,
+ regex: Optional[str] = None,
+ channel: Optional[TextChannel] = None
) -> None:
"""A helper function that does the actual message cleaning."""
def predicate_bots_only(message: Message) -> bool:
@@ -104,6 +109,10 @@ class Clean(Cog):
else:
predicate = None # Delete all messages
+ # Default to using the invoking context's channel
+ if not channel:
+ channel = ctx.channel
+
# Look through the history and retrieve message data
messages = []
message_ids = []
@@ -111,7 +120,7 @@ class Clean(Cog):
invocation_deleted = False
# To account for the invocation message, we index `amount + 1` messages.
- async for message in ctx.channel.history(limit=amount + 1):
+ async for message in channel.history(limit=amount + 1):
# If at any point the cancel command is invoked, we should stop.
if not self.cleaning:
@@ -135,7 +144,7 @@ class Clean(Cog):
self.mod_log.ignore(Event.message_delete, *message_ids)
# Use bulk delete to actually do the cleaning. It's far faster.
- await ctx.channel.purge(
+ await channel.purge(
limit=amount,
check=predicate
)
@@ -155,7 +164,7 @@ class Clean(Cog):
# Build the embed and send it
message = (
- f"**{len(message_ids)}** messages deleted in <#{ctx.channel.id}> by **{ctx.author.name}**\n\n"
+ f"**{len(message_ids)}** messages deleted in <#{channel.id}> by **{ctx.author.name}**\n\n"
f"A log of the deleted messages can be found [here]({log_url})."
)
@@ -167,7 +176,7 @@ class Clean(Cog):
channel_id=Channels.modlog,
)
- @group(invoke_without_command=True, name="clean", hidden=True)
+ @group(invoke_without_command=True, name="clean", aliases=["purge"])
@with_role(*MODERATION_ROLES)
async def clean_group(self, ctx: Context) -> None:
"""Commands for cleaning messages in channels."""
@@ -175,27 +184,49 @@ class Clean(Cog):
@clean_group.command(name="user", aliases=["users"])
@with_role(*MODERATION_ROLES)
- async def clean_user(self, ctx: Context, user: User, amount: int = 10) -> None:
+ async def clean_user(
+ self,
+ ctx: Context,
+ user: User,
+ amount: Optional[int] = 10,
+ channel: TextChannel = None
+ ) -> None:
"""Delete messages posted by the provided user, stop cleaning after traversing `amount` messages."""
- await self._clean_messages(amount, ctx, user=user)
+ await self._clean_messages(amount, ctx, user=user, channel=channel)
@clean_group.command(name="all", aliases=["everything"])
@with_role(*MODERATION_ROLES)
- async def clean_all(self, ctx: Context, amount: int = 10) -> None:
+ async def clean_all(
+ self,
+ ctx: Context,
+ amount: Optional[int] = 10,
+ channel: TextChannel = None
+ ) -> None:
"""Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages."""
- await self._clean_messages(amount, ctx)
+ await self._clean_messages(amount, ctx, channel=channel)
@clean_group.command(name="bots", aliases=["bot"])
@with_role(*MODERATION_ROLES)
- async def clean_bots(self, ctx: Context, amount: int = 10) -> None:
+ async def clean_bots(
+ self,
+ ctx: Context,
+ amount: Optional[int] = 10,
+ channel: TextChannel = None
+ ) -> None:
"""Delete all messages posted by a bot, stop cleaning after traversing `amount` messages."""
- await self._clean_messages(amount, ctx, bots_only=True)
+ await self._clean_messages(amount, ctx, bots_only=True, channel=channel)
@clean_group.command(name="regex", aliases=["word", "expression"])
@with_role(*MODERATION_ROLES)
- async def clean_regex(self, ctx: Context, regex: str, amount: int = 10) -> None:
+ async def clean_regex(
+ self,
+ ctx: Context,
+ regex: str,
+ amount: Optional[int] = 10,
+ channel: TextChannel = None
+ ) -> None:
"""Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages."""
- await self._clean_messages(amount, ctx, regex=regex)
+ await self._clean_messages(amount, ctx, regex=regex, channel=channel)
@clean_group.command(name="stop", aliases=["cancel", "abort"])
@with_role(*MODERATION_ROLES)
@@ -211,6 +242,5 @@ class Clean(Cog):
def setup(bot: Bot) -> None:
- """Clean cog load."""
+ """Load the Clean cog."""
bot.add_cog(Clean(bot))
- log.info("Cog loaded: Clean")
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index bedd70c86..3e7350fcc 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -6,8 +6,9 @@ from datetime import datetime, timedelta
from enum import Enum
from discord import Colour, Embed, Member
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles
from bot.decorators import with_role
@@ -236,6 +237,5 @@ class Defcon(Cog):
def setup(bot: Bot) -> None:
- """DEFCON cog load."""
+ """Load the Defcon cog."""
bot.add_cog(Defcon(bot))
- log.info("Cog loaded: Defcon")
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py
index e5b3a4062..9506b195a 100644
--- a/bot/cogs/doc.py
+++ b/bot/cogs/doc.py
@@ -17,6 +17,7 @@ from requests import ConnectTimeout, ConnectionError, HTTPError
from sphinx.ext import intersphinx
from urllib3.exceptions import ProtocolError
+from bot.bot import Bot
from bot.constants import MODERATION_ROLES, RedirectOutput
from bot.converters import ValidPythonIdentifier, ValidURL
from bot.decorators import with_role
@@ -147,7 +148,7 @@ class InventoryURL(commands.Converter):
class Doc(commands.Cog):
"""A set of commands for querying & displaying documentation."""
- def __init__(self, bot: commands.Bot):
+ def __init__(self, bot: Bot):
self.base_urls = {}
self.bot = bot
self.inventories = {}
@@ -506,7 +507,6 @@ class Doc(commands.Cog):
return tag.name == "table"
-def setup(bot: commands.Bot) -> None:
- """Doc cog load."""
+def setup(bot: Bot) -> None:
+ """Load the Doc cog."""
bot.add_cog(Doc(bot))
- log.info("Cog loaded: Doc")
diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py
index 2d25cd17e..345d2856c 100644
--- a/bot/cogs/duck_pond.py
+++ b/bot/cogs/duck_pond.py
@@ -3,9 +3,10 @@ from typing import Optional, Union
import discord
from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
from bot import constants
+from bot.bot import Bot
from bot.utils.messages import send_attachments
log = logging.getLogger(__name__)
@@ -177,6 +178,5 @@ class DuckPond(Cog):
def setup(bot: Bot) -> None:
- """Load the duck pond cog."""
+ """Load the DuckPond cog."""
bot.add_cog(DuckPond(bot))
- log.info("Cog loaded: DuckPond")
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index 49411814c..52893b2ee 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -14,9 +14,10 @@ from discord.ext.commands import (
NoPrivateMessage,
UserInputError,
)
-from discord.ext.commands import Bot, Cog, Context
+from discord.ext.commands import Cog, Context
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.constants import Channels
from bot.decorators import InChannelCheckFailure
@@ -75,6 +76,16 @@ class ErrorHandler(Cog):
tags_get_command = self.bot.get_command("tags get")
ctx.invoked_from_error_handler = True
+ log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
+ try:
+ if not await tags_get_command.can_run(ctx):
+ log.debug(log_msg)
+ return
+ except CommandError as tag_error:
+ log.debug(log_msg)
+ await self.on_command_error(ctx, tag_error)
+ return
+
# Return to not raise the exception
with contextlib.suppress(ResponseCodeError):
await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with)
@@ -143,6 +154,5 @@ class ErrorHandler(Cog):
def setup(bot: Bot) -> None:
- """Error handler cog load."""
+ """Load the ErrorHandler cog."""
bot.add_cog(ErrorHandler(bot))
- log.info("Cog loaded: Events")
diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py
index 00b988dde..9c729f28a 100644
--- a/bot/cogs/eval.py
+++ b/bot/cogs/eval.py
@@ -9,8 +9,9 @@ from io import StringIO
from typing import Any, Optional, Tuple
import discord
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.constants import Roles
from bot.decorators import with_role
from bot.interpreter import Interpreter
@@ -197,6 +198,5 @@ async def func(): # (None,) -> Any
def setup(bot: Bot) -> None:
- """Code eval cog load."""
+ """Load the CodeEval cog."""
bot.add_cog(CodeEval(bot))
- log.info("Cog loaded: Eval")
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index bb66e0b8e..f16e79fb7 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -6,8 +6,9 @@ from pkgutil import iter_modules
from discord import Colour, Embed
from discord.ext import commands
-from discord.ext.commands import Bot, Context, group
+from discord.ext.commands import Context, group
+from bot.bot import Bot
from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
from bot.pagination import LinePaginator
from bot.utils.checks import with_role_check
@@ -233,4 +234,3 @@ class Extensions(commands.Cog):
def setup(bot: Bot) -> None:
"""Load the Extensions cog."""
bot.add_cog(Extensions(bot))
- log.info("Cog loaded: Extensions")
diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py
index 1e7521054..74538542a 100644
--- a/bot/cogs/filtering.py
+++ b/bot/cogs/filtering.py
@@ -5,8 +5,9 @@ from typing import Optional, Union
import discord.errors
from dateutil.relativedelta import relativedelta
from discord import Colour, DMChannel, Member, Message, TextChannel
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import (
Channels, Colours,
@@ -370,6 +371,5 @@ class Filtering(Cog):
def setup(bot: Bot) -> None:
- """Filtering cog load."""
+ """Load the Filtering cog."""
bot.add_cog(Filtering(bot))
- log.info("Cog loaded: Filtering")
diff --git a/bot/cogs/free.py b/bot/cogs/free.py
index 82285656b..49cab6172 100644
--- a/bot/cogs/free.py
+++ b/bot/cogs/free.py
@@ -3,8 +3,9 @@ from datetime import datetime
from operator import itemgetter
from discord import Colour, Embed, Member, utils
-from discord.ext.commands import Bot, Cog, Context, command
+from discord.ext.commands import Cog, Context, command
+from bot.bot import Bot
from bot.constants import Categories, Channels, Free, STAFF_ROLES
from bot.decorators import redirect_output
@@ -98,6 +99,5 @@ class Free(Cog):
def setup(bot: Bot) -> None:
- """Free cog load."""
+ """Load the Free cog."""
bot.add_cog(Free())
- log.info("Cog loaded: Free")
diff --git a/bot/cogs/help.py b/bot/cogs/help.py
index 9607dbd8d..6385fa467 100644
--- a/bot/cogs/help.py
+++ b/bot/cogs/help.py
@@ -6,10 +6,11 @@ from typing import Union
from discord import Colour, Embed, HTTPException, Message, Reaction, User
from discord.ext import commands
-from discord.ext.commands import Bot, CheckFailure, Cog as DiscordCog, Command, Context
+from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context
from fuzzywuzzy import fuzz, process
from bot import constants
+from bot.bot import Bot
from bot.constants import Channels, STAFF_ROLES
from bot.decorators import redirect_output
from bot.pagination import (
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 530453600..1ede95ff4 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -9,10 +9,11 @@ from typing import Any, Mapping, Optional
import discord
from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils
from discord.ext import commands
-from discord.ext.commands import Bot, BucketType, Cog, Context, command, group
+from discord.ext.commands import BucketType, Cog, Context, command, group
from discord.utils import escape_markdown
from bot import constants
+from bot.bot import Bot
from bot.decorators import InChannelCheckFailure, in_channel, with_role
from bot.utils.checks import cooldown_with_role_bypass, with_role_check
from bot.utils.time import time_since
@@ -391,6 +392,5 @@ class Information(Cog):
def setup(bot: Bot) -> None:
- """Information cog load."""
+ """Load the Information cog."""
bot.add_cog(Information(bot))
- log.info("Cog loaded: Information")
diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py
index be9d33e3e..985f28ce5 100644
--- a/bot/cogs/jams.py
+++ b/bot/cogs/jams.py
@@ -4,6 +4,7 @@ from discord import Member, PermissionOverwrite, utils
from discord.ext import commands
from more_itertools import unique_everseen
+from bot.bot import Bot
from bot.constants import Roles
from bot.decorators import with_role
@@ -13,7 +14,7 @@ log = logging.getLogger(__name__)
class CodeJams(commands.Cog):
"""Manages the code-jam related parts of our server."""
- def __init__(self, bot: commands.Bot):
+ def __init__(self, bot: Bot):
self.bot = bot
@commands.command()
@@ -108,7 +109,6 @@ class CodeJams(commands.Cog):
)
-def setup(bot: commands.Bot) -> None:
- """Code Jams cog load."""
+def setup(bot: Bot) -> None:
+ """Load the CodeJams cog."""
bot.add_cog(CodeJams(bot))
- log.info("Cog loaded: CodeJams")
diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py
index c92b619ff..d1b7dcab3 100644
--- a/bot/cogs/logging.py
+++ b/bot/cogs/logging.py
@@ -1,8 +1,9 @@
import logging
from discord import Embed
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
+from bot.bot import Bot
from bot.constants import Channels, DEBUG_MODE
@@ -37,6 +38,5 @@ class Logging(Cog):
def setup(bot: Bot) -> None:
- """Logging cog load."""
+ """Load the Logging cog."""
bot.add_cog(Logging(bot))
- log.info("Cog loaded: Logging")
diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py
index 7383ed44e..5243cb92d 100644
--- a/bot/cogs/moderation/__init__.py
+++ b/bot/cogs/moderation/__init__.py
@@ -1,25 +1,13 @@
-import logging
-
-from discord.ext.commands import Bot
-
+from bot.bot import Bot
from .infractions import Infractions
from .management import ModManagement
from .modlog import ModLog
from .superstarify import Superstarify
-log = logging.getLogger(__name__)
-
def setup(bot: Bot) -> None:
- """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs)."""
+ """Load the Infractions, ModManagement, ModLog, and Superstarify cogs."""
bot.add_cog(Infractions(bot))
- log.info("Cog loaded: Infractions")
-
bot.add_cog(ModLog(bot))
- log.info("Cog loaded: ModLog")
-
bot.add_cog(ModManagement(bot))
- log.info("Cog loaded: ModManagement")
-
bot.add_cog(Superstarify(bot))
- log.info("Cog loaded: Superstarify")
diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py
index 2713a1b68..3536a3d38 100644
--- a/bot/cogs/moderation/infractions.py
+++ b/bot/cogs/moderation/infractions.py
@@ -7,6 +7,7 @@ from discord.ext import commands
from discord.ext.commands import Context, command
from bot import constants
+from bot.bot import Bot
from bot.constants import Event
from bot.decorators import respect_role_hierarchy
from bot.utils.checks import with_role_check
@@ -25,7 +26,7 @@ class Infractions(InfractionScheduler, commands.Cog):
category = "Moderation"
category_description = "Server moderation tools."
- def __init__(self, bot: commands.Bot):
+ def __init__(self, bot: Bot):
super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"})
self.category = "Moderation"
@@ -208,8 +209,13 @@ class Infractions(InfractionScheduler, commands.Cog):
self.mod_log.ignore(Event.member_update, user.id)
- action = user.add_roles(self._muted_role, reason=reason)
- await self.apply_infraction(ctx, infraction, user, action)
+ async def action() -> None:
+ await user.add_roles(self._muted_role, reason=reason)
+
+ log.trace(f"Attempting to kick {user} from voice because they've been muted.")
+ await user.move_to(None, reason=reason)
+
+ await self.apply_infraction(ctx, infraction, user, action())
@respect_role_hierarchy()
async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None:
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index abfe5c2b3..9605d47b2 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -9,7 +9,8 @@ from discord.ext import commands
from discord.ext.commands import Context
from bot import constants
-from bot.converters import InfractionSearchQuery
+from bot.bot import Bot
+from bot.converters import InfractionSearchQuery, allowed_strings
from bot.pagination import LinePaginator
from bot.utils import time
from bot.utils.checks import in_channel_check, with_role_check
@@ -22,21 +23,12 @@ log = logging.getLogger(__name__)
UserConverter = t.Union[discord.User, utils.proxy_user]
-def permanent_duration(expires_at: str) -> str:
- """Only allow an expiration to be 'permanent' if it is a string."""
- expires_at = expires_at.lower()
- if expires_at != "permanent":
- raise commands.BadArgument
- else:
- return expires_at
-
-
class ModManagement(commands.Cog):
"""Management of infractions."""
category = "Moderation"
- def __init__(self, bot: commands.Bot):
+ def __init__(self, bot: Bot):
self.bot = bot
@property
@@ -60,8 +52,8 @@ class ModManagement(commands.Cog):
async def infraction_edit(
self,
ctx: Context,
- infraction_id: int,
- duration: t.Union[utils.Expiry, permanent_duration, None],
+ infraction_id: t.Union[int, allowed_strings("l", "last", "recent")],
+ duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],
*,
reason: str = None
) -> None:
@@ -78,21 +70,40 @@ class ModManagement(commands.Cog):
\u2003`M` - minutes∗
\u2003`s` - seconds
- Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp
- can be provided for the duration.
+ Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction
+ authored by the command invoker should be edited.
+
+ Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601
+ timestamp can be provided for the duration.
"""
if duration is None and reason is None:
# Unlike UserInputError, the error handler will show a specified message for BadArgument
raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")
# Retrieve the previous infraction for its information.
- old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}')
+ if isinstance(infraction_id, str):
+ params = {
+ "actor__id": ctx.author.id,
+ "ordering": "-inserted_at"
+ }
+ infractions = await self.bot.api_client.get(f"bot/infractions", params=params)
+
+ if infractions:
+ old_infraction = infractions[0]
+ infraction_id = old_infraction["id"]
+ else:
+ await ctx.send(
+ f":x: Couldn't find most recent infraction; you have never given an infraction."
+ )
+ return
+ else:
+ old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")
request_data = {}
confirm_messages = []
log_text = ""
- if duration == "permanent":
+ if isinstance(duration, str):
request_data['expires_at'] = None
confirm_messages.append("marked as permanent")
elif duration is not None:
@@ -129,7 +140,8 @@ class ModManagement(commands.Cog):
New expiry: {new_infraction['expires_at'] or "Permanent"}
""".rstrip()
- await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}")
+ changes = ' & '.join(confirm_messages)
+ await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")
# Get information about the infraction's user
user_id = new_infraction['user']
@@ -232,6 +244,12 @@ class ModManagement(commands.Cog):
user_id = infraction["user"]
hidden = infraction["hidden"]
created = time.format_infraction(infraction["inserted_at"])
+
+ if active:
+ remaining = time.until_expiration(infraction["expires_at"]) or "Expired"
+ else:
+ remaining = "Inactive"
+
if infraction["expires_at"] is None:
expires = "*Permanent*"
else:
@@ -247,6 +265,7 @@ class ModManagement(commands.Cog):
Reason: {infraction["reason"] or "*None*"}
Created: {created}
Expires: {expires}
+ Remaining: {remaining}
Actor: {actor.mention if actor else actor_id}
ID: `{infraction["id"]}`
{"**===============**" if active else "==============="}
diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py
index 0df752a97..35ef6cbcc 100644
--- a/bot/cogs/moderation/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -10,8 +10,9 @@ from dateutil.relativedelta import relativedelta
from deepdiff import DeepDiff
from discord import Colour
from discord.abc import GuildChannel
-from discord.ext.commands import Bot, Cog, Context
+from discord.ext.commands import Cog, Context
+from bot.bot import Bot
from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs
from bot.utils.time import humanize_delta
from .utils import UserTypes
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index 3e0968121..01e4b1fe7 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -7,10 +7,11 @@ from gettext import ngettext
import dateutil.parser
import discord
-from discord.ext.commands import Bot, Context
+from discord.ext.commands import Context
from bot import constants
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.constants import Colours, STAFF_CHANNELS
from bot.utils import time
from bot.utils.scheduling import Scheduler
@@ -113,8 +114,8 @@ class InfractionScheduler(Scheduler):
dm_result = ":incoming_envelope: "
dm_log_text = "\nDM: Sent"
else:
+ dm_result = f"{constants.Emojis.failmail} "
dm_log_text = "\nDM: **Failed**"
- log_content = ctx.author.mention
if infraction["actor"] == self.bot.user.id:
log.trace(
@@ -146,14 +147,18 @@ class InfractionScheduler(Scheduler):
if expiry:
# Schedule the expiration of the infraction.
self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
- except discord.Forbidden:
+ except discord.HTTPException as e:
# Accordingly display that applying the infraction failed.
confirm_msg = f":x: failed to apply"
expiry_msg = ""
log_content = ctx.author.mention
log_title = "failed to apply"
- log.warning(f"Failed to apply {infr_type} infraction #{id_} to {user}.")
+ log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}"
+ if isinstance(e, discord.Forbidden):
+ log.warning(f"{log_msg}: bot lacks permissions.")
+ else:
+ log.exception(log_msg)
# Send a confirmation message to the invoking context.
log.trace(f"Sending infraction #{id_} confirmation message.")
@@ -250,8 +255,7 @@ class InfractionScheduler(Scheduler):
if log_text.get("DM") == "Sent":
dm_emoji = ":incoming_envelope: "
elif "DM" in log_text:
- # Mention the actor because the DM failed to send.
- log_content = ctx.author.mention
+ dm_emoji = f"{constants.Emojis.failmail} "
# Accordingly display whether the pardon failed.
if "Failure" in log_text:
@@ -324,12 +328,12 @@ class InfractionScheduler(Scheduler):
f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!"
)
except discord.Forbidden:
- log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions")
+ log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.")
log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)"
log_content = mod_role.mention
except discord.HTTPException as e:
log.exception(f"Failed to deactivate infraction #{id_} ({type_})")
- log_text["Failure"] = f"HTTPException with code {e.code}."
+ log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}."
log_content = mod_role.mention
# Check if the user is currently being watched by Big Brother.
diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py
index 9b3c62403..7631d9bbe 100644
--- a/bot/cogs/moderation/superstarify.py
+++ b/bot/cogs/moderation/superstarify.py
@@ -6,9 +6,10 @@ import typing as t
from pathlib import Path
from discord import Colour, Embed, Member
-from discord.ext.commands import Bot, Cog, Context, command
+from discord.ext.commands import Cog, Context, command
from bot import constants
+from bot.bot import Bot
from bot.utils.checks import with_role_check
from bot.utils.time import format_infraction
from . import utils
diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py
index 78792240f..bf777ea5a 100644
--- a/bot/cogs/off_topic_names.py
+++ b/bot/cogs/off_topic_names.py
@@ -4,9 +4,10 @@ import logging
from datetime import datetime, timedelta
from discord import Colour, Embed
-from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group
+from discord.ext.commands import BadArgument, Cog, Context, Converter, group
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES
from bot.decorators import with_role
from bot.pagination import LinePaginator
@@ -184,6 +185,5 @@ class OffTopicNames(Cog):
def setup(bot: Bot) -> None:
- """Off topic names cog load."""
+ """Load the OffTopicNames cog."""
bot.add_cog(OffTopicNames(bot))
- log.info("Cog loaded: OffTopicNames")
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 0d06e9c26..aa487f18e 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -2,13 +2,16 @@ import asyncio
import logging
import random
import textwrap
+from collections import namedtuple
from datetime import datetime, timedelta
from typing import List
+from aiohttp import BasicAuth, ClientError
from discord import Colour, Embed, TextChannel
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
from discord.ext.tasks import loop
+from bot.bot import Bot
from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
from bot.converters import Subreddit
from bot.decorators import with_role
@@ -16,25 +19,32 @@ from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
+AccessToken = namedtuple("AccessToken", ["token", "expires_at"])
+
class Reddit(Cog):
"""Track subreddit posts and show detailed statistics about them."""
- HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"}
+ HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"}
URL = "https://www.reddit.com"
- MAX_FETCH_RETRIES = 3
+ OAUTH_URL = "https://oauth.reddit.com"
+ MAX_RETRIES = 3
def __init__(self, bot: Bot):
self.bot = bot
- self.webhook = None # set in on_ready
- bot.loop.create_task(self.init_reddit_ready())
+ self.webhook = None
+ self.access_token = None
+ self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret)
+ bot.loop.create_task(self.init_reddit_ready())
self.auto_poster_loop.start()
def cog_unload(self) -> None:
- """Stops the loops when the cog is unloaded."""
+ """Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
+ if self.access_token.expires_at < datetime.utcnow():
+ self.revoke_access_token()
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
@@ -47,20 +57,82 @@ class Reddit(Cog):
"""Get the #reddit channel object from the bot's cache."""
return self.bot.get_channel(Channels.reddit)
+ async def get_access_token(self) -> None:
+ """
+ Get a Reddit API OAuth2 access token and assign it to self.access_token.
+
+ A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog
+ will be unloaded and a ClientError raised if retrieval was still unsuccessful.
+ """
+ for i in range(1, self.MAX_RETRIES + 1):
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/access_token",
+ headers=self.HEADERS,
+ auth=self.client_auth,
+ data={
+ "grant_type": "client_credentials",
+ "duration": "temporary"
+ }
+ )
+
+ if response.status == 200 and response.content_type == "application/json":
+ content = await response.json()
+ expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway.
+ self.access_token = AccessToken(
+ token=content["access_token"],
+ expires_at=datetime.utcnow() + timedelta(seconds=expiration)
+ )
+
+ log.debug(f"New token acquired; expires on {self.access_token.expires_at}")
+ return
+ else:
+ log.debug(
+ f"Failed to get an access token: "
+ f"status {response.status} & content type {response.content_type}; "
+ f"retrying ({i}/{self.MAX_RETRIES})"
+ )
+
+ await asyncio.sleep(3)
+
+ self.bot.remove_cog(self.qualified_name)
+ raise ClientError("Authentication with the Reddit API failed. Unloading the cog.")
+
+ async def revoke_access_token(self) -> None:
+ """
+ Revoke the OAuth2 access token for the Reddit API.
+
+ For security reasons, it's good practice to revoke the token when it's no longer being used.
+ """
+ response = await self.bot.http_session.post(
+ url=f"{self.URL}/api/v1/revoke_token",
+ headers=self.HEADERS,
+ auth=self.client_auth,
+ data={
+ "token": self.access_token.token,
+ "token_type_hint": "access_token"
+ }
+ )
+
+ if response.status == 204 and response.content_type == "application/json":
+ self.access_token = None
+ else:
+ log.warning(f"Unable to revoke access token: status {response.status}.")
+
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
# Reddit's JSON responses only provide 25 posts at most.
if not 25 >= amount > 0:
raise ValueError("Invalid amount of subreddit posts requested.")
- if params is None:
- params = {}
+ # Renew the token if necessary.
+ if not self.access_token or self.access_token.expires_at < datetime.utcnow():
+ await self.get_access_token()
- url = f"{self.URL}/{route}.json"
- for _ in range(self.MAX_FETCH_RETRIES):
+ url = f"{self.OAUTH_URL}/{route}"
+ for _ in range(self.MAX_RETRIES):
response = await self.bot.http_session.get(
url=url,
- headers=self.HEADERS,
+ headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"},
params=params
)
if response.status == 200 and response.content_type == 'application/json':
@@ -217,6 +289,5 @@ class Reddit(Cog):
def setup(bot: Bot) -> None:
- """Reddit cog load."""
+ """Load the Reddit cog."""
bot.add_cog(Reddit(bot))
- log.info("Cog loaded: Reddit")
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 81990704b..45bf9a8f4 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -8,8 +8,9 @@ from typing import Optional
from dateutil.relativedelta import relativedelta
from discord import Colour, Embed, Message
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES
from bot.converters import Duration
from bot.pagination import LinePaginator
@@ -290,6 +291,5 @@ class Reminders(Scheduler, Cog):
def setup(bot: Bot) -> None:
- """Reminders cog load."""
+ """Load the Reminders cog."""
bot.add_cog(Reminders(bot))
- log.info("Cog loaded: Reminders")
diff --git a/bot/cogs/security.py b/bot/cogs/security.py
index 316b33d6b..c680c5e27 100644
--- a/bot/cogs/security.py
+++ b/bot/cogs/security.py
@@ -1,6 +1,8 @@
import logging
-from discord.ext.commands import Bot, Cog, Context, NoPrivateMessage
+from discord.ext.commands import Cog, Context, NoPrivateMessage
+
+from bot.bot import Bot
log = logging.getLogger(__name__)
@@ -25,6 +27,5 @@ class Security(Cog):
def setup(bot: Bot) -> None:
- """Security cog load."""
+ """Load the Security cog."""
bot.add_cog(Security(bot))
- log.info("Cog loaded: Security")
diff --git a/bot/cogs/site.py b/bot/cogs/site.py
index 683613788..2ea8c7a2e 100644
--- a/bot/cogs/site.py
+++ b/bot/cogs/site.py
@@ -1,8 +1,9 @@
import logging
from discord import Colour, Embed
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.constants import URLs
from bot.pagination import LinePaginator
@@ -138,6 +139,5 @@ class Site(Cog):
def setup(bot: Bot) -> None:
- """Site cog load."""
+ """Load the Site cog."""
bot.add_cog(Site(bot))
- log.info("Cog loaded: Site")
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index 55a187ac1..da33e27b2 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -5,8 +5,9 @@ import textwrap
from signal import Signals
from typing import Optional, Tuple
-from discord.ext.commands import Bot, Cog, Context, command, guild_only
+from discord.ext.commands import Cog, Context, command, guild_only
+from bot.bot import Bot
from bot.constants import Channels, Roles, URLs
from bot.decorators import in_channel
from bot.utils.messages import wait_for_deletion
@@ -227,6 +228,5 @@ class Snekbox(Cog):
def setup(bot: Bot) -> None:
- """Snekbox cog load."""
+ """Load the Snekbox cog."""
bot.add_cog(Snekbox(bot))
- log.info("Cog loaded: Snekbox")
diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py
index d4565f848..fe7df4e9b 100644
--- a/bot/cogs/sync/__init__.py
+++ b/bot/cogs/sync/__init__.py
@@ -1,13 +1,7 @@
-import logging
-
-from discord.ext.commands import Bot
-
+from bot.bot import Bot
from .cog import Sync
-log = logging.getLogger(__name__)
-
def setup(bot: Bot) -> None:
- """Sync cog load."""
+ """Load the Sync cog."""
bot.add_cog(Sync(bot))
- log.info("Cog loaded: Sync")
diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py
index aaa581f96..4e6ed156b 100644
--- a/bot/cogs/sync/cog.py
+++ b/bot/cogs/sync/cog.py
@@ -1,12 +1,13 @@
import logging
-from typing import Callable, Iterable
+from typing import Callable, Dict, Iterable, Union
-from discord import Guild, Member, Role
+from discord import Guild, Member, Role, User
from discord.ext import commands
-from discord.ext.commands import Bot, Cog, Context
+from discord.ext.commands import Cog, Context
from bot import constants
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.cogs.sync import syncers
log = logging.getLogger(__name__)
@@ -50,6 +51,15 @@ class Sync(Cog):
f"deleted `{total_deleted}`."
)
+ async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None:
+ """Send a PATCH request to partially update a user in the database."""
+ try:
+ await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information)
+ except ResponseCodeError as e:
+ if e.response.status != 404:
+ raise
+ log.warning("Unable to update user, got 404. Assuming race condition from join event.")
+
@Cog.listener()
async def on_guild_role_create(self, role: Role) -> None:
"""Adds newly create role to the database table over the API."""
@@ -142,33 +152,21 @@ class Sync(Cog):
@Cog.listener()
async def on_member_update(self, before: Member, after: Member) -> None:
- """Updates the user information if any of relevant attributes have changed."""
- if (
- before.name != after.name
- or before.avatar != after.avatar
- or before.discriminator != after.discriminator
- or before.roles != after.roles
- ):
- try:
- await self.bot.api_client.put(
- 'bot/users/' + str(after.id),
- json={
- 'avatar_hash': after.avatar,
- 'discriminator': int(after.discriminator),
- 'id': after.id,
- 'in_guild': True,
- 'name': after.name,
- 'roles': sorted(role.id for role in after.roles)
- }
- )
- except ResponseCodeError as e:
- if e.response.status != 404:
- raise
-
- log.warning(
- "Unable to update user, got 404. "
- "Assuming race condition from join event."
- )
+ """Update the roles of the member in the database if a change is detected."""
+ if before.roles != after.roles:
+ updated_information = {"roles": sorted(role.id for role in after.roles)}
+ await self.patch_user(after.id, updated_information=updated_information)
+
+ @Cog.listener()
+ async def on_user_update(self, before: User, after: User) -> None:
+ """Update the user information in the database if a relevant change is detected."""
+ if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")):
+ updated_information = {
+ "name": after.name,
+ "discriminator": int(after.discriminator),
+ "avatar_hash": after.avatar,
+ }
+ await self.patch_user(after.id, updated_information=updated_information)
@commands.group(name='sync')
@commands.has_permissions(administrator=True)
diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py
index 2cc5a66e1..14cf51383 100644
--- a/bot/cogs/sync/syncers.py
+++ b/bot/cogs/sync/syncers.py
@@ -2,7 +2,8 @@ from collections import namedtuple
from typing import Dict, Set, Tuple
from discord import Guild
-from discord.ext.commands import Bot
+
+from bot.bot import Bot
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
@@ -52,7 +53,7 @@ async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]:
Synchronize roles found on the given `guild` with the ones on the API.
Arguments:
- bot (discord.ext.commands.Bot):
+ bot (bot.bot.Bot):
The bot instance that we're running with.
guild (discord.Guild):
@@ -169,7 +170,7 @@ async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]:
Synchronize users found in the given `guild` with the ones in the API.
Arguments:
- bot (discord.ext.commands.Bot):
+ bot (bot.bot.Bot):
The bot instance that we're running with.
guild (discord.Guild):
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py
index cd70e783a..970301013 100644
--- a/bot/cogs/tags.py
+++ b/bot/cogs/tags.py
@@ -2,8 +2,9 @@ import logging
import time
from discord import Colour, Embed
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.constants import Channels, Cooldowns, MODERATION_ROLES, Roles
from bot.converters import TagContentConverter, TagNameConverter
from bot.decorators import with_role
@@ -160,6 +161,5 @@ class Tags(Cog):
def setup(bot: Bot) -> None:
- """Tags cog load."""
+ """Load the Tags cog."""
bot.add_cog(Tags(bot))
- log.info("Cog loaded: Tags")
diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py
index 5a0d20e57..82c01ae96 100644
--- a/bot/cogs/token_remover.py
+++ b/bot/cogs/token_remover.py
@@ -6,9 +6,10 @@ import struct
from datetime import datetime
from discord import Colour, Message
-from discord.ext.commands import Bot, Cog
+from discord.ext.commands import Cog
from discord.utils import snowflake_time
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import Channels, Colours, Event, Icons
@@ -52,39 +53,60 @@ class TokenRemover(Cog):
See: https://discordapp.com/developers/docs/reference#snowflakes
"""
+ if self.is_token_in_message(msg):
+ await self.take_action(msg)
+
+ @Cog.listener()
+ async def on_message_edit(self, before: Message, after: Message) -> None:
+ """
+ Check each edit for a string that matches Discord's token pattern.
+
+ See: https://discordapp.com/developers/docs/reference#snowflakes
+ """
+ if self.is_token_in_message(after):
+ await self.take_action(after)
+
+ async def take_action(self, msg: Message) -> None:
+ """Remove the `msg` containing a token an send a mod_log message."""
+ user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.')
+ self.mod_log.ignore(Event.message_delete, msg.id)
+ await msg.delete()
+ await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
+
+ message = (
+ "Censored a seemingly valid token sent by "
+ f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was "
+ f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`"
+ )
+ log.debug(message)
+
+ # Send pretty mod log embed to mod-alerts
+ await self.mod_log.send_log_message(
+ icon_url=Icons.token_removed,
+ colour=Colour(Colours.soft_red),
+ title="Token removed!",
+ text=message,
+ thumbnail=msg.author.avatar_url_as(static_format="png"),
+ channel_id=Channels.mod_alerts,
+ )
+
+ @classmethod
+ def is_token_in_message(cls, msg: Message) -> bool:
+ """Check if `msg` contains a seemly valid token."""
if msg.author.bot:
- return
+ return False
maybe_match = TOKEN_RE.search(msg.content)
if maybe_match is None:
- return
+ return False
try:
user_id, creation_timestamp, hmac = maybe_match.group(0).split('.')
except ValueError:
- return
-
- if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp):
- self.mod_log.ignore(Event.message_delete, msg.id)
- await msg.delete()
- await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
-
- message = (
- "Censored a seemingly valid token sent by "
- f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was "
- f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`"
- )
- log.debug(message)
-
- # Send pretty mod log embed to mod-alerts
- await self.mod_log.send_log_message(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=message,
- thumbnail=msg.author.avatar_url_as(static_format="png"),
- channel_id=Channels.mod_alerts,
- )
+ return False
+
+ if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp):
+ return True
@staticmethod
def is_valid_user_id(b64_content: str) -> bool:
@@ -119,6 +141,5 @@ class TokenRemover(Cog):
def setup(bot: Bot) -> None:
- """Token Remover cog load."""
+ """Load the TokenRemover cog."""
bot.add_cog(TokenRemover(bot))
- log.info("Cog loaded: TokenRemover")
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index 793fe4c1a..47a59db66 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -8,8 +8,9 @@ from typing import Tuple
from dateutil import relativedelta
from discord import Colour, Embed, Message, Role
-from discord.ext.commands import Bot, Cog, Context, command
+from discord.ext.commands import Cog, Context, command
+from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES
from bot.decorators import in_channel, with_role
from bot.utils.time import humanize_delta
@@ -176,6 +177,5 @@ class Utils(Cog):
def setup(bot: Bot) -> None:
- """Utils cog load."""
+ """Load the Utils cog."""
bot.add_cog(Utils(bot))
- log.info("Cog loaded: Utils")
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index b5e8d4357..988e0d49a 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -3,15 +3,17 @@ from datetime import datetime
from discord import Colour, Message, NotFound, Object
from discord.ext import tasks
-from discord.ext.commands import Bot, Cog, Context, command
+from discord.ext.commands import Cog, Context, command
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import (
Bot as BotConfig,
Channels, Colours, Event,
- Filter, Icons, Roles
+ Filter, Icons, MODERATION_ROLES, Roles
)
from bot.decorators import InChannelCheckFailure, in_channel, without_role
+from bot.utils.checks import without_role_check
log = logging.getLogger(__name__)
@@ -37,6 +39,7 @@ PERIODIC_PING = (
f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`."
f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel."
)
+BOT_MESSAGE_DELETE_DELAY = 10
class Verification(Cog):
@@ -54,12 +57,16 @@ class Verification(Cog):
@Cog.listener()
async def on_message(self, message: Message) -> None:
"""Check new message event for messages to the checkpoint channel & process."""
- if message.author.bot:
- return # They're a bot, ignore
-
if message.channel.id != Channels.verification:
return # Only listen for #checkpoint messages
+ if message.author.bot:
+ # They're a bot, delete their message after the delay.
+ # But not the periodic ping; we like that one.
+ if message.content != PERIODIC_PING:
+ await message.delete(delay=BOT_MESSAGE_DELETE_DELAY)
+ return
+
# if a user mentions a role or guild member
# alert the mods in mod-alerts channel
if message.mentions or message.role_mentions:
@@ -189,7 +196,7 @@ class Verification(Cog):
@staticmethod
def bot_check(ctx: Context) -> bool:
"""Block any command within the verification channel that is not !accept."""
- if ctx.channel.id == Channels.verification:
+ if ctx.channel.id == Channels.verification and without_role_check(ctx, *MODERATION_ROLES):
return ctx.command.name == "accept"
else:
return True
@@ -224,6 +231,5 @@ class Verification(Cog):
def setup(bot: Bot) -> None:
- """Verification cog load."""
+ """Load the Verification cog."""
bot.add_cog(Verification(bot))
- log.info("Cog loaded: Verification")
diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py
index 86e1050fa..69d118df6 100644
--- a/bot/cogs/watchchannels/__init__.py
+++ b/bot/cogs/watchchannels/__init__.py
@@ -1,18 +1,9 @@
-import logging
-
-from discord.ext.commands import Bot
-
+from bot.bot import Bot
from .bigbrother import BigBrother
from .talentpool import TalentPool
-log = logging.getLogger(__name__)
-
-
def setup(bot: Bot) -> None:
- """Monitoring cogs load."""
+ """Load the BigBrother and TalentPool cogs."""
bot.add_cog(BigBrother(bot))
- log.info("Cog loaded: BigBrother")
-
bot.add_cog(TalentPool(bot))
- log.info("Cog loaded: TalentPool")
diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py
index 49783bb09..306ed4c64 100644
--- a/bot/cogs/watchchannels/bigbrother.py
+++ b/bot/cogs/watchchannels/bigbrother.py
@@ -3,8 +3,9 @@ from collections import ChainMap
from typing import Union
from discord import User
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
+from bot.bot import Bot
from bot.cogs.moderation.utils import post_infraction
from bot.constants import Channels, MODERATION_ROLES, Webhooks
from bot.decorators import with_role
diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py
index 4ec42dcc1..cc8feeeee 100644
--- a/bot/cogs/watchchannels/talentpool.py
+++ b/bot/cogs/watchchannels/talentpool.py
@@ -4,9 +4,10 @@ from collections import ChainMap
from typing import Union
from discord import Color, Embed, Member, User
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import Cog, Context, group
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks
from bot.decorators import with_role
from bot.pagination import LinePaginator
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
index 0bf75a924..bd0622554 100644
--- a/bot/cogs/watchchannels/watchchannel.py
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -10,9 +10,10 @@ from typing import Optional
import dateutil.parser
import discord
from discord import Color, Embed, HTTPException, Message, Object, errors
-from discord.ext.commands import BadArgument, Bot, Cog, Context
+from discord.ext.commands import BadArgument, Cog, Context
from bot.api import ResponseCodeError
+from bot.bot import Bot
from bot.cogs.moderation import ModLog
from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons
from bot.pagination import LinePaginator
diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py
index ab0ed2472..5d6b4630b 100644
--- a/bot/cogs/wolfram.py
+++ b/bot/cogs/wolfram.py
@@ -7,8 +7,9 @@ import discord
from dateutil.relativedelta import relativedelta
from discord import Embed
from discord.ext import commands
-from discord.ext.commands import Bot, BucketType, Cog, Context, check, group
+from discord.ext.commands import BucketType, Cog, Context, check, group
+from bot.bot import Bot
from bot.constants import Colours, STAFF_ROLES, Wolfram
from bot.pagination import ImagePaginator
from bot.utils.time import humanize_delta
@@ -151,7 +152,7 @@ async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tup
class Wolfram(Cog):
"""Commands for interacting with the Wolfram|Alpha API."""
- def __init__(self, bot: commands.Bot):
+ def __init__(self, bot: Bot):
self.bot = bot
@group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True)
@@ -266,7 +267,6 @@ class Wolfram(Cog):
await send_embed(ctx, message, color)
-def setup(bot: commands.Bot) -> None:
- """Wolfram cog load."""
+def setup(bot: Bot) -> None:
+ """Load the Wolfram cog."""
bot.add_cog(Wolfram(bot))
- log.info("Cog loaded: Wolfram")
diff --git a/bot/constants.py b/bot/constants.py
index 89504a2e0..2c0e3b10b 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -256,6 +256,8 @@ class Emojis(metaclass=YAMLGetter):
status_idle: str
status_dnd: str
+ failmail: str
+
bullet: str
new: str
pencil: str
@@ -268,6 +270,12 @@ class Emojis(metaclass=YAMLGetter):
ducky_ninja: int
ducky_devil: int
ducky_tube: int
+ ducky_hunt: int
+ ducky_wizard: int
+ ducky_party: int
+ ducky_angel: int
+ ducky_maul: int
+ ducky_santa: int
upvotes: str
comments: str
@@ -465,6 +473,8 @@ class Reddit(metaclass=YAMLGetter):
section = "reddit"
subreddits: list
+ client_id: str
+ secret: str
class Wolfram(metaclass=YAMLGetter):
diff --git a/bot/converters.py b/bot/converters.py
index cf0496541..8d2ab7eb8 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,8 +1,8 @@
import logging
import re
+import typing as t
from datetime import datetime
from ssl import CertificateError
-from typing import Union
import dateutil.parser
import dateutil.tz
@@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter
log = logging.getLogger(__name__)
+def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
+ """
+ Return a converter which only allows arguments equal to one of the given values.
+
+ Unless preserve_case is True, the argument is converted to lowercase. All values are then
+ expected to have already been given in lowercase too.
+ """
+ def converter(arg: str) -> str:
+ if not preserve_case:
+ arg = arg.lower()
+
+ if arg not in values:
+ raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```")
+ else:
+ return arg
+
+ return converter
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):
"""A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
@staticmethod
- async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]:
+ async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
"""Check if the argument is a Discord user, and if not, falls back to a string."""
try:
maybe_snowflake = arg.strip("<@!>")
diff --git a/bot/interpreter.py b/bot/interpreter.py
index 76a3fc293..8b7268746 100644
--- a/bot/interpreter.py
+++ b/bot/interpreter.py
@@ -2,7 +2,9 @@ from code import InteractiveInterpreter
from io import StringIO
from typing import Any
-from discord.ext.commands import Bot, Context
+from discord.ext.commands import Context
+
+from bot.bot import Bot
CODE_TEMPLATE = """
async def _func():
diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py
index c550aed76..00bb2a949 100644
--- a/bot/rules/attachments.py
+++ b/bot/rules/attachments.py
@@ -7,14 +7,14 @@ async def apply(
last_message: Message, recent_messages: List[Message], config: Dict[str, int]
) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]:
"""Detects total attachments exceeding the limit sent by a single user."""
- relevant_messages = [last_message] + [
+ relevant_messages = tuple(
msg
for msg in recent_messages
if (
msg.author == last_message.author
and len(msg.attachments) > 0
)
- ]
+ )
total_recent_attachments = sum(len(msg.attachments) for msg in relevant_messages)
if total_recent_attachments > config['max']:
diff --git a/bot/utils/time.py b/bot/utils/time.py
index a024674ac..7416f36e0 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str:
return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT)
-def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str:
+def format_infraction_with_duration(
+ expiry: Optional[str],
+ date_from: Optional[datetime.datetime] = None,
+ max_units: int = 2
+) -> Optional[str]:
"""
Format an infraction timestamp to a more readable ISO 8601 format WITH the duration.
@@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime =
duration_formatted = f" ({duration})" if duration else ''
return f"{expiry_formatted}{duration_formatted}"
+
+
+def until_expiration(
+ expiry: Optional[str],
+ now: Optional[datetime.datetime] = None,
+ max_units: int = 2
+) -> Optional[str]:
+ """
+ Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta.
+
+ Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry.
+ Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it.
+ `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
+ By default, max_units is 2.
+ """
+ if not expiry:
+ return None
+
+ now = now or datetime.datetime.utcnow()
+ since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0)
+
+ if since < now:
+ return None
+
+ return humanize_delta(relativedelta(since, now), max_units=max_units)
diff --git a/config-default.yml b/config-default.yml
index f8337b6da..f28c79712 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -27,6 +27,8 @@ style:
status_dnd: "<:status_dnd:470326272082313216>"
status_offline: "<:status_offline:470326266537705472>"
+ failmail: "<:failmail:633660039931887616>"
+
bullet: "\u2022"
pencil: "\u270F"
new: "\U0001F195"
@@ -39,6 +41,12 @@ style:
ducky_ninja: &DUCKY_NINJA 637923502535606293
ducky_devil: &DUCKY_DEVIL 637925314982576139
ducky_tube: &DUCKY_TUBE 637881368008851456
+ ducky_hunt: &DUCKY_HUNT 639355090909528084
+ ducky_wizard: &DUCKY_WIZARD 639355996954689536
+ ducky_party: &DUCKY_PARTY 639468753440210977
+ ducky_angel: &DUCKY_ANGEL 640121935610511361
+ ducky_maul: &DUCKY_MAUL 640137724958867467
+ ducky_santa: &DUCKY_SANTA 655360331002019870
upvotes: "<:upvotes:638729835245731840>"
comments: "<:comments:638729835073765387>"
@@ -361,11 +369,18 @@ anti_malware:
- '.png'
- '.tiff'
- '.wmv'
+ - '.svg'
+ - '.psd' # Photoshop
+ - '.ai' # Illustrator
+ - '.aep' # After Effects
+ - '.xcf' # GIMP
reddit:
subreddits:
- 'r/Python'
+ client_id: !ENV "REDDIT_CLIENT_ID"
+ secret: !ENV "REDDIT_SECRET"
wolfram:
@@ -397,7 +412,7 @@ redirect_output:
duck_pond:
threshold: 5
- custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE]
+ custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA]
config:
required_keys: ['bot.token']
diff --git a/docker-compose.yml b/docker-compose.yml
index f79fdba58..7281c7953 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -42,3 +42,5 @@ services:
environment:
BOT_TOKEN: ${BOT_TOKEN}
BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531
+ REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID}
+ REDDIT_SECRET: ${REDDIT_SECRET}
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index b801e86f1..d07b2bce1 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -578,15 +578,7 @@ class DuckPondSetupTests(unittest.TestCase):
"""Tests setup of the `DuckPond` cog."""
def test_setup(self):
- """Setup of the cog should log a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = helpers.MockBot()
- log = logging.getLogger('bot.cogs.duck_pond')
-
- with self.assertLogs(logger=log, level=logging.INFO) as log_watcher:
- duck_pond.setup(bot)
-
- self.assertEqual(len(log_watcher.records), 1)
- record = log_watcher.records[0]
- self.assertEqual(record.levelno, logging.INFO)
-
+ duck_pond.setup(bot)
bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py
index efa7a50b1..9d1a62f7e 100644
--- a/tests/bot/cogs/test_security.py
+++ b/tests/bot/cogs/test_security.py
@@ -1,4 +1,3 @@
-import logging
import unittest
from unittest.mock import MagicMock
@@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase):
"""Tests loading the `Security` cog."""
def test_security_cog_load(self):
- """Cog loading logs a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MagicMock()
- with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm:
- security.setup(bot)
- bot.add_cog.assert_called_once()
-
- [line] = cm.output
- self.assertIn("Cog loaded: Security", line)
+ security.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index 3276cf5a5..a54b839d7 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase):
"""Tests setup of the `TokenRemover` cog."""
def test_setup(self):
- """Setup of the cog should log a message at `INFO` level."""
+ """Setup of the extension should call add_cog."""
bot = MockBot()
- with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm:
- setup_cog(bot)
-
- [line] = cm.output
+ setup_cog(bot)
bot.add_cog.assert_called_once()
- self.assertIn("Cog loaded: TokenRemover", line)
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index 4bb0acf7c..d7187f315 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,52 +1,98 @@
-import asyncio
import unittest
-from dataclasses import dataclass
-from typing import Any, List
+from typing import List, NamedTuple, Tuple
from bot.rules import attachments
+from tests.helpers import MockMessage, async_test
-# Using `MagicMock` sadly doesn't work for this usecase
-# since it's __eq__ compares the MagicMock's ID. We just
-# want to compare the actual attributes we set.
-@dataclass
-class FakeMessage:
- author: str
- attachments: List[Any]
+class Case(NamedTuple):
+ recent_messages: List[MockMessage]
+ culprit: Tuple[str]
+ total_attachments: int
-def msg(total_attachments: int) -> FakeMessage:
- return FakeMessage(author='lemon', attachments=list(range(total_attachments)))
+def msg(author: str, total_attachments: int) -> MockMessage:
+ """Builds a message with `total_attachments` attachments."""
+ return MockMessage(author=author, attachments=list(range(total_attachments)))
class AttachmentRuleTests(unittest.TestCase):
- """Tests applying the `attachment` antispam rule."""
+ """Tests applying the `attachments` antispam rule."""
- def test_allows_messages_without_too_many_attachments(self):
+ def setUp(self):
+ self.config = {"max": 5}
+
+ @async_test
+ async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
- (msg(0), msg(0), msg(0)),
- (msg(2), msg(2)),
- (msg(0),),
+ [msg("bob", 0), msg("bob", 0), msg("bob", 0)],
+ [msg("bob", 2), msg("bob", 2)],
+ [msg("bob", 2), msg("alice", 2), msg("bob", 2)],
)
- for last_message, *recent_messages in cases:
- with self.subTest(last_message=last_message, recent_messages=recent_messages):
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- self.assertIsNone(asyncio.run(coro))
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config
+ ):
+ self.assertIsNone(
+ await attachments.apply(last_message, recent_messages, self.config)
+ )
- def test_disallows_messages_with_too_many_attachments(self):
+ @async_test
+ async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10),
- ((msg(6),), [msg(6)], 6),
- ((msg(1),) * 6, [msg(1)] * 6, 6),
+ Case(
+ [msg("bob", 4), msg("bob", 0), msg("bob", 6)],
+ ("bob",),
+ 10
+ ),
+ Case(
+ [msg("bob", 4), msg("alice", 6), msg("bob", 2)],
+ ("bob",),
+ 6
+ ),
+ Case(
+ [msg("alice", 6)],
+ ("alice",),
+ 6
+ ),
+ (
+ [msg("alice", 1) for _ in range(6)],
+ ("alice",),
+ 6
+ ),
)
- for messages, relevant_messages, total in cases:
- with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total):
- last_message, *recent_messages = messages
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- self.assertEqual(
- asyncio.run(coro),
- (f"sent {total} attachments in 5s", ('lemon',), relevant_messages)
+
+ for recent_messages, culprit, total_attachments in cases:
+ last_message = recent_messages[0]
+ relevant_messages = tuple(
+ msg
+ for msg in recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
+ )
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ total_attachments=total_attachments,
+ config=self.config
+ ):
+ desired_output = (
+ f"sent {total_attachments} attachments in {self.config['max']}s",
+ culprit,
+ relevant_messages
+ )
+ self.assertTupleEqual(
+ await attachments.apply(last_message, recent_messages, self.config),
+ desired_output
)
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index be832843b..02a5d5501 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -2,25 +2,19 @@ import unittest
from typing import List, NamedTuple, Tuple
from bot.rules import links
-from tests.helpers import async_test
-
-
-class FakeMessage(NamedTuple):
- author: str
- content: str
+from tests.helpers import MockMessage, async_test
class Case(NamedTuple):
- recent_messages: List[FakeMessage]
- relevant_messages: Tuple[FakeMessage]
+ recent_messages: List[MockMessage]
culprit: Tuple[str]
total_links: int
-def msg(author: str, total_links: int) -> FakeMessage:
- """Makes a message with *total_links* links."""
+def msg(author: str, total_links: int) -> MockMessage:
+ """Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
- return FakeMessage(author=author, content=content)
+ return MockMessage(author=author, content=content)
class LinksTests(unittest.TestCase):
@@ -61,26 +55,28 @@ class LinksTests(unittest.TestCase):
cases = (
Case(
[msg("bob", 1), msg("bob", 2)],
- (msg("bob", 1), msg("bob", 2)),
("bob",),
3
),
Case(
[msg("alice", 1), msg("alice", 1), msg("alice", 1)],
- (msg("alice", 1), msg("alice", 1), msg("alice", 1)),
("alice",),
3
),
Case(
[msg("alice", 2), msg("bob", 3), msg("alice", 1)],
- (msg("alice", 2), msg("alice", 1)),
("alice",),
3
)
)
- for recent_messages, relevant_messages, culprit, total_links in cases:
+ for recent_messages, culprit, total_links in cases:
last_message = recent_messages[0]
+ relevant_messages = tuple(
+ msg
+ for msg in recent_messages
+ if msg.author == last_message.author
+ )
with self.subTest(
last_message=last_message,
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
new file mode 100644
index 000000000..ad49ead32
--- /dev/null
+++ b/tests/bot/rules/test_mentions.py
@@ -0,0 +1,95 @@
+import unittest
+from typing import List, NamedTuple, Tuple
+
+from bot.rules import mentions
+from tests.helpers import MockMessage, async_test
+
+
+class Case(NamedTuple):
+ recent_messages: List[MockMessage]
+ culprit: Tuple[str]
+ total_mentions: int
+
+
+def msg(author: str, total_mentions: int) -> MockMessage:
+ """Makes a message with `total_mentions` mentions."""
+ return MockMessage(author=author, mentions=list(range(total_mentions)))
+
+
+class TestMentions(unittest.TestCase):
+ """Tests applying the `mentions` antispam rule."""
+
+ def setUp(self):
+ self.config = {
+ "max": 2,
+ "interval": 10
+ }
+
+ @async_test
+ async def test_mentions_within_limit(self):
+ """Messages with an allowed amount of mentions."""
+ cases = (
+ [msg("bob", 0)],
+ [msg("bob", 2)],
+ [msg("bob", 1), msg("bob", 1)],
+ [msg("bob", 1), msg("alice", 2)]
+ )
+
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config
+ ):
+ self.assertIsNone(
+ await mentions.apply(last_message, recent_messages, self.config)
+ )
+
+ @async_test
+ async def test_mentions_exceeding_limit(self):
+ """Messages with a higher than allowed amount of mentions."""
+ cases = (
+ Case(
+ [msg("bob", 3)],
+ ("bob",),
+ 3
+ ),
+ Case(
+ [msg("alice", 2), msg("alice", 0), msg("alice", 1)],
+ ("alice",),
+ 3
+ ),
+ Case(
+ [msg("bob", 2), msg("alice", 3), msg("bob", 2)],
+ ("bob",),
+ 4
+ )
+ )
+
+ for recent_messages, culprit, total_mentions in cases:
+ last_message = recent_messages[0]
+ relevant_messages = tuple(
+ msg
+ for msg in recent_messages
+ if msg.author == last_message.author
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ culprit=culprit,
+ total_mentions=total_mentions,
+ cofig=self.config
+ ):
+ desired_output = (
+ f"sent {total_mentions} mentions in {self.config['interval']}s",
+ culprit,
+ relevant_messages
+ )
+ self.assertTupleEqual(
+ await mentions.apply(last_message, recent_messages, self.config),
+ desired_output
+ )
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
new file mode 100644
index 000000000..69f35f2f5
--- /dev/null
+++ b/tests/bot/utils/test_time.py
@@ -0,0 +1,162 @@
+import asyncio
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from dateutil.relativedelta import relativedelta
+
+from bot.utils import time
+from tests.helpers import AsyncMock
+
+
+class TimeTests(unittest.TestCase):
+ """Test helper functions in bot.utils.time."""
+
+ def test_humanize_delta_handle_unknown_units(self):
+ """humanize_delta should be able to handle unknown units, and will not abort."""
+ # Does not abort for unknown units, as the unit name is checked
+ # against the attribute of the relativedelta instance.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours')
+
+ def test_humanize_delta_handle_high_units(self):
+ """humanize_delta should be able to handle very high units."""
+ # Very high maximum units, but it only ever iterates over
+ # each value the relativedelta might have.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours')
+
+ def test_humanize_delta_should_normal_usage(self):
+ """Testing humanize delta."""
+ test_cases = (
+ (relativedelta(days=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
+ (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
+ )
+
+ for delta, precision, max_units, expected in test_cases:
+ with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected):
+ self.assertEqual(time.humanize_delta(delta, precision, max_units), expected)
+
+ def test_humanize_delta_raises_for_invalid_max_units(self):
+ """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units."""
+ test_cases = (-1, 0)
+
+ for max_units in test_cases:
+ with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
+ time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
+ self.assertEqual(str(error), 'max_units must be positive')
+
+ def test_parse_rfc1123(self):
+ """Testing parse_rfc1123."""
+ self.assertEqual(
+ time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'),
+ datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)
+ )
+
+ def test_format_infraction(self):
+ """Testing format_infraction."""
+ self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
+
+ @patch('asyncio.sleep', new_callable=AsyncMock)
+ def test_wait_until(self, mock):
+ """Testing wait_until."""
+ start = datetime(2019, 1, 1, 0, 0)
+ then = datetime(2019, 1, 1, 0, 10)
+
+ # No return value
+ self.assertIs(asyncio.run(time.wait_until(then, start)), None)
+
+ mock.assert_called_once_with(10 * 60)
+
+ def test_format_infraction_with_duration_none_expiry(self):
+ """format_infraction_with_duration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that date_from and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_custom_units(self):
+ """format_infraction_with_duration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6,
+ '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20,
+ '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)')
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_normal_usage(self):
+ """format_infraction_with_duration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2,
+ '2019-11-23 23:59 (9 minutes and 55 seconds)'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_until_expiration_with_duration_none_expiry(self):
+ """until_expiration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that now and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_with_duration_custom_units(self):
+ """until_expiration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes')
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_normal_usage(self):
+ """until_expiration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
diff --git a/tests/helpers.py b/tests/helpers.py
index b2daae92d..5df796c23 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -10,7 +10,9 @@ import unittest.mock
from typing import Any, Iterable, Optional
import discord
-from discord.ext.commands import Bot, Context
+from discord.ext.commands import Context
+
+from bot.bot import Bot
for logger in logging.Logger.manager.loggerDict.values():