diff options
52 files changed, 604 insertions, 258 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index ea7c43a12..84bc7094b 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,18 +1,11 @@ -import asyncio -import logging -import socket -  import discord -from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord.ext.commands import Bot, when_mentioned_or +from discord.ext.commands import when_mentioned_or  from bot import patches -from bot.api import APIClient, APILoggingHandler +from bot.bot import Bot  from bot.constants import Bot as BotConfig, DEBUG_MODE -log = logging.getLogger('bot') -  bot = Bot(      command_prefix=when_mentioned_or(BotConfig.prefix),      activity=discord.Game(name="Commands: !help"), @@ -20,18 +13,6 @@ bot = Bot(      max_messages=10_000,  ) -# Global aiohttp session for all cogs -# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads -# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. -bot.http_session = ClientSession( -    connector=TCPConnector( -        resolver=AsyncResolver(), -        family=socket.AF_INET, -    ) -) -bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.addHandler(APILoggingHandler(bot.api_client)) -  # Internal/debug  bot.load_extension("bot.cogs.error_handler")  bot.load_extension("bot.cogs.filtering") @@ -77,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'):      patches.message_edited_at.apply_patch()  bot.run(BotConfig.token) - -# This calls a coroutine, so it doesn't do anything at the moment. -# bot.http_session.close()  # Close the aiohttp session when the bot finishes running diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError):  class APIClient:      """Django Site API wrapper.""" -    def __init__(self, **kwargs): +    def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):          auth_headers = {              'Authorization': f"Token {Keys.site_api}"          } @@ -42,12 +42,39 @@ class APIClient:          else:              kwargs['headers'] = auth_headers -        self.session = aiohttp.ClientSession(**kwargs) +        self.session: Optional[aiohttp.ClientSession] = None +        self.loop = loop + +        self._ready = asyncio.Event(loop=loop) +        self._creation_task = None +        self._session_args = kwargs + +        self.recreate()      @staticmethod      def _url_for(endpoint: str) -> str:          return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" +    async def _create_session(self) -> None: +        """Create the aiohttp session and set the ready event.""" +        self.session = aiohttp.ClientSession(**self._session_args) +        self._ready.set() + +    async def close(self) -> None: +        """Close the aiohttp session and unset the ready event.""" +        if not self._ready.is_set(): +            return + +        await self.session.close() +        self._ready.clear() + +    def recreate(self) -> None: +        """Schedule the aiohttp session to be created if it's been closed.""" +        if self.session is None or self.session.closed: +            # Don't schedule a task if one is already in progress. +            if self._creation_task is None or self._creation_task.done(): +                self._creation_task = self.loop.create_task(self._create_session()) +      async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None:          """Raise ResponseCodeError for non-OK response if an exception should be raised."""          if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient:      async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API GET.""" +        await self._ready.wait() +          async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API PATCH.""" +        await self._ready.wait() +          async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API POST.""" +        await self._ready.wait() +          async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:          """Site API PUT.""" +        await self._ready.wait() +          async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp:              await self.maybe_raise_for_status(resp, raise_for_status)              return await resp.json()      async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]:          """Site API DELETE.""" +        await self._ready.wait() +          async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp:              if resp.status == 204:                  return None diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 000000000..8f808272f --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,53 @@ +import logging +import socket +from typing import Optional + +import aiohttp +from discord.ext import commands + +from bot import api + +log = logging.getLogger('bot') + + +class Bot(commands.Bot): +    """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + +    def __init__(self, *args, **kwargs): +        # Use asyncio for DNS resolution instead of threads so threads aren't spammed. +        # Use AF_INET as its socket family to prevent HTTPS related problems both locally +        # and in production. +        self.connector = aiohttp.TCPConnector( +            resolver=aiohttp.AsyncResolver(), +            family=socket.AF_INET, +        ) + +        super().__init__(*args, connector=self.connector, **kwargs) + +        self.http_session: Optional[aiohttp.ClientSession] = None +        self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + +        log.addHandler(api.APILoggingHandler(self.api_client)) + +    def add_cog(self, cog: commands.Cog) -> None: +        """Adds a "cog" to the bot and logs the operation.""" +        super().add_cog(cog) +        log.info(f"Cog loaded: {cog.qualified_name}") + +    def clear(self) -> None: +        """Clears the internal state of the bot and resets the API client.""" +        super().clear() +        self.api_client.recreate() + +    async def close(self) -> None: +        """Close the aiohttp session after closing the Discord connection.""" +        await super().close() + +        await self.http_session.close() +        await self.api_client.close() + +    async def start(self, *args, **kwargs) -> None: +        """Open an aiohttp session before logging in and connecting to Discord.""" +        self.http_session = aiohttp.ClientSession(connector=self.connector) + +        await super().start(*args, **kwargs) diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 5190c559b..c1db38462 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -3,8 +3,9 @@ import logging  from typing import Union  from discord import Colour, Embed, Member, User -from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from discord.ext.commands import Cog, Command, Context, clean_content, command, group +from bot.bot import Bot  from bot.cogs.extensions import Extension  from bot.cogs.watchchannels.watchchannel import proxy_user  from bot.converters import TagNameConverter @@ -147,6 +148,5 @@ class Alias (Cog):  def setup(bot: Bot) -> None: -    """Alias cog load.""" +    """Load the Alias cog."""      bot.add_cog(Alias(bot)) -    log.info("Cog loaded: Alias") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 602819191..28e3e5d96 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,8 +1,9 @@  import logging  from discord import Embed, Message, NotFound -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot  from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs  log = logging.getLogger(__name__) @@ -49,6 +50,5 @@ class AntiMalware(Cog):  def setup(bot: Bot) -> None: -    """Antimalware cog load.""" +    """Load the AntiMalware cog."""      bot.add_cog(AntiMalware(bot)) -    log.info("Cog loaded: AntiMalware") diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 1340eb608..f454061a6 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -7,9 +7,10 @@ from operator import itemgetter  from typing import Dict, Iterable, List, Set  from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog  from bot import rules +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import (      AntiSpam as AntiSpamConfig, Channels, @@ -276,7 +277,6 @@ def validate_config(rules: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:  def setup(bot: Bot) -> None: -    """Antispam cog load.""" +    """Validate the AntiSpam configs and load the AntiSpam cog."""      validation_errors = validate_config()      bot.add_cog(AntiSpam(bot, validation_errors)) -    log.info("Cog loaded: AntiSpam") diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index ee0a463de..73b1e8f41 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -5,8 +5,10 @@ import time  from typing import Optional, Tuple  from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Bot, Cog, Context, command, group +from discord.ext.commands import Cog, Context, command, group +from bot.bot import Bot +from bot.cogs.token_remover import TokenRemover  from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs  from bot.decorators import with_role  from bot.utils.messages import wait_for_deletion @@ -16,7 +18,7 @@ log = logging.getLogger(__name__)  RE_MARKDOWN = re.compile(r'([*_~`|>])') -class Bot(Cog): +class BotCog(Cog, name="Bot"):      """Bot information commands."""      def __init__(self, bot: Bot): @@ -238,9 +240,10 @@ class Bot(Cog):              )              and not msg.author.bot              and len(msg.content.splitlines()) > 3 +            and not TokenRemover.is_token_in_message(msg)          ) -        if parse_codeblock: +        if parse_codeblock:  # no token in the msg              on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300              if not on_cooldown or DEBUG_MODE:                  try: @@ -373,10 +376,9 @@ class Bot(Cog):              bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id])              await bot_message.delete()              del self.codeblock_message_ids[payload.message_id] -            log.trace("User's incorrect code block has been fixed.  Removing bot formatting message.") +            log.trace("User's incorrect code block has been fixed. Removing bot formatting message.")  def setup(bot: Bot) -> None: -    """Bot cog load.""" -    bot.add_cog(Bot(bot)) -    log.info("Cog loaded: Bot") +    """Load the Bot cog.""" +    bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index dca411d01..2104efe57 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -3,9 +3,10 @@ import random  import re  from typing import Optional -from discord import Colour, Embed, Message, User -from discord.ext.commands import Bot, Cog, Context, group +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import (      Channels, CleanMessages, Colours, Event, @@ -37,9 +38,13 @@ class Clean(Cog):          return self.bot.get_cog("ModLog")      async def _clean_messages( -            self, amount: int, ctx: Context, -            bots_only: bool = False, user: User = None, -            regex: Optional[str] = None +        self, +        amount: int, +        ctx: Context, +        bots_only: bool = False, +        user: User = None, +        regex: Optional[str] = None, +        channel: Optional[TextChannel] = None      ) -> None:          """A helper function that does the actual message cleaning."""          def predicate_bots_only(message: Message) -> bool: @@ -104,6 +109,10 @@ class Clean(Cog):          else:              predicate = None                     # Delete all messages +        # Default to using the invoking context's channel +        if not channel: +            channel = ctx.channel +          # Look through the history and retrieve message data          messages = []          message_ids = [] @@ -111,7 +120,7 @@ class Clean(Cog):          invocation_deleted = False          # To account for the invocation message, we index `amount + 1` messages. -        async for message in ctx.channel.history(limit=amount + 1): +        async for message in channel.history(limit=amount + 1):              # If at any point the cancel command is invoked, we should stop.              if not self.cleaning: @@ -135,7 +144,7 @@ class Clean(Cog):          self.mod_log.ignore(Event.message_delete, *message_ids)          # Use bulk delete to actually do the cleaning. It's far faster. -        await ctx.channel.purge( +        await channel.purge(              limit=amount,              check=predicate          ) @@ -155,7 +164,7 @@ class Clean(Cog):          # Build the embed and send it          message = ( -            f"**{len(message_ids)}** messages deleted in <#{ctx.channel.id}> by **{ctx.author.name}**\n\n" +            f"**{len(message_ids)}** messages deleted in <#{channel.id}> by **{ctx.author.name}**\n\n"              f"A log of the deleted messages can be found [here]({log_url})."          ) @@ -167,7 +176,7 @@ class Clean(Cog):              channel_id=Channels.modlog,          ) -    @group(invoke_without_command=True, name="clean", hidden=True) +    @group(invoke_without_command=True, name="clean", aliases=["purge"])      @with_role(*MODERATION_ROLES)      async def clean_group(self, ctx: Context) -> None:          """Commands for cleaning messages in channels.""" @@ -175,27 +184,49 @@ class Clean(Cog):      @clean_group.command(name="user", aliases=["users"])      @with_role(*MODERATION_ROLES) -    async def clean_user(self, ctx: Context, user: User, amount: int = 10) -> None: +    async def clean_user( +        self, +        ctx: Context, +        user: User, +        amount: Optional[int] = 10, +        channel: TextChannel = None +    ) -> None:          """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" -        await self._clean_messages(amount, ctx, user=user) +        await self._clean_messages(amount, ctx, user=user, channel=channel)      @clean_group.command(name="all", aliases=["everything"])      @with_role(*MODERATION_ROLES) -    async def clean_all(self, ctx: Context, amount: int = 10) -> None: +    async def clean_all( +        self, +        ctx: Context, +        amount: Optional[int] = 10, +        channel: TextChannel = None +    ) -> None:          """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" -        await self._clean_messages(amount, ctx) +        await self._clean_messages(amount, ctx, channel=channel)      @clean_group.command(name="bots", aliases=["bot"])      @with_role(*MODERATION_ROLES) -    async def clean_bots(self, ctx: Context, amount: int = 10) -> None: +    async def clean_bots( +        self, +        ctx: Context, +        amount: Optional[int] = 10, +        channel: TextChannel = None +    ) -> None:          """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" -        await self._clean_messages(amount, ctx, bots_only=True) +        await self._clean_messages(amount, ctx, bots_only=True, channel=channel)      @clean_group.command(name="regex", aliases=["word", "expression"])      @with_role(*MODERATION_ROLES) -    async def clean_regex(self, ctx: Context, regex: str, amount: int = 10) -> None: +    async def clean_regex( +        self, +        ctx: Context, +        regex: str, +        amount: Optional[int] = 10, +        channel: TextChannel = None +    ) -> None:          """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" -        await self._clean_messages(amount, ctx, regex=regex) +        await self._clean_messages(amount, ctx, regex=regex, channel=channel)      @clean_group.command(name="stop", aliases=["cancel", "abort"])      @with_role(*MODERATION_ROLES) @@ -211,6 +242,5 @@ class Clean(Cog):  def setup(bot: Bot) -> None: -    """Clean cog load.""" +    """Load the Clean cog."""      bot.add_cog(Clean(bot)) -    log.info("Cog loaded: Clean") diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index bedd70c86..3e7350fcc 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -6,8 +6,9 @@ from datetime import datetime, timedelta  from enum import Enum  from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles  from bot.decorators import with_role @@ -236,6 +237,5 @@ class Defcon(Cog):  def setup(bot: Bot) -> None: -    """DEFCON cog load.""" +    """Load the Defcon cog."""      bot.add_cog(Defcon(bot)) -    log.info("Cog loaded: Defcon") diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index e5b3a4062..9506b195a 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -17,6 +17,7 @@ from requests import ConnectTimeout, ConnectionError, HTTPError  from sphinx.ext import intersphinx  from urllib3.exceptions import ProtocolError +from bot.bot import Bot  from bot.constants import MODERATION_ROLES, RedirectOutput  from bot.converters import ValidPythonIdentifier, ValidURL  from bot.decorators import with_role @@ -147,7 +148,7 @@ class InventoryURL(commands.Converter):  class Doc(commands.Cog):      """A set of commands for querying & displaying documentation.""" -    def __init__(self, bot: commands.Bot): +    def __init__(self, bot: Bot):          self.base_urls = {}          self.bot = bot          self.inventories = {} @@ -506,7 +507,6 @@ class Doc(commands.Cog):          return tag.name == "table" -def setup(bot: commands.Bot) -> None: -    """Doc cog load.""" +def setup(bot: Bot) -> None: +    """Load the Doc cog."""      bot.add_cog(Doc(bot)) -    log.info("Cog loaded: Doc") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 2d25cd17e..345d2856c 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -3,9 +3,10 @@ from typing import Optional, Union  import discord  from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog  from bot import constants +from bot.bot import Bot  from bot.utils.messages import send_attachments  log = logging.getLogger(__name__) @@ -177,6 +178,5 @@ class DuckPond(Cog):  def setup(bot: Bot) -> None: -    """Load the duck pond cog.""" +    """Load the DuckPond cog."""      bot.add_cog(DuckPond(bot)) -    log.info("Cog loaded: DuckPond") diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 49411814c..52893b2ee 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -14,9 +14,10 @@ from discord.ext.commands import (      NoPrivateMessage,      UserInputError,  ) -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.constants import Channels  from bot.decorators import InChannelCheckFailure @@ -75,6 +76,16 @@ class ErrorHandler(Cog):                  tags_get_command = self.bot.get_command("tags get")                  ctx.invoked_from_error_handler = True +                log_msg = "Cancelling attempt to fall back to a tag due to failed checks." +                try: +                    if not await tags_get_command.can_run(ctx): +                        log.debug(log_msg) +                        return +                except CommandError as tag_error: +                    log.debug(log_msg) +                    await self.on_command_error(ctx, tag_error) +                    return +                  # Return to not raise the exception                  with contextlib.suppress(ResponseCodeError):                      await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) @@ -143,6 +154,5 @@ class ErrorHandler(Cog):  def setup(bot: Bot) -> None: -    """Error handler cog load.""" +    """Load the ErrorHandler cog."""      bot.add_cog(ErrorHandler(bot)) -    log.info("Cog loaded: Events") diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 00b988dde..9c729f28a 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -9,8 +9,9 @@ from io import StringIO  from typing import Any, Optional, Tuple  import discord -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.constants import Roles  from bot.decorators import with_role  from bot.interpreter import Interpreter @@ -197,6 +198,5 @@ async def func():  # (None,) -> Any  def setup(bot: Bot) -> None: -    """Code eval cog load.""" +    """Load the CodeEval cog."""      bot.add_cog(CodeEval(bot)) -    log.info("Cog loaded: Eval") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index bb66e0b8e..f16e79fb7 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -6,8 +6,9 @@ from pkgutil import iter_modules  from discord import Colour, Embed  from discord.ext import commands -from discord.ext.commands import Bot, Context, group +from discord.ext.commands import Context, group +from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs  from bot.pagination import LinePaginator  from bot.utils.checks import with_role_check @@ -233,4 +234,3 @@ class Extensions(commands.Cog):  def setup(bot: Bot) -> None:      """Load the Extensions cog."""      bot.add_cog(Extensions(bot)) -    log.info("Cog loaded: Extensions") diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 1e7521054..74538542a 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -5,8 +5,9 @@ from typing import Optional, Union  import discord.errors  from dateutil.relativedelta import relativedelta  from discord import Colour, DMChannel, Member, Message, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import (      Channels, Colours, @@ -370,6 +371,5 @@ class Filtering(Cog):  def setup(bot: Bot) -> None: -    """Filtering cog load.""" +    """Load the Filtering cog."""      bot.add_cog(Filtering(bot)) -    log.info("Cog loaded: Filtering") diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 82285656b..49cab6172 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -3,8 +3,9 @@ from datetime import datetime  from operator import itemgetter  from discord import Colour, Embed, Member, utils -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot  from bot.constants import Categories, Channels, Free, STAFF_ROLES  from bot.decorators import redirect_output @@ -98,6 +99,5 @@ class Free(Cog):  def setup(bot: Bot) -> None: -    """Free cog load.""" +    """Load the Free cog."""      bot.add_cog(Free()) -    log.info("Cog loaded: Free") diff --git a/bot/cogs/help.py b/bot/cogs/help.py index 9607dbd8d..6385fa467 100644 --- a/bot/cogs/help.py +++ b/bot/cogs/help.py @@ -6,10 +6,11 @@ from typing import Union  from discord import Colour, Embed, HTTPException, Message, Reaction, User  from discord.ext import commands -from discord.ext.commands import Bot, CheckFailure, Cog as DiscordCog, Command, Context +from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context  from fuzzywuzzy import fuzz, process  from bot import constants +from bot.bot import Bot  from bot.constants import Channels, STAFF_ROLES  from bot.decorators import redirect_output  from bot.pagination import ( diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 530453600..1ede95ff4 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -9,10 +9,11 @@ from typing import Any, Mapping, Optional  import discord  from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils  from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, command, group +from discord.ext.commands import BucketType, Cog, Context, command, group  from discord.utils import escape_markdown  from bot import constants +from bot.bot import Bot  from bot.decorators import InChannelCheckFailure, in_channel, with_role  from bot.utils.checks import cooldown_with_role_bypass, with_role_check  from bot.utils.time import time_since @@ -391,6 +392,5 @@ class Information(Cog):  def setup(bot: Bot) -> None: -    """Information cog load.""" +    """Load the Information cog."""      bot.add_cog(Information(bot)) -    log.info("Cog loaded: Information") diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index be9d33e3e..985f28ce5 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -4,6 +4,7 @@ from discord import Member, PermissionOverwrite, utils  from discord.ext import commands  from more_itertools import unique_everseen +from bot.bot import Bot  from bot.constants import Roles  from bot.decorators import with_role @@ -13,7 +14,7 @@ log = logging.getLogger(__name__)  class CodeJams(commands.Cog):      """Manages the code-jam related parts of our server.""" -    def __init__(self, bot: commands.Bot): +    def __init__(self, bot: Bot):          self.bot = bot      @commands.command() @@ -108,7 +109,6 @@ class CodeJams(commands.Cog):          ) -def setup(bot: commands.Bot) -> None: -    """Code Jams cog load.""" +def setup(bot: Bot) -> None: +    """Load the CodeJams cog."""      bot.add_cog(CodeJams(bot)) -    log.info("Cog loaded: CodeJams") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index c92b619ff..d1b7dcab3 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -1,8 +1,9 @@  import logging  from discord import Embed -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot  from bot.constants import Channels, DEBUG_MODE @@ -37,6 +38,5 @@ class Logging(Cog):  def setup(bot: Bot) -> None: -    """Logging cog load.""" +    """Load the Logging cog."""      bot.add_cog(Logging(bot)) -    log.info("Cog loaded: Logging") diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 7383ed44e..5243cb92d 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,25 +1,13 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot  from .infractions import Infractions  from .management import ModManagement  from .modlog import ModLog  from .superstarify import Superstarify -log = logging.getLogger(__name__) -  def setup(bot: Bot) -> None: -    """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs).""" +    """Load the Infractions, ModManagement, ModLog, and Superstarify cogs."""      bot.add_cog(Infractions(bot)) -    log.info("Cog loaded: Infractions") -      bot.add_cog(ModLog(bot)) -    log.info("Cog loaded: ModLog") -      bot.add_cog(ModManagement(bot)) -    log.info("Cog loaded: ModManagement") -      bot.add_cog(Superstarify(bot)) -    log.info("Cog loaded: Superstarify") diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2713a1b68..3536a3d38 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -7,6 +7,7 @@ from discord.ext import commands  from discord.ext.commands import Context, command  from bot import constants +from bot.bot import Bot  from bot.constants import Event  from bot.decorators import respect_role_hierarchy  from bot.utils.checks import with_role_check @@ -25,7 +26,7 @@ class Infractions(InfractionScheduler, commands.Cog):      category = "Moderation"      category_description = "Server moderation tools." -    def __init__(self, bot: commands.Bot): +    def __init__(self, bot: Bot):          super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"})          self.category = "Moderation" @@ -208,8 +209,13 @@ class Infractions(InfractionScheduler, commands.Cog):          self.mod_log.ignore(Event.member_update, user.id) -        action = user.add_roles(self._muted_role, reason=reason) -        await self.apply_infraction(ctx, infraction, user, action) +        async def action() -> None: +            await user.add_roles(self._muted_role, reason=reason) + +            log.trace(f"Attempting to kick {user} from voice because they've been muted.") +            await user.move_to(None, reason=reason) + +        await self.apply_infraction(ctx, infraction, user, action())      @respect_role_hierarchy()      async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index abfe5c2b3..9605d47b2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -9,7 +9,8 @@ from discord.ext import commands  from discord.ext.commands import Context  from bot import constants -from bot.converters import InfractionSearchQuery +from bot.bot import Bot +from bot.converters import InfractionSearchQuery, allowed_strings  from bot.pagination import LinePaginator  from bot.utils import time  from bot.utils.checks import in_channel_check, with_role_check @@ -22,21 +23,12 @@ log = logging.getLogger(__name__)  UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: -    """Only allow an expiration to be 'permanent' if it is a string.""" -    expires_at = expires_at.lower() -    if expires_at != "permanent": -        raise commands.BadArgument -    else: -        return expires_at - -  class ModManagement(commands.Cog):      """Management of infractions."""      category = "Moderation" -    def __init__(self, bot: commands.Bot): +    def __init__(self, bot: Bot):          self.bot = bot      @property @@ -60,8 +52,8 @@ class ModManagement(commands.Cog):      async def infraction_edit(          self,          ctx: Context, -        infraction_id: int, -        duration: t.Union[utils.Expiry, permanent_duration, None], +        infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], +        duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],          *,          reason: str = None      ) -> None: @@ -78,21 +70,40 @@ class ModManagement(commands.Cog):          \u2003`M` - minutes∗          \u2003`s` - seconds -        Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp -        can be provided for the duration. +        Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction +        authored by the command invoker should be edited. + +        Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 +        timestamp can be provided for the duration.          """          if duration is None and reason is None:              # Unlike UserInputError, the error handler will show a specified message for BadArgument              raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")          # Retrieve the previous infraction for its information. -        old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') +        if isinstance(infraction_id, str): +            params = { +                "actor__id": ctx.author.id, +                "ordering": "-inserted_at" +            } +            infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + +            if infractions: +                old_infraction = infractions[0] +                infraction_id = old_infraction["id"] +            else: +                await ctx.send( +                    f":x: Couldn't find most recent infraction; you have never given an infraction." +                ) +                return +        else: +            old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")          request_data = {}          confirm_messages = []          log_text = "" -        if duration == "permanent": +        if isinstance(duration, str):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: @@ -129,7 +140,8 @@ class ModManagement(commands.Cog):                  New expiry: {new_infraction['expires_at'] or "Permanent"}              """.rstrip() -        await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") +        changes = ' & '.join(confirm_messages) +        await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")          # Get information about the infraction's user          user_id = new_infraction['user'] @@ -232,6 +244,12 @@ class ModManagement(commands.Cog):          user_id = infraction["user"]          hidden = infraction["hidden"]          created = time.format_infraction(infraction["inserted_at"]) + +        if active: +            remaining = time.until_expiration(infraction["expires_at"]) or "Expired" +        else: +            remaining = "Inactive" +          if infraction["expires_at"] is None:              expires = "*Permanent*"          else: @@ -247,6 +265,7 @@ class ModManagement(commands.Cog):              Reason: {infraction["reason"] or "*None*"}              Created: {created}              Expires: {expires} +            Remaining: {remaining}              Actor: {actor.mention if actor else actor_id}              ID: `{infraction["id"]}`              {"**===============**" if active else "==============="} diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 0df752a97..35ef6cbcc 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -10,8 +10,9 @@ from dateutil.relativedelta import relativedelta  from deepdiff import DeepDiff  from discord import Colour  from discord.abc import GuildChannel -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context +from bot.bot import Bot  from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs  from bot.utils.time import humanize_delta  from .utils import UserTypes diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 0ab1fe997..01e4b1fe7 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -7,10 +7,11 @@ from gettext import ngettext  import dateutil.parser  import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context  from bot import constants  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.constants import Colours, STAFF_CHANNELS  from bot.utils import time  from bot.utils.scheduling import Scheduler @@ -146,14 +147,18 @@ class InfractionScheduler(Scheduler):                  if expiry:                      # Schedule the expiration of the infraction.                      self.schedule_task(ctx.bot.loop, infraction["id"], infraction) -            except discord.Forbidden: +            except discord.HTTPException as e:                  # Accordingly display that applying the infraction failed.                  confirm_msg = f":x: failed to apply"                  expiry_msg = ""                  log_content = ctx.author.mention                  log_title = "failed to apply" -                log.warning(f"Failed to apply {infr_type} infraction #{id_} to {user}.") +                log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" +                if isinstance(e, discord.Forbidden): +                    log.warning(f"{log_msg}: bot lacks permissions.") +                else: +                    log.exception(log_msg)          # Send a confirmation message to the invoking context.          log.trace(f"Sending infraction #{id_} confirmation message.") @@ -323,12 +328,12 @@ class InfractionScheduler(Scheduler):                      f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!"                  )          except discord.Forbidden: -            log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions") +            log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.")              log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)"              log_content = mod_role.mention          except discord.HTTPException as e:              log.exception(f"Failed to deactivate infraction #{id_} ({type_})") -            log_text["Failure"] = f"HTTPException with code {e.code}." +            log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}."              log_content = mod_role.mention          # Check if the user is currently being watched by Big Brother. diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 9b3c62403..7631d9bbe 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -6,9 +6,10 @@ import typing as t  from pathlib import Path  from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command  from bot import constants +from bot.bot import Bot  from bot.utils.checks import with_role_check  from bot.utils.time import format_infraction  from . import utils diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 78792240f..bf777ea5a 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -4,9 +4,10 @@ import logging  from datetime import datetime, timedelta  from discord import Colour, Embed -from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group +from discord.ext.commands import BadArgument, Cog, Context, Converter, group  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES  from bot.decorators import with_role  from bot.pagination import LinePaginator @@ -184,6 +185,5 @@ class OffTopicNames(Cog):  def setup(bot: Bot) -> None: -    """Off topic names cog load.""" +    """Load the OffTopicNames cog."""      bot.add_cog(OffTopicNames(bot)) -    log.info("Cog loaded: OffTopicNames") diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0d06e9c26..bec316ae7 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -6,9 +6,10 @@ from datetime import datetime, timedelta  from typing import List  from discord import Colour, Embed, TextChannel -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group  from discord.ext.tasks import loop +from bot.bot import Bot  from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks  from bot.converters import Subreddit  from bot.decorators import with_role @@ -217,6 +218,5 @@ class Reddit(Cog):  def setup(bot: Bot) -> None: -    """Reddit cog load.""" +    """Load the Reddit cog."""      bot.add_cog(Reddit(bot)) -    log.info("Cog loaded: Reddit") diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 81990704b..45bf9a8f4 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -8,8 +8,9 @@ from typing import Optional  from dateutil.relativedelta import relativedelta  from discord import Colour, Embed, Message -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES  from bot.converters import Duration  from bot.pagination import LinePaginator @@ -290,6 +291,5 @@ class Reminders(Scheduler, Cog):  def setup(bot: Bot) -> None: -    """Reminders cog load.""" +    """Load the Reminders cog."""      bot.add_cog(Reminders(bot)) -    log.info("Cog loaded: Reminders") diff --git a/bot/cogs/security.py b/bot/cogs/security.py index 316b33d6b..c680c5e27 100644 --- a/bot/cogs/security.py +++ b/bot/cogs/security.py @@ -1,6 +1,8 @@  import logging -from discord.ext.commands import Bot, Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot  log = logging.getLogger(__name__) @@ -25,6 +27,5 @@ class Security(Cog):  def setup(bot: Bot) -> None: -    """Security cog load.""" +    """Load the Security cog."""      bot.add_cog(Security(bot)) -    log.info("Cog loaded: Security") diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 683613788..2ea8c7a2e 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -1,8 +1,9 @@  import logging  from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.constants import URLs  from bot.pagination import LinePaginator @@ -138,6 +139,5 @@ class Site(Cog):  def setup(bot: Bot) -> None: -    """Site cog load.""" +    """Load the Site cog."""      bot.add_cog(Site(bot)) -    log.info("Cog loaded: Site") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 55a187ac1..da33e27b2 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -5,8 +5,9 @@ import textwrap  from signal import Signals  from typing import Optional, Tuple -from discord.ext.commands import Bot, Cog, Context, command, guild_only +from discord.ext.commands import Cog, Context, command, guild_only +from bot.bot import Bot  from bot.constants import Channels, Roles, URLs  from bot.decorators import in_channel  from bot.utils.messages import wait_for_deletion @@ -227,6 +228,5 @@ class Snekbox(Cog):  def setup(bot: Bot) -> None: -    """Snekbox cog load.""" +    """Load the Snekbox cog."""      bot.add_cog(Snekbox(bot)) -    log.info("Cog loaded: Snekbox") diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py index d4565f848..fe7df4e9b 100644 --- a/bot/cogs/sync/__init__.py +++ b/bot/cogs/sync/__init__.py @@ -1,13 +1,7 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot  from .cog import Sync -log = logging.getLogger(__name__) -  def setup(bot: Bot) -> None: -    """Sync cog load.""" +    """Load the Sync cog."""      bot.add_cog(Sync(bot)) -    log.info("Cog loaded: Sync") diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index aaa581f96..90d4c40fe 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -3,10 +3,11 @@ from typing import Callable, Iterable  from discord import Guild, Member, Role  from discord.ext import commands -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context  from bot import constants  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.cogs.sync import syncers  log = logging.getLogger(__name__) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 2cc5a66e1..14cf51383 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -2,7 +2,8 @@ from collections import namedtuple  from typing import Dict, Set, Tuple  from discord import Guild -from discord.ext.commands import Bot + +from bot.bot import Bot  # These objects are declared as namedtuples because tuples are hashable,  # something that we make use of when diffing site roles against guild roles. @@ -52,7 +53,7 @@ async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]:      Synchronize roles found on the given `guild` with the ones on the API.      Arguments: -        bot (discord.ext.commands.Bot): +        bot (bot.bot.Bot):              The bot instance that we're running with.          guild (discord.Guild): @@ -169,7 +170,7 @@ async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]:      Synchronize users found in the given `guild` with the ones in the API.      Arguments: -        bot (discord.ext.commands.Bot): +        bot (bot.bot.Bot):              The bot instance that we're running with.          guild (discord.Guild): diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index cd70e783a..970301013 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -2,8 +2,9 @@ import logging  import time  from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.constants import Channels, Cooldowns, MODERATION_ROLES, Roles  from bot.converters import TagContentConverter, TagNameConverter  from bot.decorators import with_role @@ -160,6 +161,5 @@ class Tags(Cog):  def setup(bot: Bot) -> None: -    """Tags cog load.""" +    """Load the Tags cog."""      bot.add_cog(Tags(bot)) -    log.info("Cog loaded: Tags") diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5a0d20e57..82c01ae96 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -6,9 +6,10 @@ import struct  from datetime import datetime  from discord import Colour, Message -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog  from discord.utils import snowflake_time +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import Channels, Colours, Event, Icons @@ -52,39 +53,60 @@ class TokenRemover(Cog):          See: https://discordapp.com/developers/docs/reference#snowflakes          """ +        if self.is_token_in_message(msg): +            await self.take_action(msg) + +    @Cog.listener() +    async def on_message_edit(self, before: Message, after: Message) -> None: +        """ +        Check each edit for a string that matches Discord's token pattern. + +        See: https://discordapp.com/developers/docs/reference#snowflakes +        """ +        if self.is_token_in_message(after): +            await self.take_action(after) + +    async def take_action(self, msg: Message) -> None: +        """Remove the `msg` containing a token an send a mod_log message.""" +        user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') +        self.mod_log.ignore(Event.message_delete, msg.id) +        await msg.delete() +        await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + +        message = ( +            "Censored a seemingly valid token sent by " +            f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " +            f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" +        ) +        log.debug(message) + +        # Send pretty mod log embed to mod-alerts +        await self.mod_log.send_log_message( +            icon_url=Icons.token_removed, +            colour=Colour(Colours.soft_red), +            title="Token removed!", +            text=message, +            thumbnail=msg.author.avatar_url_as(static_format="png"), +            channel_id=Channels.mod_alerts, +        ) + +    @classmethod +    def is_token_in_message(cls, msg: Message) -> bool: +        """Check if `msg` contains a seemly valid token."""          if msg.author.bot: -            return +            return False          maybe_match = TOKEN_RE.search(msg.content)          if maybe_match is None: -            return +            return False          try:              user_id, creation_timestamp, hmac = maybe_match.group(0).split('.')          except ValueError: -            return - -        if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): -            self.mod_log.ignore(Event.message_delete, msg.id) -            await msg.delete() -            await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - -            message = ( -                "Censored a seemingly valid token sent by " -                f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " -                f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" -            ) -            log.debug(message) - -            # Send pretty mod log embed to mod-alerts -            await self.mod_log.send_log_message( -                icon_url=Icons.token_removed, -                colour=Colour(Colours.soft_red), -                title="Token removed!", -                text=message, -                thumbnail=msg.author.avatar_url_as(static_format="png"), -                channel_id=Channels.mod_alerts, -            ) +            return False + +        if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): +            return True      @staticmethod      def is_valid_user_id(b64_content: str) -> bool: @@ -119,6 +141,5 @@ class TokenRemover(Cog):  def setup(bot: Bot) -> None: -    """Token Remover cog load.""" +    """Load the TokenRemover cog."""      bot.add_cog(TokenRemover(bot)) -    log.info("Cog loaded: TokenRemover") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 793fe4c1a..47a59db66 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -8,8 +8,9 @@ from typing import Tuple  from dateutil import relativedelta  from discord import Colour, Embed, Message, Role -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES  from bot.decorators import in_channel, with_role  from bot.utils.time import humanize_delta @@ -176,6 +177,5 @@ class Utils(Cog):  def setup(bot: Bot) -> None: -    """Utils cog load.""" +    """Load the Utils cog."""      bot.add_cog(Utils(bot)) -    log.info("Cog loaded: Utils") diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b5e8d4357..988e0d49a 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -3,15 +3,17 @@ from datetime import datetime  from discord import Colour, Message, NotFound, Object  from discord.ext import tasks -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import (      Bot as BotConfig,      Channels, Colours, Event, -    Filter, Icons, Roles +    Filter, Icons, MODERATION_ROLES, Roles  )  from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.utils.checks import without_role_check  log = logging.getLogger(__name__) @@ -37,6 +39,7 @@ PERIODIC_PING = (      f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`."      f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel."  ) +BOT_MESSAGE_DELETE_DELAY = 10  class Verification(Cog): @@ -54,12 +57,16 @@ class Verification(Cog):      @Cog.listener()      async def on_message(self, message: Message) -> None:          """Check new message event for messages to the checkpoint channel & process.""" -        if message.author.bot: -            return  # They're a bot, ignore -          if message.channel.id != Channels.verification:              return  # Only listen for #checkpoint messages +        if message.author.bot: +            # They're a bot, delete their message after the delay. +            # But not the periodic ping; we like that one. +            if message.content != PERIODIC_PING: +                await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) +            return +          # if a user mentions a role or guild member          # alert the mods in mod-alerts channel          if message.mentions or message.role_mentions: @@ -189,7 +196,7 @@ class Verification(Cog):      @staticmethod      def bot_check(ctx: Context) -> bool:          """Block any command within the verification channel that is not !accept.""" -        if ctx.channel.id == Channels.verification: +        if ctx.channel.id == Channels.verification and without_role_check(ctx, *MODERATION_ROLES):              return ctx.command.name == "accept"          else:              return True @@ -224,6 +231,5 @@ class Verification(Cog):  def setup(bot: Bot) -> None: -    """Verification cog load.""" +    """Load the Verification cog."""      bot.add_cog(Verification(bot)) -    log.info("Cog loaded: Verification") diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py index 86e1050fa..69d118df6 100644 --- a/bot/cogs/watchchannels/__init__.py +++ b/bot/cogs/watchchannels/__init__.py @@ -1,18 +1,9 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot  from .bigbrother import BigBrother  from .talentpool import TalentPool -log = logging.getLogger(__name__) - -  def setup(bot: Bot) -> None: -    """Monitoring cogs load.""" +    """Load the BigBrother and TalentPool cogs."""      bot.add_cog(BigBrother(bot)) -    log.info("Cog loaded: BigBrother") -      bot.add_cog(TalentPool(bot)) -    log.info("Cog loaded: TalentPool") diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 49783bb09..306ed4c64 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -3,8 +3,9 @@ from collections import ChainMap  from typing import Union  from discord import User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot  from bot.cogs.moderation.utils import post_infraction  from bot.constants import Channels, MODERATION_ROLES, Webhooks  from bot.decorators import with_role diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index 4ec42dcc1..cc8feeeee 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -4,9 +4,10 @@ from collections import ChainMap  from typing import Union  from discord import Color, Embed, Member, User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks  from bot.decorators import with_role  from bot.pagination import LinePaginator diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 0bf75a924..bd0622554 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -10,9 +10,10 @@ from typing import Optional  import dateutil.parser  import discord  from discord import Color, Embed, HTTPException, Message, Object, errors -from discord.ext.commands import BadArgument, Bot, Cog, Context +from discord.ext.commands import BadArgument, Cog, Context  from bot.api import ResponseCodeError +from bot.bot import Bot  from bot.cogs.moderation import ModLog  from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons  from bot.pagination import LinePaginator diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py index ab0ed2472..5d6b4630b 100644 --- a/bot/cogs/wolfram.py +++ b/bot/cogs/wolfram.py @@ -7,8 +7,9 @@ import discord  from dateutil.relativedelta import relativedelta  from discord import Embed  from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, check, group +from discord.ext.commands import BucketType, Cog, Context, check, group +from bot.bot import Bot  from bot.constants import Colours, STAFF_ROLES, Wolfram  from bot.pagination import ImagePaginator  from bot.utils.time import humanize_delta @@ -151,7 +152,7 @@ async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tup  class Wolfram(Cog):      """Commands for interacting with the Wolfram|Alpha API.""" -    def __init__(self, bot: commands.Bot): +    def __init__(self, bot: Bot):          self.bot = bot      @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) @@ -266,7 +267,6 @@ class Wolfram(Cog):              await send_embed(ctx, message, color) -def setup(bot: commands.Bot) -> None: -    """Wolfram cog load.""" +def setup(bot: Bot) -> None: +    """Load the Wolfram cog."""      bot.add_cog(Wolfram(bot)) -    log.info("Cog loaded: Wolfram") diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@  import logging  import re +import typing as t  from datetime import datetime  from ssl import CertificateError -from typing import Union  import dateutil.parser  import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter  log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +    """ +    Return a converter which only allows arguments equal to one of the given values. + +    Unless preserve_case is True, the argument is converted to lowercase. All values are then +    expected to have already been given in lowercase too. +    """ +    def converter(arg: str) -> str: +        if not preserve_case: +            arg = arg.lower() + +        if arg not in values: +            raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") +        else: +            return arg + +    return converter + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):      """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""      @staticmethod -    async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: +    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:          """Check if the argument is a Discord user, and if not, falls back to a string."""          try:              maybe_snowflake = arg.strip("<@!>") diff --git a/bot/interpreter.py b/bot/interpreter.py index 76a3fc293..8b7268746 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -2,7 +2,9 @@ from code import InteractiveInterpreter  from io import StringIO  from typing import Any -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot  CODE_TEMPLATE = """  async def _func(): diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str:      return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( +    expiry: Optional[str], +    date_from: Optional[datetime.datetime] = None, +    max_units: int = 2 +) -> Optional[str]:      """      Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. @@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime =      duration_formatted = f" ({duration})" if duration else ''      return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration( +    expiry: Optional[str], +    now: Optional[datetime.datetime] = None, +    max_units: int = 2 +) -> Optional[str]: +    """ +    Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + +    Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. +    Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. +    `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). +    By default, max_units is 2. +    """ +    if not expiry: +        return None + +    now = now or datetime.datetime.utcnow() +    since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + +    if since < now: +        return None + +    return humanize_delta(relativedelta(since, now), max_units=max_units) diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index b801e86f1..d07b2bce1 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -578,15 +578,7 @@ class DuckPondSetupTests(unittest.TestCase):      """Tests setup of the `DuckPond` cog."""      def test_setup(self): -        """Setup of the cog should log a message at `INFO` level.""" +        """Setup of the extension should call add_cog."""          bot = helpers.MockBot() -        log = logging.getLogger('bot.cogs.duck_pond') - -        with self.assertLogs(logger=log, level=logging.INFO) as log_watcher: -            duck_pond.setup(bot) - -        self.assertEqual(len(log_watcher.records), 1) -        record = log_watcher.records[0] -        self.assertEqual(record.levelno, logging.INFO) - +        duck_pond.setup(bot)          bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index efa7a50b1..9d1a62f7e 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -1,4 +1,3 @@ -import logging  import unittest  from unittest.mock import MagicMock @@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase):      """Tests loading the `Security` cog."""      def test_security_cog_load(self): -        """Cog loading logs a message at `INFO` level.""" +        """Setup of the extension should call add_cog."""          bot = MagicMock() -        with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: -            security.setup(bot) -            bot.add_cog.assert_called_once() - -        [line] = cm.output -        self.assertIn("Cog loaded: Security", line) +        security.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3276cf5a5..a54b839d7 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase):      """Tests setup of the `TokenRemover` cog."""      def test_setup(self): -        """Setup of the cog should log a message at `INFO` level.""" +        """Setup of the extension should call add_cog."""          bot = MockBot() -        with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: -            setup_cog(bot) - -        [line] = cm.output +        setup_cog(bot)          bot.add_cog.assert_called_once() -        self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): +    """Test helper functions in bot.utils.time.""" + +    def test_humanize_delta_handle_unknown_units(self): +        """humanize_delta should be able to handle unknown units, and will not abort.""" +        # Does not abort for unknown units, as the unit name is checked +        # against the attribute of the relativedelta instance. +        self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + +    def test_humanize_delta_handle_high_units(self): +        """humanize_delta should be able to handle very high units.""" +        # Very high maximum units, but it only ever iterates over +        # each value the relativedelta might have. +        self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + +    def test_humanize_delta_should_normal_usage(self): +        """Testing humanize delta.""" +        test_cases = ( +            (relativedelta(days=2), 'seconds', 1, '2 days'), +            (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), +            (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), +            (relativedelta(days=2, hours=2), 'days', 2, '2 days'), +        ) + +        for delta, precision, max_units, expected in test_cases: +            with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): +                self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + +    def test_humanize_delta_raises_for_invalid_max_units(self): +        """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" +        test_cases = (-1, 0) + +        for max_units in test_cases: +            with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: +                time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) +                self.assertEqual(str(error), 'max_units must be positive') + +    def test_parse_rfc1123(self): +        """Testing parse_rfc1123.""" +        self.assertEqual( +            time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), +            datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) +        ) + +    def test_format_infraction(self): +        """Testing format_infraction.""" +        self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + +    @patch('asyncio.sleep', new_callable=AsyncMock) +    def test_wait_until(self, mock): +        """Testing wait_until.""" +        start = datetime(2019, 1, 1, 0, 0) +        then = datetime(2019, 1, 1, 0, 10) + +        # No return value +        self.assertIs(asyncio.run(time.wait_until(then, start)), None) + +        mock.assert_called_once_with(10 * 60) + +    def test_format_infraction_with_duration_none_expiry(self): +        """format_infraction_with_duration should work for None expiry.""" +        test_cases = ( +            (None, None, None, None), + +            # To make sure that date_from and max_units are not touched +            (None, 'Why hello there!', None, None), +            (None, None, float('inf'), None), +            (None, 'Why hello there!', float('inf'), None), +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_format_infraction_with_duration_custom_units(self): +        """format_infraction_with_duration should work for custom max_units.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, +             '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, +             '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_format_infraction_with_duration_normal_usage(self): +        """format_infraction_with_duration should work for normal usage, across various durations.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), +            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), +            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), +            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), +            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), +            ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, +             '2019-11-23 23:59 (9 minutes and 55 seconds)'), +            (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_until_expiration_with_duration_none_expiry(self): +        """until_expiration should work for None expiry.""" +        test_cases = ( +            (None, None, None, None), + +            # To make sure that now and max_units are not touched +            (None, 'Why hello there!', None, None), +            (None, None, float('inf'), None), +            (None, 'Why hello there!', float('inf'), None), +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + +    def test_until_expiration_with_duration_custom_units(self): +        """until_expiration should work for custom max_units.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + +    def test_until_expiration_normal_usage(self): +        """until_expiration should work for normal usage, across various durations.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), +            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), +            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), +            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), +            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), +            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), +            ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), +            (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) diff --git a/tests/helpers.py b/tests/helpers.py index b2daae92d..5df796c23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,7 +10,9 @@ import unittest.mock  from typing import Any, Iterable, Optional  import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot  for logger in logging.Logger.manager.loggerDict.values(): | 
