diff options
60 files changed, 928 insertions, 351 deletions
diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/__main__.py b/bot/__main__.py index ea7c43a12..84bc7094b 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,18 +1,11 @@ -import asyncio -import logging -import socket - import discord -from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord.ext.commands import Bot, when_mentioned_or +from discord.ext.commands import when_mentioned_or from bot import patches -from bot.api import APIClient, APILoggingHandler +from bot.bot import Bot from bot.constants import Bot as BotConfig, DEBUG_MODE -log = logging.getLogger('bot') - bot = Bot( command_prefix=when_mentioned_or(BotConfig.prefix), activity=discord.Game(name="Commands: !help"), @@ -20,18 +13,6 @@ bot = Bot( max_messages=10_000, ) -# Global aiohttp session for all cogs -# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads -# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. -bot.http_session = ClientSession( - connector=TCPConnector( - resolver=AsyncResolver(), - family=socket.AF_INET, - ) -) -bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.addHandler(APILoggingHandler(bot.api_client)) - # Internal/debug bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") @@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'): patches.message_edited_at.apply_patch() bot.run(BotConfig.token) - -# This calls a coroutine, so it doesn't do anything at the moment. -# bot.http_session.close() # Close the aiohttp session when the bot finishes running diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" - def __init__(self, **kwargs): + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" } @@ -42,12 +42,39 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session = aiohttp.ClientSession(**kwargs) + self.session: Optional[aiohttp.ClientSession] = None + self.loop = loop + + self._ready = asyncio.Event(loop=loop) + self._creation_task = None + self._session_args = kwargs + + self.recreate() @staticmethod def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" + async def _create_session(self) -> None: + """Create the aiohttp session and set the ready event.""" + self.session = aiohttp.ClientSession(**self._session_args) + self._ready.set() + + async def close(self) -> None: + """Close the aiohttp session and unset the ready event.""" + if not self._ready.is_set(): + return + + await self.session.close() + self._ready.clear() + + def recreate(self) -> None: + """Schedule the aiohttp session to be created if it's been closed.""" + if self.session is None or self.session.closed: + # Don't schedule a task if one is already in progress. + if self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session()) + async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient: async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API GET.""" + await self._ready.wait() + async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PATCH.""" + await self._ready.wait() + async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" + await self._ready.wait() + async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" + await self._ready.wait() + async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" + await self._ready.wait() + async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: if resp.status == 204: return None diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 000000000..8f808272f --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,53 @@ +import logging +import socket +from typing import Optional + +import aiohttp +from discord.ext import commands + +from bot import api + +log = logging.getLogger('bot') + + +class Bot(commands.Bot): + """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + + def __init__(self, *args, **kwargs): + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self.connector = aiohttp.TCPConnector( + resolver=aiohttp.AsyncResolver(), + family=socket.AF_INET, + ) + + super().__init__(*args, connector=self.connector, **kwargs) + + self.http_session: Optional[aiohttp.ClientSession] = None + self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + + log.addHandler(api.APILoggingHandler(self.api_client)) + + def add_cog(self, cog: commands.Cog) -> None: + """Adds a "cog" to the bot and logs the operation.""" + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def clear(self) -> None: + """Clears the internal state of the bot and resets the API client.""" + super().clear() + self.api_client.recreate() + + async def close(self) -> None: + """Close the aiohttp session after closing the Discord connection.""" + await super().close() + + await self.http_session.close() + await self.api_client.close() + + async def start(self, *args, **kwargs) -> None: + """Open an aiohttp session before logging in and connecting to Discord.""" + self.http_session = aiohttp.ClientSession(connector=self.connector) + + await super().start(*args, **kwargs) diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 5190c559b..c1db38462 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -3,8 +3,9 @@ import logging from typing import Union from discord import Colour, Embed, Member, User -from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from discord.ext.commands import Cog, Command, Context, clean_content, command, group +from bot.bot import Bot from bot.cogs.extensions import Extension from bot.cogs.watchchannels.watchchannel import proxy_user from bot.converters import TagNameConverter @@ -147,6 +148,5 @@ class Alias (Cog): def setup(bot: Bot) -> None: - """Alias cog load.""" + """Load the Alias cog.""" bot.add_cog(Alias(bot)) - log.info("Cog loaded: Alias") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 602819191..28e3e5d96 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,8 +1,9 @@ import logging from discord import Embed, Message, NotFound -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs log = logging.getLogger(__name__) @@ -49,6 +50,5 @@ class AntiMalware(Cog): def setup(bot: Bot) -> None: - """Antimalware cog load.""" + """Load the AntiMalware cog.""" bot.add_cog(AntiMalware(bot)) - log.info("Cog loaded: AntiMalware") diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 1340eb608..f454061a6 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -7,9 +7,10 @@ from operator import itemgetter from typing import Dict, Iterable, List, Set from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from bot import rules +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( AntiSpam as AntiSpamConfig, Channels, @@ -276,7 +277,6 @@ def validate_config(rules: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: def setup(bot: Bot) -> None: - """Antispam cog load.""" + """Validate the AntiSpam configs and load the AntiSpam cog.""" validation_errors = validate_config() bot.add_cog(AntiSpam(bot, validation_errors)) - log.info("Cog loaded: AntiSpam") diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index ee0a463de..73b1e8f41 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -5,8 +5,10 @@ import time from typing import Optional, Tuple from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Bot, Cog, Context, command, group +from discord.ext.commands import Cog, Context, command, group +from bot.bot import Bot +from bot.cogs.token_remover import TokenRemover from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion @@ -16,7 +18,7 @@ log = logging.getLogger(__name__) RE_MARKDOWN = re.compile(r'([*_~`|>])') -class Bot(Cog): +class BotCog(Cog, name="Bot"): """Bot information commands.""" def __init__(self, bot: Bot): @@ -238,9 +240,10 @@ class Bot(Cog): ) and not msg.author.bot and len(msg.content.splitlines()) > 3 + and not TokenRemover.is_token_in_message(msg) ) - if parse_codeblock: + if parse_codeblock: # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: @@ -373,10 +376,9 @@ class Bot(Cog): bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) await bot_message.delete() del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") def setup(bot: Bot) -> None: - """Bot cog load.""" - bot.add_cog(Bot(bot)) - log.info("Cog loaded: Bot") + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index dca411d01..2104efe57 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -3,9 +3,10 @@ import random import re from typing import Optional -from discord import Colour, Embed, Message, User -from discord.ext.commands import Bot, Cog, Context, group +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, CleanMessages, Colours, Event, @@ -37,9 +38,13 @@ class Clean(Cog): return self.bot.get_cog("ModLog") async def _clean_messages( - self, amount: int, ctx: Context, - bots_only: bool = False, user: User = None, - regex: Optional[str] = None + self, + amount: int, + ctx: Context, + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + channel: Optional[TextChannel] = None ) -> None: """A helper function that does the actual message cleaning.""" def predicate_bots_only(message: Message) -> bool: @@ -104,6 +109,10 @@ class Clean(Cog): else: predicate = None # Delete all messages + # Default to using the invoking context's channel + if not channel: + channel = ctx.channel + # Look through the history and retrieve message data messages = [] message_ids = [] @@ -111,7 +120,7 @@ class Clean(Cog): invocation_deleted = False # To account for the invocation message, we index `amount + 1` messages. - async for message in ctx.channel.history(limit=amount + 1): + async for message in channel.history(limit=amount + 1): # If at any point the cancel command is invoked, we should stop. if not self.cleaning: @@ -135,7 +144,7 @@ class Clean(Cog): self.mod_log.ignore(Event.message_delete, *message_ids) # Use bulk delete to actually do the cleaning. It's far faster. - await ctx.channel.purge( + await channel.purge( limit=amount, check=predicate ) @@ -155,7 +164,7 @@ class Clean(Cog): # Build the embed and send it message = ( - f"**{len(message_ids)}** messages deleted in <#{ctx.channel.id}> by **{ctx.author.name}**\n\n" + f"**{len(message_ids)}** messages deleted in <#{channel.id}> by **{ctx.author.name}**\n\n" f"A log of the deleted messages can be found [here]({log_url})." ) @@ -167,7 +176,7 @@ class Clean(Cog): channel_id=Channels.modlog, ) - @group(invoke_without_command=True, name="clean", hidden=True) + @group(invoke_without_command=True, name="clean", aliases=["purge"]) @with_role(*MODERATION_ROLES) async def clean_group(self, ctx: Context) -> None: """Commands for cleaning messages in channels.""" @@ -175,27 +184,49 @@ class Clean(Cog): @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) - async def clean_user(self, ctx: Context, user: User, amount: int = 10) -> None: + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user) + await self._clean_messages(amount, ctx, user=user, channel=channel) @clean_group.command(name="all", aliases=["everything"]) @with_role(*MODERATION_ROLES) - async def clean_all(self, ctx: Context, amount: int = 10) -> None: + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx) + await self._clean_messages(amount, ctx, channel=channel) @clean_group.command(name="bots", aliases=["bot"]) @with_role(*MODERATION_ROLES) - async def clean_bots(self, ctx: Context, amount: int = 10) -> None: + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True) + await self._clean_messages(amount, ctx, bots_only=True, channel=channel) @clean_group.command(name="regex", aliases=["word", "expression"]) @with_role(*MODERATION_ROLES) - async def clean_regex(self, ctx: Context, regex: str, amount: int = 10) -> None: + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex) + await self._clean_messages(amount, ctx, regex=regex, channel=channel) @clean_group.command(name="stop", aliases=["cancel", "abort"]) @with_role(*MODERATION_ROLES) @@ -211,6 +242,5 @@ class Clean(Cog): def setup(bot: Bot) -> None: - """Clean cog load.""" + """Load the Clean cog.""" bot.add_cog(Clean(bot)) - log.info("Cog loaded: Clean") diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index bedd70c86..3e7350fcc 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -6,8 +6,9 @@ from datetime import datetime, timedelta from enum import Enum from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles from bot.decorators import with_role @@ -236,6 +237,5 @@ class Defcon(Cog): def setup(bot: Bot) -> None: - """DEFCON cog load.""" + """Load the Defcon cog.""" bot.add_cog(Defcon(bot)) - log.info("Cog loaded: Defcon") diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index e5b3a4062..9506b195a 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -17,6 +17,7 @@ from requests import ConnectTimeout, ConnectionError, HTTPError from sphinx.ext import intersphinx from urllib3.exceptions import ProtocolError +from bot.bot import Bot from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import ValidPythonIdentifier, ValidURL from bot.decorators import with_role @@ -147,7 +148,7 @@ class InventoryURL(commands.Converter): class Doc(commands.Cog): """A set of commands for querying & displaying documentation.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.base_urls = {} self.bot = bot self.inventories = {} @@ -506,7 +507,6 @@ class Doc(commands.Cog): return tag.name == "table" -def setup(bot: commands.Bot) -> None: - """Doc cog load.""" +def setup(bot: Bot) -> None: + """Load the Doc cog.""" bot.add_cog(Doc(bot)) - log.info("Cog loaded: Doc") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 2d25cd17e..345d2856c 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -3,9 +3,10 @@ from typing import Optional, Union import discord from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from bot import constants +from bot.bot import Bot from bot.utils.messages import send_attachments log = logging.getLogger(__name__) @@ -177,6 +178,5 @@ class DuckPond(Cog): def setup(bot: Bot) -> None: - """Load the duck pond cog.""" + """Load the DuckPond cog.""" bot.add_cog(DuckPond(bot)) - log.info("Cog loaded: DuckPond") diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 49411814c..52893b2ee 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -14,9 +14,10 @@ from discord.ext.commands import ( NoPrivateMessage, UserInputError, ) -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels from bot.decorators import InChannelCheckFailure @@ -75,6 +76,16 @@ class ErrorHandler(Cog): tags_get_command = self.bot.get_command("tags get") ctx.invoked_from_error_handler = True + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + # Return to not raise the exception with contextlib.suppress(ResponseCodeError): await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) @@ -143,6 +154,5 @@ class ErrorHandler(Cog): def setup(bot: Bot) -> None: - """Error handler cog load.""" + """Load the ErrorHandler cog.""" bot.add_cog(ErrorHandler(bot)) - log.info("Cog loaded: Events") diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 00b988dde..9c729f28a 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -9,8 +9,9 @@ from io import StringIO from typing import Any, Optional, Tuple import discord -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role from bot.interpreter import Interpreter @@ -197,6 +198,5 @@ async def func(): # (None,) -> Any def setup(bot: Bot) -> None: - """Code eval cog load.""" + """Load the CodeEval cog.""" bot.add_cog(CodeEval(bot)) - log.info("Cog loaded: Eval") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index bb66e0b8e..f16e79fb7 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -6,8 +6,9 @@ from pkgutil import iter_modules from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import Bot, Context, group +from discord.ext.commands import Context, group +from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs from bot.pagination import LinePaginator from bot.utils.checks import with_role_check @@ -233,4 +234,3 @@ class Extensions(commands.Cog): def setup(bot: Bot) -> None: """Load the Extensions cog.""" bot.add_cog(Extensions(bot)) - log.info("Cog loaded: Extensions") diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 1e7521054..74538542a 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -5,8 +5,9 @@ from typing import Optional, Union import discord.errors from dateutil.relativedelta import relativedelta from discord import Colour, DMChannel, Member, Message, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, Colours, @@ -370,6 +371,5 @@ class Filtering(Cog): def setup(bot: Bot) -> None: - """Filtering cog load.""" + """Load the Filtering cog.""" bot.add_cog(Filtering(bot)) - log.info("Cog loaded: Filtering") diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 82285656b..49cab6172 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -3,8 +3,9 @@ from datetime import datetime from operator import itemgetter from discord import Colour, Embed, Member, utils -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Categories, Channels, Free, STAFF_ROLES from bot.decorators import redirect_output @@ -98,6 +99,5 @@ class Free(Cog): def setup(bot: Bot) -> None: - """Free cog load.""" + """Load the Free cog.""" bot.add_cog(Free()) - log.info("Cog loaded: Free") diff --git a/bot/cogs/help.py b/bot/cogs/help.py index 9607dbd8d..6385fa467 100644 --- a/bot/cogs/help.py +++ b/bot/cogs/help.py @@ -6,10 +6,11 @@ from typing import Union from discord import Colour, Embed, HTTPException, Message, Reaction, User from discord.ext import commands -from discord.ext.commands import Bot, CheckFailure, Cog as DiscordCog, Command, Context +from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context from fuzzywuzzy import fuzz, process from bot import constants +from bot.bot import Bot from bot.constants import Channels, STAFF_ROLES from bot.decorators import redirect_output from bot.pagination import ( diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 530453600..1ede95ff4 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -9,10 +9,11 @@ from typing import Any, Mapping, Optional import discord from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, command, group +from discord.ext.commands import BucketType, Cog, Context, command, group from discord.utils import escape_markdown from bot import constants +from bot.bot import Bot from bot.decorators import InChannelCheckFailure, in_channel, with_role from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -391,6 +392,5 @@ class Information(Cog): def setup(bot: Bot) -> None: - """Information cog load.""" + """Load the Information cog.""" bot.add_cog(Information(bot)) - log.info("Cog loaded: Information") diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index be9d33e3e..985f28ce5 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -4,6 +4,7 @@ from discord import Member, PermissionOverwrite, utils from discord.ext import commands from more_itertools import unique_everseen +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role @@ -13,7 +14,7 @@ log = logging.getLogger(__name__) class CodeJams(commands.Cog): """Manages the code-jam related parts of our server.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @commands.command() @@ -108,7 +109,6 @@ class CodeJams(commands.Cog): ) -def setup(bot: commands.Bot) -> None: - """Code Jams cog load.""" +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" bot.add_cog(CodeJams(bot)) - log.info("Cog loaded: CodeJams") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index c92b619ff..d1b7dcab3 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -1,8 +1,9 @@ import logging from discord import Embed -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE @@ -37,6 +38,5 @@ class Logging(Cog): def setup(bot: Bot) -> None: - """Logging cog load.""" + """Load the Logging cog.""" bot.add_cog(Logging(bot)) - log.info("Cog loaded: Logging") diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 7383ed44e..5243cb92d 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,25 +1,13 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .infractions import Infractions from .management import ModManagement from .modlog import ModLog from .superstarify import Superstarify -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs).""" + """Load the Infractions, ModManagement, ModLog, and Superstarify cogs.""" bot.add_cog(Infractions(bot)) - log.info("Cog loaded: Infractions") - bot.add_cog(ModLog(bot)) - log.info("Cog loaded: ModLog") - bot.add_cog(ModManagement(bot)) - log.info("Cog loaded: ModManagement") - bot.add_cog(Superstarify(bot)) - log.info("Cog loaded: Superstarify") diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2713a1b68..3536a3d38 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -7,6 +7,7 @@ from discord.ext import commands from discord.ext.commands import Context, command from bot import constants +from bot.bot import Bot from bot.constants import Event from bot.decorators import respect_role_hierarchy from bot.utils.checks import with_role_check @@ -25,7 +26,7 @@ class Infractions(InfractionScheduler, commands.Cog): category = "Moderation" category_description = "Server moderation tools." - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) self.category = "Moderation" @@ -208,8 +209,13 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_update, user.id) - action = user.add_roles(self._muted_role, reason=reason) - await self.apply_infraction(ctx, infraction, user, action) + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) @respect_role_hierarchy() async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..9605d47b2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,8 @@ from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery +from bot.bot import Bot +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -22,21 +23,12 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" category = "Moderation" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @property @@ -60,8 +52,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: @@ -78,21 +70,40 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: @@ -129,7 +140,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] @@ -232,6 +244,12 @@ class ModManagement(commands.Cog): user_id = infraction["user"] hidden = infraction["hidden"] created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + if infraction["expires_at"] is None: expires = "*Permanent*" else: @@ -247,6 +265,7 @@ class ModManagement(commands.Cog): Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} + Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` {"**===============**" if active else "==============="} diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 0df752a97..35ef6cbcc 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -10,8 +10,9 @@ from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff from discord import Colour from discord.abc import GuildChannel -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context +from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs from bot.utils.time import humanize_delta from .utils import UserTypes diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 3e0968121..01e4b1fe7 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -7,10 +7,11 @@ from gettext import ngettext import dateutil.parser import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Colours, STAFF_CHANNELS from bot.utils import time from bot.utils.scheduling import Scheduler @@ -113,8 +114,8 @@ class InfractionScheduler(Scheduler): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" else: + dm_result = f"{constants.Emojis.failmail} " dm_log_text = "\nDM: **Failed**" - log_content = ctx.author.mention if infraction["actor"] == self.bot.user.id: log.trace( @@ -146,14 +147,18 @@ class InfractionScheduler(Scheduler): if expiry: # Schedule the expiration of the infraction. self.schedule_task(ctx.bot.loop, infraction["id"], infraction) - except discord.Forbidden: + except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = f":x: failed to apply" expiry_msg = "" log_content = ctx.author.mention log_title = "failed to apply" - log.warning(f"Failed to apply {infr_type} infraction #{id_} to {user}.") + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) # Send a confirmation message to the invoking context. log.trace(f"Sending infraction #{id_} confirmation message.") @@ -250,8 +255,7 @@ class InfractionScheduler(Scheduler): if log_text.get("DM") == "Sent": dm_emoji = ":incoming_envelope: " elif "DM" in log_text: - # Mention the actor because the DM failed to send. - log_content = ctx.author.mention + dm_emoji = f"{constants.Emojis.failmail} " # Accordingly display whether the pardon failed. if "Failure" in log_text: @@ -324,12 +328,12 @@ class InfractionScheduler(Scheduler): f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions") + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention except discord.HTTPException as e: log.exception(f"Failed to deactivate infraction #{id_} ({type_})") - log_text["Failure"] = f"HTTPException with code {e.code}." + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." log_content = mod_role.mention # Check if the user is currently being watched by Big Brother. diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 9b3c62403..7631d9bbe 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -6,9 +6,10 @@ import typing as t from pathlib import Path from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command from bot import constants +from bot.bot import Bot from bot.utils.checks import with_role_check from bot.utils.time import format_infraction from . import utils diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 78792240f..bf777ea5a 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -4,9 +4,10 @@ import logging from datetime import datetime, timedelta from discord import Colour, Embed -from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group +from discord.ext.commands import BadArgument, Cog, Context, Converter, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES from bot.decorators import with_role from bot.pagination import LinePaginator @@ -184,6 +185,5 @@ class OffTopicNames(Cog): def setup(bot: Bot) -> None: - """Off topic names cog load.""" + """Load the OffTopicNames cog.""" bot.add_cog(OffTopicNames(bot)) - log.info("Cog loaded: OffTopicNames") diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0d06e9c26..aa487f18e 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,13 +2,16 @@ import asyncio import logging import random import textwrap +from collections import namedtuple from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop +from bot.bot import Bot from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks from bot.converters import Subreddit from bot.decorators import with_role @@ -16,25 +19,32 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" - MAX_FETCH_RETRIES = 3 + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready - bot.loop.create_task(self.init_reddit_ready()) + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -47,20 +57,82 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() - url = f"{self.URL}/{route}.json" - for _ in range(self.MAX_FETCH_RETRIES): + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -217,6 +289,5 @@ class Reddit(Cog): def setup(bot: Bot) -> None: - """Reddit cog load.""" + """Load the Reddit cog.""" bot.add_cog(Reddit(bot)) - log.info("Cog loaded: Reddit") diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 81990704b..45bf9a8f4 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -8,8 +8,9 @@ from typing import Optional from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Message -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration from bot.pagination import LinePaginator @@ -290,6 +291,5 @@ class Reminders(Scheduler, Cog): def setup(bot: Bot) -> None: - """Reminders cog load.""" + """Load the Reminders cog.""" bot.add_cog(Reminders(bot)) - log.info("Cog loaded: Reminders") diff --git a/bot/cogs/security.py b/bot/cogs/security.py index 316b33d6b..c680c5e27 100644 --- a/bot/cogs/security.py +++ b/bot/cogs/security.py @@ -1,6 +1,8 @@ import logging -from discord.ext.commands import Bot, Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot log = logging.getLogger(__name__) @@ -25,6 +27,5 @@ class Security(Cog): def setup(bot: Bot) -> None: - """Security cog load.""" + """Load the Security cog.""" bot.add_cog(Security(bot)) - log.info("Cog loaded: Security") diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 683613788..2ea8c7a2e 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -1,8 +1,9 @@ import logging from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import URLs from bot.pagination import LinePaginator @@ -138,6 +139,5 @@ class Site(Cog): def setup(bot: Bot) -> None: - """Site cog load.""" + """Load the Site cog.""" bot.add_cog(Site(bot)) - log.info("Cog loaded: Site") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 55a187ac1..da33e27b2 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -5,8 +5,9 @@ import textwrap from signal import Signals from typing import Optional, Tuple -from discord.ext.commands import Bot, Cog, Context, command, guild_only +from discord.ext.commands import Cog, Context, command, guild_only +from bot.bot import Bot from bot.constants import Channels, Roles, URLs from bot.decorators import in_channel from bot.utils.messages import wait_for_deletion @@ -227,6 +228,5 @@ class Snekbox(Cog): def setup(bot: Bot) -> None: - """Snekbox cog load.""" + """Load the Snekbox cog.""" bot.add_cog(Snekbox(bot)) - log.info("Cog loaded: Snekbox") diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py index d4565f848..fe7df4e9b 100644 --- a/bot/cogs/sync/__init__.py +++ b/bot/cogs/sync/__init__.py @@ -1,13 +1,7 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .cog import Sync -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Sync cog load.""" + """Load the Sync cog.""" bot.add_cog(Sync(bot)) - log.info("Cog loaded: Sync") diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index aaa581f96..4e6ed156b 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,12 +1,13 @@ import logging -from typing import Callable, Iterable +from typing import Callable, Dict, Iterable, Union -from discord import Guild, Member, Role +from discord import Guild, Member, Role, User from discord.ext import commands -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.sync import syncers log = logging.getLogger(__name__) @@ -50,6 +51,15 @@ class Sync(Cog): f"deleted `{total_deleted}`." ) + async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None: + """Send a PATCH request to partially update a user in the database.""" + try: + await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information) + except ResponseCodeError as e: + if e.response.status != 404: + raise + log.warning("Unable to update user, got 404. Assuming race condition from join event.") + @Cog.listener() async def on_guild_role_create(self, role: Role) -> None: """Adds newly create role to the database table over the API.""" @@ -142,33 +152,21 @@ class Sync(Cog): @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: - """Updates the user information if any of relevant attributes have changed.""" - if ( - before.name != after.name - or before.avatar != after.avatar - or before.discriminator != after.discriminator - or before.roles != after.roles - ): - try: - await self.bot.api_client.put( - 'bot/users/' + str(after.id), - json={ - 'avatar_hash': after.avatar, - 'discriminator': int(after.discriminator), - 'id': after.id, - 'in_guild': True, - 'name': after.name, - 'roles': sorted(role.id for role in after.roles) - } - ) - except ResponseCodeError as e: - if e.response.status != 404: - raise - - log.warning( - "Unable to update user, got 404. " - "Assuming race condition from join event." - ) + """Update the roles of the member in the database if a change is detected.""" + if before.roles != after.roles: + updated_information = {"roles": sorted(role.id for role in after.roles)} + await self.patch_user(after.id, updated_information=updated_information) + + @Cog.listener() + async def on_user_update(self, before: User, after: User) -> None: + """Update the user information in the database if a relevant change is detected.""" + if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")): + updated_information = { + "name": after.name, + "discriminator": int(after.discriminator), + "avatar_hash": after.avatar, + } + await self.patch_user(after.id, updated_information=updated_information) @commands.group(name='sync') @commands.has_permissions(administrator=True) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 2cc5a66e1..14cf51383 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -2,7 +2,8 @@ from collections import namedtuple from typing import Dict, Set, Tuple from discord import Guild -from discord.ext.commands import Bot + +from bot.bot import Bot # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. @@ -52,7 +53,7 @@ async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: Synchronize roles found on the given `guild` with the ones on the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): @@ -169,7 +170,7 @@ async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: Synchronize users found in the given `guild` with the ones in the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index cd70e783a..970301013 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -2,8 +2,9 @@ import logging import time from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Cooldowns, MODERATION_ROLES, Roles from bot.converters import TagContentConverter, TagNameConverter from bot.decorators import with_role @@ -160,6 +161,5 @@ class Tags(Cog): def setup(bot: Bot) -> None: - """Tags cog load.""" + """Load the Tags cog.""" bot.add_cog(Tags(bot)) - log.info("Cog loaded: Tags") diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5a0d20e57..82c01ae96 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -6,9 +6,10 @@ import struct from datetime import datetime from discord import Colour, Message -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from discord.utils import snowflake_time +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Event, Icons @@ -52,39 +53,60 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ + if self.is_token_in_message(msg): + await self.take_action(msg) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + if self.is_token_in_message(after): + await self.take_action(after) + + async def take_action(self, msg: Message) -> None: + """Remove the `msg` containing a token an send a mod_log message.""" + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') + self.mod_log.ignore(Event.message_delete, msg.id) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + message = ( + "Censored a seemingly valid token sent by " + f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " + f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" + ) + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + @classmethod + def is_token_in_message(cls, msg: Message) -> bool: + """Check if `msg` contains a seemly valid token.""" if msg.author.bot: - return + return False maybe_match = TOKEN_RE.search(msg.content) if maybe_match is None: - return + return False try: user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') except ValueError: - return - - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): - self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - message = ( - "Censored a seemingly valid token sent by " - f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " - f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" - ) - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) + return False + + if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): + return True @staticmethod def is_valid_user_id(b64_content: str) -> bool: @@ -119,6 +141,5 @@ class TokenRemover(Cog): def setup(bot: Bot) -> None: - """Token Remover cog load.""" + """Load the TokenRemover cog.""" bot.add_cog(TokenRemover(bot)) - log.info("Cog loaded: TokenRemover") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 793fe4c1a..47a59db66 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -8,8 +8,9 @@ from typing import Tuple from dateutil import relativedelta from discord import Colour, Embed, Message, Role -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES from bot.decorators import in_channel, with_role from bot.utils.time import humanize_delta @@ -176,6 +177,5 @@ class Utils(Cog): def setup(bot: Bot) -> None: - """Utils cog load.""" + """Load the Utils cog.""" bot.add_cog(Utils(bot)) - log.info("Cog loaded: Utils") diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b5e8d4357..988e0d49a 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -3,15 +3,17 @@ from datetime import datetime from discord import Colour, Message, NotFound, Object from discord.ext import tasks -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Bot as BotConfig, Channels, Colours, Event, - Filter, Icons, Roles + Filter, Icons, MODERATION_ROLES, Roles ) from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -37,6 +39,7 @@ PERIODIC_PING = ( f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel." ) +BOT_MESSAGE_DELETE_DELAY = 10 class Verification(Cog): @@ -54,12 +57,16 @@ class Verification(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Check new message event for messages to the checkpoint channel & process.""" - if message.author.bot: - return # They're a bot, ignore - if message.channel.id != Channels.verification: return # Only listen for #checkpoint messages + if message.author.bot: + # They're a bot, delete their message after the delay. + # But not the periodic ping; we like that one. + if message.content != PERIODIC_PING: + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + # if a user mentions a role or guild member # alert the mods in mod-alerts channel if message.mentions or message.role_mentions: @@ -189,7 +196,7 @@ class Verification(Cog): @staticmethod def bot_check(ctx: Context) -> bool: """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == Channels.verification: + if ctx.channel.id == Channels.verification and without_role_check(ctx, *MODERATION_ROLES): return ctx.command.name == "accept" else: return True @@ -224,6 +231,5 @@ class Verification(Cog): def setup(bot: Bot) -> None: - """Verification cog load.""" + """Load the Verification cog.""" bot.add_cog(Verification(bot)) - log.info("Cog loaded: Verification") diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py index 86e1050fa..69d118df6 100644 --- a/bot/cogs/watchchannels/__init__.py +++ b/bot/cogs/watchchannels/__init__.py @@ -1,18 +1,9 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .bigbrother import BigBrother from .talentpool import TalentPool -log = logging.getLogger(__name__) - - def setup(bot: Bot) -> None: - """Monitoring cogs load.""" + """Load the BigBrother and TalentPool cogs.""" bot.add_cog(BigBrother(bot)) - log.info("Cog loaded: BigBrother") - bot.add_cog(TalentPool(bot)) - log.info("Cog loaded: TalentPool") diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 49783bb09..306ed4c64 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -3,8 +3,9 @@ from collections import ChainMap from typing import Union from discord import User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation.utils import post_infraction from bot.constants import Channels, MODERATION_ROLES, Webhooks from bot.decorators import with_role diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index 4ec42dcc1..cc8feeeee 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -4,9 +4,10 @@ from collections import ChainMap from typing import Union from discord import Color, Embed, Member, User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks from bot.decorators import with_role from bot.pagination import LinePaginator diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 0bf75a924..bd0622554 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -10,9 +10,10 @@ from typing import Optional import dateutil.parser import discord from discord import Color, Embed, HTTPException, Message, Object, errors -from discord.ext.commands import BadArgument, Bot, Cog, Context +from discord.ext.commands import BadArgument, Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py index ab0ed2472..5d6b4630b 100644 --- a/bot/cogs/wolfram.py +++ b/bot/cogs/wolfram.py @@ -7,8 +7,9 @@ import discord from dateutil.relativedelta import relativedelta from discord import Embed from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, check, group +from discord.ext.commands import BucketType, Cog, Context, check, group +from bot.bot import Bot from bot.constants import Colours, STAFF_ROLES, Wolfram from bot.pagination import ImagePaginator from bot.utils.time import humanize_delta @@ -151,7 +152,7 @@ async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tup class Wolfram(Cog): """Commands for interacting with the Wolfram|Alpha API.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) @@ -266,7 +267,6 @@ class Wolfram(Cog): await send_embed(ctx, message, color) -def setup(bot: commands.Bot) -> None: - """Wolfram cog load.""" +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" bot.add_cog(Wolfram(bot)) - log.info("Cog loaded: Wolfram") diff --git a/bot/constants.py b/bot/constants.py index 89504a2e0..2c0e3b10b 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -256,6 +256,8 @@ class Emojis(metaclass=YAMLGetter): status_idle: str status_dnd: str + failmail: str + bullet: str new: str pencil: str @@ -268,6 +270,12 @@ class Emojis(metaclass=YAMLGetter): ducky_ninja: int ducky_devil: int ducky_tube: int + ducky_hunt: int + ducky_wizard: int + ducky_party: int + ducky_angel: int + ducky_maul: int + ducky_santa: int upvotes: str comments: str @@ -465,6 +473,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converted to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") diff --git a/bot/interpreter.py b/bot/interpreter.py index 76a3fc293..8b7268746 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -2,7 +2,9 @@ from code import InteractiveInterpreter from io import StringIO from typing import Any -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot CODE_TEMPLATE = """ async def _func(): diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index c550aed76..00bb2a949 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -7,14 +7,14 @@ async def apply( last_message: Message, recent_messages: List[Message], config: Dict[str, int] ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: """Detects total attachments exceeding the limit sent by a single user.""" - relevant_messages = [last_message] + [ + relevant_messages = tuple( msg for msg in recent_messages if ( msg.author == last_message.author and len(msg.attachments) > 0 ) - ] + ) total_recent_attachments = sum(len(msg.attachments) for msg in relevant_messages) if total_recent_attachments > config['max']: diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( + expiry: Optional[str], + date_from: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: """ Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. @@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration( + expiry: Optional[str], + now: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: + """ + Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + + Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + now = now or datetime.datetime.utcnow() + since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + if since < now: + return None + + return humanize_delta(relativedelta(since, now), max_units=max_units) diff --git a/config-default.yml b/config-default.yml index f8337b6da..f28c79712 100644 --- a/config-default.yml +++ b/config-default.yml @@ -27,6 +27,8 @@ style: status_dnd: "<:status_dnd:470326272082313216>" status_offline: "<:status_offline:470326266537705472>" + failmail: "<:failmail:633660039931887616>" + bullet: "\u2022" pencil: "\u270F" new: "\U0001F195" @@ -39,6 +41,12 @@ style: ducky_ninja: &DUCKY_NINJA 637923502535606293 ducky_devil: &DUCKY_DEVIL 637925314982576139 ducky_tube: &DUCKY_TUBE 637881368008851456 + ducky_hunt: &DUCKY_HUNT 639355090909528084 + ducky_wizard: &DUCKY_WIZARD 639355996954689536 + ducky_party: &DUCKY_PARTY 639468753440210977 + ducky_angel: &DUCKY_ANGEL 640121935610511361 + ducky_maul: &DUCKY_MAUL 640137724958867467 + ducky_santa: &DUCKY_SANTA 655360331002019870 upvotes: "<:upvotes:638729835245731840>" comments: "<:comments:638729835073765387>" @@ -361,11 +369,18 @@ anti_malware: - '.png' - '.tiff' - '.wmv' + - '.svg' + - '.psd' # Photoshop + - '.ai' # Illustrator + - '.aep' # After Effects + - '.xcf' # GIMP reddit: subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: @@ -397,7 +412,7 @@ redirect_output: duck_pond: threshold: 5 - custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE] + custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA] config: required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index b801e86f1..d07b2bce1 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -578,15 +578,7 @@ class DuckPondSetupTests(unittest.TestCase): """Tests setup of the `DuckPond` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = helpers.MockBot() - log = logging.getLogger('bot.cogs.duck_pond') - - with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: - duck_pond.setup(bot) - - self.assertEqual(len(log_watcher.records), 1) - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.INFO) - + duck_pond.setup(bot) bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index efa7a50b1..9d1a62f7e 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -1,4 +1,3 @@ -import logging import unittest from unittest.mock import MagicMock @@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase): """Tests loading the `Security` cog.""" def test_security_cog_load(self): - """Cog loading logs a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MagicMock() - with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: - security.setup(bot) - bot.add_cog.assert_called_once() - - [line] = cm.output - self.assertIn("Cog loaded: Security", line) + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3276cf5a5..a54b839d7 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MockBot() - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: - setup_cog(bot) - - [line] = cm.output + setup_cog(bot) bot.add_cog.assert_called_once() - self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index 4bb0acf7c..d7187f315 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,52 +1,98 @@ -import asyncio import unittest -from dataclasses import dataclass -from typing import Any, List +from typing import List, NamedTuple, Tuple from bot.rules import attachments +from tests.helpers import MockMessage, async_test -# Using `MagicMock` sadly doesn't work for this usecase -# since it's __eq__ compares the MagicMock's ID. We just -# want to compare the actual attributes we set. -@dataclass -class FakeMessage: - author: str - attachments: List[Any] +class Case(NamedTuple): + recent_messages: List[MockMessage] + culprit: Tuple[str] + total_attachments: int -def msg(total_attachments: int) -> FakeMessage: - return FakeMessage(author='lemon', attachments=list(range(total_attachments))) +def msg(author: str, total_attachments: int) -> MockMessage: + """Builds a message with `total_attachments` attachments.""" + return MockMessage(author=author, attachments=list(range(total_attachments))) class AttachmentRuleTests(unittest.TestCase): - """Tests applying the `attachment` antispam rule.""" + """Tests applying the `attachments` antispam rule.""" - def test_allows_messages_without_too_many_attachments(self): + def setUp(self): + self.config = {"max": 5} + + @async_test + async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( - (msg(0), msg(0), msg(0)), - (msg(2), msg(2)), - (msg(0),), + [msg("bob", 0), msg("bob", 0), msg("bob", 0)], + [msg("bob", 2), msg("bob", 2)], + [msg("bob", 2), msg("alice", 2), msg("bob", 2)], ) - for last_message, *recent_messages in cases: - with self.subTest(last_message=last_message, recent_messages=recent_messages): - coro = attachments.apply(last_message, recent_messages, {'max': 5}) - self.assertIsNone(asyncio.run(coro)) + for recent_messages in cases: + last_message = recent_messages[0] + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + config=self.config + ): + self.assertIsNone( + await attachments.apply(last_message, recent_messages, self.config) + ) - def test_disallows_messages_with_too_many_attachments(self): + @async_test + async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( - ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), - ((msg(6),), [msg(6)], 6), - ((msg(1),) * 6, [msg(1)] * 6, 6), + Case( + [msg("bob", 4), msg("bob", 0), msg("bob", 6)], + ("bob",), + 10 + ), + Case( + [msg("bob", 4), msg("alice", 6), msg("bob", 2)], + ("bob",), + 6 + ), + Case( + [msg("alice", 6)], + ("alice",), + 6 + ), + ( + [msg("alice", 1) for _ in range(6)], + ("alice",), + 6 + ), ) - for messages, relevant_messages, total in cases: - with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total): - last_message, *recent_messages = messages - coro = attachments.apply(last_message, recent_messages, {'max': 5}) - self.assertEqual( - asyncio.run(coro), - (f"sent {total} attachments in 5s", ('lemon',), relevant_messages) + + for recent_messages, culprit, total_attachments in cases: + last_message = recent_messages[0] + relevant_messages = tuple( + msg + for msg in recent_messages + if ( + msg.author == last_message.author + and len(msg.attachments) > 0 + ) + ) + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + relevant_messages=relevant_messages, + total_attachments=total_attachments, + config=self.config + ): + desired_output = ( + f"sent {total_attachments} attachments in {self.config['max']}s", + culprit, + relevant_messages + ) + self.assertTupleEqual( + await attachments.apply(last_message, recent_messages, self.config), + desired_output ) diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index be832843b..02a5d5501 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -2,25 +2,19 @@ import unittest from typing import List, NamedTuple, Tuple from bot.rules import links -from tests.helpers import async_test - - -class FakeMessage(NamedTuple): - author: str - content: str +from tests.helpers import MockMessage, async_test class Case(NamedTuple): - recent_messages: List[FakeMessage] - relevant_messages: Tuple[FakeMessage] + recent_messages: List[MockMessage] culprit: Tuple[str] total_links: int -def msg(author: str, total_links: int) -> FakeMessage: - """Makes a message with *total_links* links.""" +def msg(author: str, total_links: int) -> MockMessage: + """Makes a message with `total_links` links.""" content = " ".join(["https://pydis.com"] * total_links) - return FakeMessage(author=author, content=content) + return MockMessage(author=author, content=content) class LinksTests(unittest.TestCase): @@ -61,26 +55,28 @@ class LinksTests(unittest.TestCase): cases = ( Case( [msg("bob", 1), msg("bob", 2)], - (msg("bob", 1), msg("bob", 2)), ("bob",), 3 ), Case( [msg("alice", 1), msg("alice", 1), msg("alice", 1)], - (msg("alice", 1), msg("alice", 1), msg("alice", 1)), ("alice",), 3 ), Case( [msg("alice", 2), msg("bob", 3), msg("alice", 1)], - (msg("alice", 2), msg("alice", 1)), ("alice",), 3 ) ) - for recent_messages, relevant_messages, culprit, total_links in cases: + for recent_messages, culprit, total_links in cases: last_message = recent_messages[0] + relevant_messages = tuple( + msg + for msg in recent_messages + if msg.author == last_message.author + ) with self.subTest( last_message=last_message, diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py new file mode 100644 index 000000000..ad49ead32 --- /dev/null +++ b/tests/bot/rules/test_mentions.py @@ -0,0 +1,95 @@ +import unittest +from typing import List, NamedTuple, Tuple + +from bot.rules import mentions +from tests.helpers import MockMessage, async_test + + +class Case(NamedTuple): + recent_messages: List[MockMessage] + culprit: Tuple[str] + total_mentions: int + + +def msg(author: str, total_mentions: int) -> MockMessage: + """Makes a message with `total_mentions` mentions.""" + return MockMessage(author=author, mentions=list(range(total_mentions))) + + +class TestMentions(unittest.TestCase): + """Tests applying the `mentions` antispam rule.""" + + def setUp(self): + self.config = { + "max": 2, + "interval": 10 + } + + @async_test + async def test_mentions_within_limit(self): + """Messages with an allowed amount of mentions.""" + cases = ( + [msg("bob", 0)], + [msg("bob", 2)], + [msg("bob", 1), msg("bob", 1)], + [msg("bob", 1), msg("alice", 2)] + ) + + for recent_messages in cases: + last_message = recent_messages[0] + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + config=self.config + ): + self.assertIsNone( + await mentions.apply(last_message, recent_messages, self.config) + ) + + @async_test + async def test_mentions_exceeding_limit(self): + """Messages with a higher than allowed amount of mentions.""" + cases = ( + Case( + [msg("bob", 3)], + ("bob",), + 3 + ), + Case( + [msg("alice", 2), msg("alice", 0), msg("alice", 1)], + ("alice",), + 3 + ), + Case( + [msg("bob", 2), msg("alice", 3), msg("bob", 2)], + ("bob",), + 4 + ) + ) + + for recent_messages, culprit, total_mentions in cases: + last_message = recent_messages[0] + relevant_messages = tuple( + msg + for msg in recent_messages + if msg.author == last_message.author + ) + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + relevant_messages=relevant_messages, + culprit=culprit, + total_mentions=total_mentions, + cofig=self.config + ): + desired_output = ( + f"sent {total_mentions} mentions in {self.config['interval']}s", + culprit, + relevant_messages + ) + self.assertTupleEqual( + await mentions.apply(last_message, recent_messages, self.config), + desired_output + ) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + + def test_humanize_delta_should_normal_usage(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" + test_cases = (-1, 0) + + for max_units in test_cases: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) + ) + + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + self.assertIs(asyncio.run(time.wait_until(then, start)), None) + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that now and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) diff --git a/tests/helpers.py b/tests/helpers.py index b2daae92d..5df796c23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,7 +10,9 @@ import unittest.mock from typing import Any, Iterable, Optional import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot for logger in logging.Logger.manager.loggerDict.values(): |