diff options
Diffstat (limited to '')
| -rw-r--r-- | .gitignore | 3 | ||||
| -rw-r--r-- | bot/__main__.py | 2 | ||||
| -rw-r--r-- | bot/cogs/alias.py | 7 | ||||
| -rw-r--r-- | bot/cogs/antispam.py | 6 | ||||
| -rw-r--r-- | bot/cogs/cogs.py | 298 | ||||
| -rw-r--r-- | bot/cogs/defcon.py | 6 | ||||
| -rw-r--r-- | bot/cogs/doc.py | 8 | ||||
| -rw-r--r-- | bot/cogs/extensions.py | 236 | ||||
| -rw-r--r-- | bot/cogs/logging.py | 6 | ||||
| -rw-r--r-- | bot/cogs/moderation/infractions.py | 57 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 17 | ||||
| -rw-r--r-- | bot/cogs/moderation/modlog.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/superstarify.py | 22 | ||||
| -rw-r--r-- | bot/cogs/moderation/utils.py | 2 | ||||
| -rw-r--r-- | bot/cogs/off_topic_names.py | 6 | ||||
| -rw-r--r-- | bot/cogs/reddit.py | 42 | ||||
| -rw-r--r-- | bot/cogs/reminders.py | 6 | ||||
| -rw-r--r-- | bot/cogs/sync/cog.py | 6 | ||||
| -rw-r--r-- | bot/cogs/token_remover.py | 10 | ||||
| -rw-r--r-- | bot/cogs/verification.py | 38 | ||||
| -rw-r--r-- | bot/utils/time.py | 19 | ||||
| -rw-r--r-- | docker-compose.yml | 2 | ||||
| -rw-r--r-- | tests/helpers.py | 4 | ||||
| -rw-r--r-- | tests/utils/test_time.py | 62 | 
24 files changed, 488 insertions, 381 deletions
| diff --git a/.gitignore b/.gitignore index 261fa179f..a191523b6 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ config.yml  # JUnit XML reports from pytest  junit.xml + +# Mac OS .DS_Store, which is a file that stores custom attributes of its containing folder +.DS_Store diff --git a/bot/__main__.py b/bot/__main__.py index d0924be78..19a7e5ec6 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -42,7 +42,7 @@ bot.load_extension("bot.cogs.security")  bot.load_extension("bot.cogs.antispam")  bot.load_extension("bot.cogs.bot")  bot.load_extension("bot.cogs.clean") -bot.load_extension("bot.cogs.cogs") +bot.load_extension("bot.cogs.extensions")  bot.load_extension("bot.cogs.help")  # Only load this in production diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 0f49a400c..6648805e9 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -5,6 +5,7 @@ from typing import Union  from discord import Colour, Embed, Member, User  from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from bot.cogs.extensions import Extension  from bot.cogs.watchchannels.watchchannel import proxy_user  from bot.converters import TagNameConverter  from bot.pagination import LinePaginator @@ -84,9 +85,9 @@ class Alias (Cog):          await self.invoke(ctx, "site rules")      @command(name="reload", hidden=True) -    async def cogs_reload_alias(self, ctx: Context, *, cog_name: str) -> None: -        """Alias for invoking <prefix>cogs reload [cog_name].""" -        await self.invoke(ctx, "cogs reload", cog_name) +    async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: +        """Alias for invoking <prefix>extensions reload [extensions...].""" +        await self.invoke(ctx, "extensions reload", *extensions)      @command(name="defon", hidden=True)      async def defcon_enable_alias(self, ctx: Context) -> None: diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index fd7e4edb0..1b394048a 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -107,14 +107,16 @@ class AntiSpam(Cog):          self.message_deletion_queue = dict()          self.queue_consumption_tasks = dict() +        self.bot.loop.create_task(self.alert_on_validation_error()) +      @property      def mod_log(self) -> ModLog:          """Allows for easy access of the ModLog cog."""          return self.bot.get_cog("ModLog") -    @Cog.listener() -    async def on_ready(self) -> None: +    async def alert_on_validation_error(self) -> None:          """Unloads the cog and alerts admins if configuration validation failed.""" +        await self.bot.wait_until_ready()          if self.validation_errors:              body = "**The following errors were encountered:**\n"              body += "\n".join(f"- {error}" for error in self.validation_errors.values()) diff --git a/bot/cogs/cogs.py b/bot/cogs/cogs.py deleted file mode 100644 index 1f6ccd09c..000000000 --- a/bot/cogs/cogs.py +++ /dev/null @@ -1,298 +0,0 @@ -import logging -import os - -from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group - -from bot.constants import ( -    Emojis, MODERATION_ROLES, Roles, URLs -) -from bot.decorators import with_role -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -KEEP_LOADED = ["bot.cogs.cogs", "bot.cogs.modlog"] - - -class Cogs(Cog): -    """Cog management commands.""" - -    def __init__(self, bot: Bot): -        self.bot = bot -        self.cogs = {} - -        # Load up the cog names -        log.info("Initializing cog names...") -        for filename in os.listdir("bot/cogs"): -            if filename.endswith(".py") and "_" not in filename: -                if os.path.isfile(f"bot/cogs/{filename}"): -                    cog = filename[:-3] - -                    self.cogs[cog] = f"bot.cogs.{cog}" - -        # Allow reverse lookups by reversing the pairs -        self.cogs.update({v: k for k, v in self.cogs.items()}) - -    @group(name='cogs', aliases=('c',), invoke_without_command=True) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def cogs_group(self, ctx: Context) -> None: -        """Load, unload, reload, and list active cogs.""" -        await ctx.invoke(self.bot.get_command("help"), "cogs") - -    @cogs_group.command(name='load', aliases=('l',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def load_command(self, ctx: Context, cog: str) -> None: -        """ -        Load up an unloaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we load the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog not in self.bot.extensions: -                try: -                    self.bot.load_extension(full_cog) -                except ImportError: -                    log.exception(f"{ctx.author} requested we load the '{cog}' cog, " -                                  f"but the cog module {full_cog} could not be found!") -                    embed.description = f"Invalid cog: {cog}\n\nCould not find cog module {full_cog}" -                except Exception as e: -                    log.exception(f"{ctx.author} requested we load the '{cog}' cog, " -                                  "but the loading failed") -                    embed.description = f"Failed to load cog: {cog}\n\n{e.__class__.__name__}: {e}" -                else: -                    log.debug(f"{ctx.author} requested we load the '{cog}' cog. Cog loaded!") -                    embed.description = f"Cog loaded: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we load the '{cog}' cog, but the cog was already loaded!") -                embed.description = f"Cog {cog} is already loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='unload', aliases=('ul',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def unload_command(self, ctx: Context, cog: str) -> None: -        """ -        Unload an already-loaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog in KEEP_LOADED: -                log.warning(f"{ctx.author} requested we unload `{full_cog}`, that sneaky pete. We said no.") -                embed.description = f"You may not unload `{full_cog}`!" -            elif full_cog in self.bot.extensions: -                try: -                    self.bot.unload_extension(full_cog) -                except Exception as e: -                    log.exception(f"{ctx.author} requested we unload the '{cog}' cog, " -                                  "but the unloading failed") -                    embed.description = f"Failed to unload cog: {cog}\n\n```{e}```" -                else: -                    log.debug(f"{ctx.author} requested we unload the '{cog}' cog. Cog unloaded!") -                    embed.description = f"Cog unloaded: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but the cog wasn't loaded!") -                embed.description = f"Cog {cog} is not loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='reload', aliases=('r',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def reload_command(self, ctx: Context, cog: str) -> None: -        """ -        Reload an unloaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. - -        If you specify "*" as the cog, every cog currently loaded will be unloaded, and then every cog present in the -        bot/cogs directory will be loaded. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog == "*": -            full_cog = cog -        elif cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog == "*": -                all_cogs = [ -                    f"bot.cogs.{fn[:-3]}" for fn in os.listdir("bot/cogs") -                    if os.path.isfile(f"bot/cogs/{fn}") and fn.endswith(".py") and "_" not in fn -                ] - -                failed_unloads = {} -                failed_loads = {} - -                unloaded = 0 -                loaded = 0 - -                for loaded_cog in self.bot.extensions.copy().keys(): -                    try: -                        self.bot.unload_extension(loaded_cog) -                    except Exception as e: -                        failed_unloads[loaded_cog] = f"{e.__class__.__name__}: {e}" -                    else: -                        unloaded += 1 - -                for unloaded_cog in all_cogs: -                    try: -                        self.bot.load_extension(unloaded_cog) -                    except Exception as e: -                        failed_loads[unloaded_cog] = f"{e.__class__.__name__}: {e}" -                    else: -                        loaded += 1 - -                lines = [ -                    "**All cogs reloaded**", -                    f"**Unloaded**: {unloaded} / **Loaded**: {loaded}" -                ] - -                if failed_unloads: -                    lines.append("\n**Unload failures**") - -                    for cog, error in failed_unloads: -                        lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`") - -                if failed_loads: -                    lines.append("\n**Load failures**") - -                    for cog, error in failed_loads.items(): -                        lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`") - -                log.debug(f"{ctx.author} requested we reload all cogs. Here are the results: \n" -                          f"{lines}") - -                await LinePaginator.paginate(lines, ctx, embed, empty=False) -                return - -            elif full_cog in self.bot.extensions: -                try: -                    self.bot.unload_extension(full_cog) -                    self.bot.load_extension(full_cog) -                except Exception as e: -                    log.exception(f"{ctx.author} requested we reload the '{cog}' cog, " -                                  "but the unloading failed") -                    embed.description = f"Failed to reload cog: {cog}\n\n```{e}```" -                else: -                    log.debug(f"{ctx.author} requested we reload the '{cog}' cog. Cog reloaded!") -                    embed.description = f"Cog reload: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but the cog wasn't loaded!") -                embed.description = f"Cog {cog} is not loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='list', aliases=('all',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def list_command(self, ctx: Context) -> None: -        """ -        Get a list of all cogs, including their loaded status. - -        Gray indicates that the cog is unloaded. Green indicates that the cog is currently loaded. -        """ -        embed = Embed() -        lines = [] -        cogs = {} - -        embed.colour = Colour.blurple() -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        for key, _value in self.cogs.items(): -            if "." not in key: -                continue - -            if key in self.bot.extensions: -                cogs[key] = True -            else: -                cogs[key] = False - -        for key in self.bot.extensions.keys(): -            if key not in self.cogs: -                cogs[key] = True - -        for cog, loaded in sorted(cogs.items(), key=lambda x: x[0]): -            if cog in self.cogs: -                cog = self.cogs[cog] - -            if loaded: -                status = Emojis.status_online -            else: -                status = Emojis.status_offline - -            lines.append(f"{status}  {cog}") - -        log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") -        await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) - - -def setup(bot: Bot) -> None: -    """Cogs cog load.""" -    bot.add_cog(Cogs(bot)) -    log.info("Cog loaded: Cogs") diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index ae0332688..70e101baa 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -35,14 +35,16 @@ class Defcon(Cog):          self.channel = None          self.days = timedelta(days=0) +        self.bot.loop.create_task(self.sync_settings()) +      @property      def mod_log(self) -> ModLog:          """Get currently loaded ModLog cog instance."""          return self.bot.get_cog("ModLog") -    @Cog.listener() -    async def on_ready(self) -> None: +    async def sync_settings(self) -> None:          """On cog load, try to synchronize DEFCON settings to the API.""" +        await self.bot.wait_until_ready()          self.channel = await self.bot.fetch_channel(Channels.defcon)          try:              response = await self.bot.api_client.get('bot/bot-settings/defcon') diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 0c5a8fce3..a13464bff 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -126,9 +126,11 @@ class Doc(commands.Cog):          self.bot = bot          self.inventories = {} -    @commands.Cog.listener() -    async def on_ready(self) -> None: -        """Refresh documentation inventory.""" +        self.bot.loop.create_task(self.init_refresh_inventory()) + +    async def init_refresh_inventory(self) -> None: +        """Refresh documentation inventory on cog initialization.""" +        await self.bot.wait_until_ready()          await self.refresh_inventory()      async def update_single( diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py new file mode 100644 index 000000000..bb66e0b8e --- /dev/null +++ b/bot/cogs/extensions.py @@ -0,0 +1,236 @@ +import functools +import logging +import typing as t +from enum import Enum +from pkgutil import iter_modules + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Bot, Context, group + +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + +UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} +EXTENSIONS = frozenset( +    ext.name +    for ext in iter_modules(("bot/cogs",), "bot.cogs.") +    if ext.name[-1] != "_" +) + + +class Action(Enum): +    """Represents an action to perform on an extension.""" + +    # Need to be partial otherwise they are considered to be function definitions. +    LOAD = functools.partial(Bot.load_extension) +    UNLOAD = functools.partial(Bot.unload_extension) +    RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if "." not in argument: +            argument = f"bot.cogs.{argument}" + +        if argument in EXTENSIONS: +            return argument +        else: +            raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): +    """Extension management commands.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) +    async def extensions_group(self, ctx: Context) -> None: +        """Load, unload, reload, and list loaded extensions.""" +        await ctx.invoke(self.bot.get_command("help"), "extensions") + +    @extensions_group.command(name="load", aliases=("l",)) +    async def load_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Load extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions load") +            return + +        if "*" in extensions or "**" in extensions: +            extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + +        msg = self.batch_manage(Action.LOAD, *extensions) +        await ctx.send(msg) + +    @extensions_group.command(name="unload", aliases=("ul",)) +    async def unload_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Unload currently loaded extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions unload") +            return + +        blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + +        if blacklisted: +            msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" +        else: +            if "*" in extensions or "**" in extensions: +                extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + +            msg = self.batch_manage(Action.UNLOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="reload", aliases=("r",)) +    async def reload_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Reload extensions given their fully qualified or unqualified names. + +        If an extension fails to be reloaded, it will be rolled-back to the prior working state. + +        If '\*' is given as the name, all currently loaded extensions will be reloaded. +        If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions reload") +            return + +        if "**" in extensions: +            extensions = EXTENSIONS +        elif "*" in extensions: +            extensions = set(self.bot.extensions.keys()) | set(extensions) +            extensions.remove("*") + +        msg = self.batch_manage(Action.RELOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="list", aliases=("all",)) +    async def list_command(self, ctx: Context) -> None: +        """ +        Get a list of all extensions, including their loaded status. + +        Grey indicates that the extension is unloaded. +        Green indicates that the extension is currently loaded. +        """ +        embed = Embed() +        lines = [] + +        embed.colour = Colour.blurple() +        embed.set_author( +            name="Extensions List", +            url=URLs.github_bot_repo, +            icon_url=URLs.bot_avatar +        ) + +        for ext in sorted(list(EXTENSIONS)): +            if ext in self.bot.extensions: +                status = Emojis.status_online +            else: +                status = Emojis.status_offline + +            ext = ext.rsplit(".", 1)[1] +            lines.append(f"{status}  {ext}") + +        log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") +        await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) + +    def batch_manage(self, action: Action, *extensions: str) -> str: +        """ +        Apply an action to multiple extensions and return a message with the results. + +        If only one extension is given, it is deferred to `manage()`. +        """ +        if len(extensions) == 1: +            msg, _ = self.manage(action, extensions[0]) +            return msg + +        verb = action.name.lower() +        failures = {} + +        for extension in extensions: +            _, error = self.manage(action, extension) +            if error: +                failures[extension] = error + +        emoji = ":x:" if failures else ":ok_hand:" +        msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + +        if failures: +            failures = "\n".join(f"{ext}\n    {err}" for ext, err in failures.items()) +            msg += f"\nFailures:```{failures}```" + +        log.debug(f"Batch {verb}ed extensions.") + +        return msg + +    def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: +        """Apply an action to an extension and return the status message and any error message.""" +        verb = action.name.lower() +        error_msg = None + +        try: +            action.value(self.bot, ext) +        except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): +            if action is Action.RELOAD: +                # When reloading, just load the extension if it was not loaded. +                return self.manage(Action.LOAD, ext) + +            msg = f":x: Extension `{ext}` is already {verb}ed." +            log.debug(msg[4:]) +        except Exception as e: +            if hasattr(e, "original"): +                e = e.original + +            log.exception(f"Extension '{ext}' failed to {verb}.") + +            error_msg = f"{e.__class__.__name__}: {e}" +            msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" +        else: +            msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." +            log.debug(msg[10:]) + +        return msg, error_msg + +    # This cannot be static (must have a __func__ attribute). +    def cog_check(self, ctx: Context) -> bool: +        """Only allow moderators and core developers to invoke the commands in this cog.""" +        return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer) + +    # This cannot be static (must have a __func__ attribute). +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Handle BadArgument errors locally to prevent the help command from showing.""" +        if isinstance(error, commands.BadArgument): +            await ctx.send(str(error)) +            error.handled = True + + +def setup(bot: Bot) -> None: +    """Load the Extensions cog.""" +    bot.add_cog(Extensions(bot)) +    log.info("Cog loaded: Extensions") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index 8e47bcc36..c92b619ff 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -15,9 +15,11 @@ class Logging(Cog):      def __init__(self, bot: Bot):          self.bot = bot -    @Cog.listener() -    async def on_ready(self) -> None: +        self.bot.loop.create_task(self.startup_greeting()) + +    async def startup_greeting(self) -> None:          """Announce our presence to the configured devlog channel.""" +        await self.bot.wait_until_ready()          log.info("Bot connected!")          embed = Embed(description="Connected!") diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2c075f436..592ead60f 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -12,7 +12,6 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.api import ResponseCodeError  from bot.constants import Colours, Event -from bot.converters import Duration  from bot.decorators import respect_role_hierarchy  from bot.utils import time  from bot.utils.checks import with_role_check @@ -113,7 +112,7 @@ class Infractions(Scheduler, commands.Cog):      # region: Temporary infractions      @command(aliases=["mute"]) -    async def tempmute(self, ctx: Context, user: Member, duration: Duration, *, reason: str = None) -> None: +    async def tempmute(self, ctx: Context, user: Member, duration: utils.Expiry, *, reason: str = None) -> None:          """          Temporarily mute a user for the given reason and duration. @@ -126,11 +125,13 @@ class Infractions(Scheduler, commands.Cog):          \u2003`h` - hours          \u2003`M` - minutes∗          \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration.          """          await self.apply_mute(ctx, user, reason, expires_at=duration)      @command() -    async def tempban(self, ctx: Context, user: MemberConverter, duration: Duration, *, reason: str = None) -> None: +    async def tempban(self, ctx: Context, user: MemberConverter, duration: utils.Expiry, *, reason: str = None) -> None:          """          Temporarily ban a user for the given reason and duration. @@ -143,6 +144,8 @@ class Infractions(Scheduler, commands.Cog):          \u2003`h` - hours          \u2003`M` - minutes∗          \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration.          """          await self.apply_ban(ctx, user, reason, expires_at=duration) @@ -172,9 +175,7 @@ class Infractions(Scheduler, commands.Cog):      # region: Temporary shadow infractions      @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"]) -    async def shadow_tempmute( -        self, ctx: Context, user: Member, duration: Duration, *, reason: str = None -    ) -> None: +    async def shadow_tempmute(self, ctx: Context, user: Member, duration: utils.Expiry, *, reason: str = None) -> None:          """          Temporarily mute a user for the given reason and duration without notifying the user. @@ -187,12 +188,19 @@ class Infractions(Scheduler, commands.Cog):          \u2003`h` - hours          \u2003`M` - minutes∗          \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration.          """          await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True)      @command(hidden=True, aliases=["shadowtempban, stempban"])      async def shadow_tempban( -        self, ctx: Context, user: MemberConverter, duration: Duration, *, reason: str = None +        self, +        ctx: Context, +        user: MemberConverter, +        duration: utils.Expiry, +        *, +        reason: str = None      ) -> None:          """          Temporarily ban a user for the given reason and duration without notifying the user. @@ -206,6 +214,8 @@ class Infractions(Scheduler, commands.Cog):          \u2003`h` - hours          \u2003`M` - minutes∗          \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration.          """          await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) @@ -261,7 +271,6 @@ class Infractions(Scheduler, commands.Cog):          if infraction is None:              return -        self.mod_log.ignore(Event.member_ban, user.id)          self.mod_log.ignore(Event.member_remove, user.id)          action = ctx.guild.ban(user, reason=reason, delete_message_days=0) @@ -311,7 +320,8 @@ class Infractions(Scheduler, commands.Cog):          log_content = None          log_text = {              "Member": str(user_id), -            "Actor": str(self.bot.user) +            "Actor": str(self.bot.user), +            "Reason": infraction["reason"]          }          try: @@ -356,6 +366,22 @@ class Infractions(Scheduler, commands.Cog):              log_text["Failure"] = f"HTTPException with code {e.code}."              log_content = mod_role.mention +        # Check if the user is currently being watched by Big Brother. +        try: +            active_watch = await self.bot.api_client.get( +                "bot/infractions", +                params={ +                    "active": "true", +                    "type": "watch", +                    "user__id": user_id +                } +            ) + +            log_text["Watching"] = "Yes" if active_watch else "No" +        except ResponseCodeError: +            log.exception(f"Failed to fetch watch status for user {user_id}") +            log_text["Watching"] = "Unknown - failed to fetch watch status." +          try:              # Mark infraction as inactive in the database.              await self.bot.api_client.patch( @@ -416,7 +442,6 @@ class Infractions(Scheduler, commands.Cog):          expiry_log_text = f"Expires: {expiry}" if expiry else ""          log_title = "applied"          log_content = None -        reason_msg = ""          # DM the user about the infraction if it's not a shadow/hidden infraction.          if not infraction["hidden"]: @@ -432,7 +457,13 @@ class Infractions(Scheduler, commands.Cog):                  log_content = ctx.author.mention          if infraction["actor"] == self.bot.user.id: -            reason_msg = f" (reason: {infraction['reason']})" +            end_msg = f" (reason: {infraction['reason']})" +        else: +            infractions = await self.bot.api_client.get( +                "bot/infractions", +                params={"user__id": str(user.id)} +            ) +            end_msg = f" ({len(infractions)} infractions total)"          # Execute the necessary actions to apply the infraction on Discord.          if action_coro: @@ -449,7 +480,9 @@ class Infractions(Scheduler, commands.Cog):                  log_title = "failed to apply"          # Send a confirmation message to the invoking context. -        await ctx.send(f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{reason_msg}.") +        await ctx.send( +            f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." +        )          # Send a log message to the mod log.          await self.mod_log.send_log_message( diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index cb266b608..491f6d400 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -8,7 +8,7 @@ from discord.ext import commands  from discord.ext.commands import Context  from bot import constants -from bot.converters import Duration, InfractionSearchQuery +from bot.converters import InfractionSearchQuery  from bot.pagination import LinePaginator  from bot.utils import time  from bot.utils.checks import with_role_check @@ -60,7 +60,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction_id: int, -        expires_at: t.Union[Duration, permanent_duration, None], +        duration: t.Union[utils.Expiry, permanent_duration, None],          *,          reason: str = None      ) -> None: @@ -77,9 +77,10 @@ class ModManagement(commands.Cog):          \u2003`M` - minutes∗          \u2003`s` - seconds -        Use "permanent" to mark the infraction as permanent. +        Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp +        can be provided for the duration.          """ -        if expires_at is None and reason is None: +        if duration is None and reason is None:              # Unlike UserInputError, the error handler will show a specified message for BadArgument              raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") @@ -90,12 +91,12 @@ class ModManagement(commands.Cog):          confirm_messages = []          log_text = "" -        if expires_at == "permanent": +        if duration == "permanent":              request_data['expires_at'] = None              confirm_messages.append("marked as permanent") -        elif expires_at is not None: -            request_data['expires_at'] = expires_at.isoformat() -            expiry = expires_at.strftime(time.INFRACTION_FORMAT) +        elif duration is not None: +            request_data['expires_at'] = duration.isoformat() +            expiry = duration.strftime(time.INFRACTION_FORMAT)              confirm_messages.append(f"set to expire on {expiry}")          else:              confirm_messages.append("expiry unchanged") diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 92e9b0ef1..118503517 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -353,7 +353,7 @@ class ModLog(Cog, name="ModLog"):      @Cog.listener()      async def on_member_ban(self, guild: discord.Guild, member: UserTypes) -> None: -        """Log ban event to mod log.""" +        """Log ban event to user log."""          if guild.id != GuildConstant.id:              return @@ -365,7 +365,7 @@ class ModLog(Cog, name="ModLog"):              Icons.user_ban, Colours.soft_red,              "User banned", f"{member.name}#{member.discriminator} (`{member.id}`)",              thumbnail=member.avatar_url_as(static_format="png"), -            channel_id=Channels.modlog +            channel_id=Channels.userlog          )      @Cog.listener() diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index f3fcf236b..ccc6395d9 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -8,7 +8,6 @@ from discord.errors import Forbidden  from discord.ext.commands import Bot, Cog, Context, command  from bot import constants -from bot.converters import Duration  from bot.utils.checks import with_role_check  from bot.utils.time import format_infraction  from . import utils @@ -144,21 +143,30 @@ class Superstarify(Cog):              )      @command(name='superstarify', aliases=('force_nick', 'star')) -    async def superstarify( -        self, ctx: Context, member: Member, expiration: Duration, reason: str = None -    ) -> None: +    async def superstarify(self, ctx: Context, member: Member, duration: utils.Expiry, reason: str = None) -> None:          """          Force a random superstar name (like Taylor Swift) to be the user's nickname for a specified duration. -        An optional reason can be provided. +        A unit of time should be appended to the duration. +        Units (∗case-sensitive): +        \u2003`y` - years +        \u2003`m` - months∗ +        \u2003`w` - weeks +        \u2003`d` - days +        \u2003`h` - hours +        \u2003`M` - minutes∗ +        \u2003`s` - seconds -        If no reason is given, the original name will be shown in a generated reason. +        Alternatively, an ISO 8601 timestamp can be provided for the duration. + +        An optional reason can be provided. If no reason is given, the original name will be shown +        in a generated reason.          """          if await utils.has_active_infraction(ctx, member, "superstar"):              return          reason = reason or ('old nick: ' + member.display_name) -        infraction = await utils.post_infraction(ctx, member, 'superstar', reason, expires_at=expiration) +        infraction = await utils.post_infraction(ctx, member, 'superstar', reason, expires_at=duration)          forced_nick = self.get_nick(infraction['id'], member.id)          expiry_str = format_infraction(infraction["expires_at"]) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index e9c879b46..788a40d40 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -9,6 +9,7 @@ from discord.ext.commands import Context  from bot.api import ResponseCodeError  from bot.constants import Colours, Icons +from bot.converters import Duration, ISODateTime  log = logging.getLogger(__name__) @@ -26,6 +27,7 @@ APPEALABLE_INFRACTIONS = ("ban", "mute")  UserTypes = t.Union[discord.Member, discord.User]  MemberObject = t.Union[UserTypes, discord.Object]  Infraction = t.Dict[str, t.Union[str, int, bool]] +Expiry = t.Union[Duration, ISODateTime]  def proxy_user(user_id: str) -> discord.Object: diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 16717d523..2977e4ebb 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -75,14 +75,16 @@ class OffTopicNames(Cog):          self.bot = bot          self.updater_task = None +        self.bot.loop.create_task(self.init_offtopic_updater()) +      def cog_unload(self) -> None:          """Cancel any running updater tasks on cog unload."""          if self.updater_task is not None:              self.updater_task.cancel() -    @Cog.listener() -    async def on_ready(self) -> None: +    async def init_offtopic_updater(self) -> None:          """Start off-topic channel updating event loop if it hasn't already started.""" +        await self.bot.wait_until_ready()          if self.updater_task is None:              coro = update_names(self.bot)              self.updater_task = self.bot.loop.create_task(coro) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index bf4403ce4..7b183221c 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,10 +2,10 @@ import asyncio  import logging  import random  import textwrap -from aiohttp import BasicAuth  from datetime import datetime, timedelta  from typing import List +from aiohttp import BasicAuth  from discord import Colour, Embed, Message, TextChannel  from discord.ext import tasks  from discord.ext.commands import Bot, Cog, Context, group @@ -41,11 +41,11 @@ class Reddit(Cog):          self.new_posts_task = None          self.top_weekly_posts_task = None -        self.refresh_access_token.start() +        self.bot.loop.create_task(self.init_reddit_polling())      @tasks.loop(hours=0.99)  # access tokens are valid for one hour      async def refresh_access_token(self) -> None: -        """Refresh the access token""" +        """Refresh Reddits access token."""          headers = {"Authorization": self.client_auth}          data = {              "grant_type": "refresh_token", @@ -53,7 +53,7 @@ class Reddit(Cog):          }          response = await self.bot.http_session.post( -            url = f"{self.URL}/api/v1/access_token", +            url=f"{self.URL}/api/v1/access_token",              headers=headers,              data=data,          ) @@ -67,25 +67,23 @@ class Reddit(Cog):      @refresh_access_token.before_loop      async def get_tokens(self) -> None: -        """Get Reddit access and refresh tokens""" -        await self.bot.wait_until_ready() - +        """Get Reddit access and refresh tokens."""          headers = {"User-Agent": self.USER_AGENT}          data = {              "grant_type": "client_credentials",              "duration": "permanent"          } -        if RedditConfig.client_id and RedditConfig.secret: -            self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) +        self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) -            response = await self.bot.http_session.post( -                url=f"{self.URL}/api/v1/access_token", -                headers=headers, -                auth=self.client_auth, -                data=data -            ) +        response = await self.bot.http_session.post( +            url=f"{self.URL}/api/v1/access_token", +            headers=headers, +            auth=self.client_auth, +            data=data +        ) +        if response.status == 200 and response.content_type == "application/json":              content = await response.json()              self.access_token = content["access_token"]              self.refresh_token = content["refresh_token"] @@ -94,12 +92,9 @@ class Reddit(Cog):                  "User-Agent": self.USER_AGENT              }          else: -            self.client_auth = None -            self.access_token = None -            self.refresh_token = None -            self.headers = None - -            log.error("Unable to find client credentials.") +            log.error("Authentication with Reddit API failed. Unloading extension.") +            self.bot.remove_cog(self.__class__.__name__) +            return      async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:          """A helper method to fetch a certain amount of Reddit posts at a given route.""" @@ -329,10 +324,11 @@ class Reddit(Cog):              max_lines=15          ) -    @Cog.listener() -    async def on_ready(self) -> None: +    async def init_reddit_polling(self) -> None:          """Initiate reddit post event loop.""" +        await self.bot.wait_until_ready()          self.reddit_channel = await self.bot.fetch_channel(Channels.reddit) +        self.refresh_access_token.start()          if self.reddit_channel is not None:              if self.new_posts_task is None: diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 6e91d2c06..b54622306 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -30,9 +30,11 @@ class Reminders(Scheduler, Cog):          self.bot = bot          super().__init__() -    @Cog.listener() -    async def on_ready(self) -> None: +        self.bot.loop.create_task(self.reschedule_reminders()) + +    async def reschedule_reminders(self) -> None:          """Get all current reminders from the API and reschedule them.""" +        await self.bot.wait_until_ready()          response = await self.bot.api_client.get(              'bot/reminders',              params={'active': 'true'} diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index b75fb26cd..aaa581f96 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -29,9 +29,11 @@ class Sync(Cog):      def __init__(self, bot: Bot) -> None:          self.bot = bot -    @Cog.listener() -    async def on_ready(self) -> None: +        self.bot.loop.create_task(self.sync_guild()) + +    async def sync_guild(self) -> None:          """Syncs the roles/users of the guild with the database.""" +        await self.bot.wait_until_ready()          guild = self.bot.get_guild(self.SYNC_SERVER_ID)          if guild is not None:              for syncer in self.ON_READY_SYNCERS: diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 4a655d049..5a0d20e57 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -26,11 +26,11 @@ DELETION_MESSAGE_TEMPLATE = (  DISCORD_EPOCH_TIMESTAMP = datetime(2017, 1, 1)  TOKEN_EPOCH = 1_293_840_000  TOKEN_RE = re.compile( -    r"[^\s\.]+"     # Matches token part 1: The user ID string, encoded as base64 -    r"\."           # Matches a literal dot between the token parts -    r"[^\s\.]+"     # Matches token part 2: The creation timestamp, as an integer -    r"\."           # Matches a literal dot between the token parts -    r"[^\s\.]+"     # Matches token part 3: The HMAC, unused by us, but check that it isn't empty +    r"[^\s\.()\"']+"  # Matches token part 1: The user ID string, encoded as base64 +    r"\."             # Matches a literal dot between the token parts +    r"[^\s\.()\"']+"  # Matches token part 2: The creation timestamp, as an integer +    r"\."             # Matches a literal dot between the token parts +    r"[^\s\.()\"']+"  # Matches token part 3: The HMAC, unused by us, but check that it isn't empty  ) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index acd7a7865..5b115deaa 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -1,10 +1,12 @@  import logging +from datetime import datetime  from discord import Message, NotFound, Object +from discord.ext import tasks  from discord.ext.commands import Bot, Cog, Context, command  from bot.cogs.moderation import ModLog -from bot.constants import Channels, Event, Roles +from bot.constants import Bot as BotConfig, Channels, Event, Roles  from bot.decorators import InChannelCheckFailure, in_channel, without_role  log = logging.getLogger(__name__) @@ -27,12 +29,18 @@ from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to  If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>.  """ +PERIODIC_PING = ( +    f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." +    f" Ping <@&{Roles.admin}> if you encounter any problems during the verification process." +) +  class Verification(Cog):      """User verification and role self-management."""      def __init__(self, bot: Bot):          self.bot = bot +        self.periodic_ping.start()      @property      def mod_log(self) -> ModLog: @@ -155,6 +163,34 @@ class Verification(Cog):          else:              return True +    @tasks.loop(hours=12) +    async def periodic_ping(self) -> None: +        """Every week, mention @everyone to remind them to verify.""" +        messages = self.bot.get_channel(Channels.verification).history(limit=10) +        need_to_post = True  # True if a new message needs to be sent. + +        async for message in messages: +            if message.author == self.bot.user and message.content == PERIODIC_PING: +                delta = datetime.utcnow() - message.created_at  # Time since last message. +                if delta.days >= 7:  # Message is older than a week. +                    await message.delete() +                else: +                    need_to_post = False + +                break + +        if need_to_post: +            await self.bot.get_channel(Channels.verification).send(PERIODIC_PING) + +    @periodic_ping.before_loop +    async def before_ping(self) -> None: +        """Only start the loop when the bot is ready.""" +        await self.bot.wait_until_ready() + +    def cog_unload(self) -> None: +        """Cancel the periodic ping task when the cog is unloaded.""" +        self.periodic_ping.cancel() +  def setup(bot: Bot) -> None:      """Verification cog load.""" diff --git a/bot/utils/time.py b/bot/utils/time.py index da28f2c76..2aea2c099 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,5 +1,6 @@  import asyncio  import datetime +from typing import Optional  import dateutil.parser  from dateutil.relativedelta import relativedelta @@ -34,6 +35,9 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units:      precision specifies the smallest unit of time to include (e.g. "seconds", "minutes").      max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).      """ +    if max_units <= 0: +        raise ValueError("max_units must be positive") +      units = (          ("years", delta.years),          ("months", delta.months), @@ -83,15 +87,20 @@ def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max      return f"{humanized} ago" -def parse_rfc1123(time_str: str) -> datetime.datetime: +def parse_rfc1123(stamp: str) -> datetime.datetime:      """Parse RFC1123 time string into datetime.""" -    return datetime.datetime.strptime(time_str, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc) +    return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc)  # Hey, this could actually be used in the off_topic_names and reddit cogs :) -async def wait_until(time: datetime.datetime) -> None: -    """Wait until a given time.""" -    delay = time - datetime.datetime.utcnow() +async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] = None) -> None: +    """ +    Wait until a given time. + +    :param time: A datetime.datetime object to wait until. +    :param start: The start from which to calculate the waiting duration. Defaults to UTC time. +    """ +    delay = time - (start or datetime.datetime.utcnow())      delay_seconds = delay.total_seconds()      # Incorporate a small delay so we don't rapid-fire the event due to time precision errors diff --git a/docker-compose.yml b/docker-compose.yml index 9684a3c62..f79fdba58 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,7 +6,7 @@ version: "3.7"  services:    postgres: -    image: postgres:11-alpine +    image: postgres:12-alpine      environment:        POSTGRES_DB: pysite        POSTGRES_PASSWORD: pysite diff --git a/tests/helpers.py b/tests/helpers.py index 2908294f7..25059fa3a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,10 @@ __all__ = ('AsyncMock', 'async_test')  # TODO: Remove me on 3.8 +# Allows you to mock a coroutine. Since the default `__call__` of `MagicMock` +# is not a coroutine, trying to mock a coroutine with it will result in errors +# as the default `__call__` is not awaitable. Use this class for monkeypatching +# coroutines instead.  class AsyncMock(MagicMock):      async def __call__(self, *args, **kwargs):          return super(AsyncMock, self).__call__(*args, **kwargs) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py new file mode 100644 index 000000000..4baa6395c --- /dev/null +++ b/tests/utils/test_time.py @@ -0,0 +1,62 @@ +import asyncio +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +    ('delta', 'precision', 'max_units', 'expected'), +    ( +        (relativedelta(days=2), 'seconds', 1, '2 days'), +        (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), +        (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), +        (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + +        # Does not abort for unknown units, as the unit name is checked +        # against the attribute of the relativedelta instance. +        (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), + +        # Very high maximum units, but it only ever iterates over +        # each value the relativedelta might have. +        (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), +    ) +) +def test_humanize_delta( +        delta: relativedelta, +        precision: str, +        max_units: int, +        expected: str +): +    assert time.humanize_delta(delta, precision, max_units) == expected + + [email protected]('max_units', (-1, 0)) +def test_humanize_delta_raises_for_invalid_max_units(max_units: int): +    with pytest.raises(ValueError, match='max_units must be positive'): +        time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + + +    ('stamp', 'expected'), +    ( +        ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), +    ) +) +def test_parse_rfc1123(stamp: str, expected: str): +    assert time.parse_rfc1123(stamp) == expected + + +@patch('asyncio.sleep', new_callable=AsyncMock) +def test_wait_until(sleep_patch): +    start = datetime(2019, 1, 1, 0, 0) +    then = datetime(2019, 1, 1, 0, 10) + +    # No return value +    assert asyncio.run(time.wait_until(then, start)) is None + +    sleep_patch.assert_called_once_with(10 * 60) | 
