diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | bot/__init__.py | 67 | ||||
| -rw-r--r-- | bot/__main__.py | 78 | ||||
| -rw-r--r-- | bot/bot.py | 60 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | bot/exts/backend/sync/_cog.py | 9 | ||||
| -rw-r--r-- | bot/exts/backend/sync/_syncers.py | 66 | ||||
| -rw-r--r-- | bot/exts/help_channels.py | 11 | ||||
| -rw-r--r-- | bot/exts/info/codeblock/_cog.py | 2 | ||||
| -rw-r--r-- | bot/exts/info/doc.py | 2 | ||||
| -rw-r--r-- | bot/exts/info/help.py | 6 | ||||
| -rw-r--r-- | bot/exts/info/tags.py | 2 | ||||
| -rw-r--r-- | bot/exts/utils/internal.py | 4 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox.py | 4 | ||||
| -rw-r--r-- | bot/interpreter.py | 6 | ||||
| -rw-r--r-- | bot/log.py | 86 | ||||
| -rw-r--r-- | bot/utils/channel.py | 7 | ||||
| -rw-r--r-- | bot/utils/messages.py | 4 | ||||
| -rw-r--r-- | bot/utils/services.py | 8 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_base.py | 29 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py | 34 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py | 26 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 29 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 4 | ||||
| -rw-r--r-- | tests/bot/utils/test_services.py | 39 | 
25 files changed, 304 insertions, 282 deletions
| diff --git a/.gitignore b/.gitignore index 2074887ad..9186dbe06 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,7 @@ ENV/  # Logfiles  log.*  *.log.* +!log.py  # Custom user configuration  config.yml diff --git a/bot/__init__.py b/bot/__init__.py index 4fce04532..8f880b8e6 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,78 +1,25 @@  import asyncio -import logging  import os -import sys  from functools import partial, partialmethod -from logging import Logger, handlers -from pathlib import Path +from typing import TYPE_CHECKING -import coloredlogs  from discord.ext import commands +from bot import log  from bot.command import Command -TRACE_LEVEL = logging.TRACE = 5 -logging.addLevelName(TRACE_LEVEL, "TRACE") - - -def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: -    """ -    Log 'msg % args' with severity 'TRACE'. - -    To pass exception information, use the keyword argument exc_info with -    a true value, e.g. - -    logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) -    """ -    if self.isEnabledFor(TRACE_LEVEL): -        self._log(TRACE_LEVEL, msg, args, **kwargs) - - -Logger.trace = monkeypatch_trace - -DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") - -log_level = TRACE_LEVEL if DEBUG_MODE else logging.INFO -format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" -log_format = logging.Formatter(format_string) - -log_file = Path("logs", "bot.log") -log_file.parent.mkdir(exist_ok=True) -file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8") -file_handler.setFormatter(log_format) - -root_log = logging.getLogger() -root_log.setLevel(log_level) -root_log.addHandler(file_handler) - -if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: -    coloredlogs.DEFAULT_LEVEL_STYLES = { -        **coloredlogs.DEFAULT_LEVEL_STYLES, -        "trace": {"color": 246}, -        "critical": {"background": "red"}, -        "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"] -    } - -if "COLOREDLOGS_LOG_FORMAT" not in os.environ: -    coloredlogs.DEFAULT_LOG_FORMAT = format_string - -if "COLOREDLOGS_LOG_LEVEL" not in os.environ: -    coloredlogs.DEFAULT_LOG_LEVEL = log_level - -coloredlogs.install(logger=root_log, stream=sys.stdout) - -logging.getLogger("discord").setLevel(logging.WARNING) -logging.getLogger("websockets").setLevel(logging.WARNING) -logging.getLogger("chardet").setLevel(logging.WARNING) -logging.getLogger("async_rediscache").setLevel(logging.WARNING) +if TYPE_CHECKING: +    from bot.bot import Bot +log.setup()  # On Windows, the selector event loop is required for aiodns.  if os.name == "nt":      asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -  # Monkey-patch discord.py decorators to use the Command subclass which supports root aliases.  # Must be patched before any cogs are added.  commands.command = partial(commands.command, cls=Command)  commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=Command) + +instance: "Bot" = None  # Global Bot instance. diff --git a/bot/__main__.py b/bot/__main__.py index 367be1300..257216fa7 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,76 +1,10 @@ -import asyncio -import logging - -import discord -import sentry_sdk -from async_rediscache import RedisSession -from discord.ext.commands import when_mentioned_or -from sentry_sdk.integrations.aiohttp import AioHttpIntegration -from sentry_sdk.integrations.logging import LoggingIntegration -from sentry_sdk.integrations.redis import RedisIntegration - +import bot  from bot import constants  from bot.bot import Bot -from bot.utils.extensions import EXTENSIONS - -# Set up Sentry. -sentry_logging = LoggingIntegration( -    level=logging.DEBUG, -    event_level=logging.WARNING -) - -sentry_sdk.init( -    dsn=constants.Bot.sentry_dsn, -    integrations=[ -        sentry_logging, -        AioHttpIntegration(), -        RedisIntegration(), -    ] -) - -# Create the redis session instance. -redis_session = RedisSession( -    address=(constants.Redis.host, constants.Redis.port), -    password=constants.Redis.password, -    minsize=1, -    maxsize=20, -    use_fakeredis=constants.Redis.use_fakeredis, -    global_namespace="bot", -) - -# Connect redis session to ensure it's connected before we try to access Redis -# from somewhere within the bot. We create the event loop in the same way -# discord.py normally does and pass it to the bot's __init__. -loop = asyncio.get_event_loop() -loop.run_until_complete(redis_session.connect()) - - -# Instantiate the bot. -allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] -intents = discord.Intents().all() -intents.presences = False -intents.dm_typing = False -intents.dm_reactions = False -intents.invites = False -intents.webhooks = False -intents.integrations = False -bot = Bot( -    redis_session=redis_session, -    loop=loop, -    command_prefix=when_mentioned_or(constants.Bot.prefix), -    activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), -    case_insensitive=True, -    max_messages=10_000, -    allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), -    intents=intents, -) - -# Load extensions. -extensions = set(EXTENSIONS)  # Create a mutable copy. -if not constants.HelpChannels.enable: -    extensions.remove("bot.exts.help_channels") +from bot.log import setup_sentry -for extension in extensions: -    bot.load_extension(extension) +setup_sentry() -bot.run(constants.Bot.token) +bot.instance = Bot.create() +bot.instance.load_extensions() +bot.instance.run(constants.Bot.token) diff --git a/bot/bot.py b/bot/bot.py index b2e5237fe..36cf7d30a 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -11,7 +11,7 @@ from async_rediscache import RedisSession  from discord.ext import commands  from sentry_sdk import push_scope -from bot import DEBUG_MODE, api, constants +from bot import api, constants  from bot.async_stats import AsyncStatsClient  log = logging.getLogger('bot') @@ -40,7 +40,7 @@ class Bot(commands.Bot):          statsd_url = constants.Stats.statsd_host -        if DEBUG_MODE: +        if constants.DEBUG_MODE:              # Since statsd is UDP, there are no errors for sending to a down port.              # For this reason, setting the statsd host to 127.0.0.1 for development              # will effectively disable stats. @@ -95,6 +95,43 @@ class Bot(commands.Bot):          # Build the FilterList cache          self.loop.create_task(self.cache_filter_list_data()) +    @classmethod +    def create(cls) -> "Bot": +        """Create and return an instance of a Bot.""" +        loop = asyncio.get_event_loop() +        allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] + +        intents = discord.Intents().all() +        intents.presences = False +        intents.dm_typing = False +        intents.dm_reactions = False +        intents.invites = False +        intents.webhooks = False +        intents.integrations = False + +        return cls( +            redis_session=_create_redis_session(loop), +            loop=loop, +            command_prefix=commands.when_mentioned_or(constants.Bot.prefix), +            activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), +            case_insensitive=True, +            max_messages=10_000, +            allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), +            intents=intents, +        ) + +    def load_extensions(self) -> None: +        """Load all enabled extensions.""" +        # Must be done here to avoid a circular import. +        from bot.utils.extensions import EXTENSIONS + +        extensions = set(EXTENSIONS)  # Create a mutable copy. +        if not constants.HelpChannels.enable: +            extensions.remove("bot.exts.help_channels") + +        for extension in extensions: +            self.load_extension(extension) +      def add_cog(self, cog: commands.Cog) -> None:          """Adds a "cog" to the bot and logs the operation."""          super().add_cog(cog) @@ -243,3 +280,22 @@ class Bot(commands.Bot):          for alias in getattr(command, "root_aliases", ()):              self.all_commands.pop(alias, None) + + +def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession: +    """ +    Create and connect to a redis session. + +    Ensure the connection is established before returning to prevent race conditions. +    `loop` is the event loop on which to connect. The Bot should use this same event loop. +    """ +    redis_session = RedisSession( +        address=(constants.Redis.host, constants.Redis.port), +        password=constants.Redis.password, +        minsize=1, +        maxsize=20, +        use_fakeredis=constants.Redis.use_fakeredis, +        global_namespace="bot", +    ) +    loop.run_until_complete(redis_session.connect()) +    return redis_session diff --git a/bot/constants.py b/bot/constants.py index 731f06fed..2126b2b37 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -632,7 +632,7 @@ class Event(Enum):  # Debug mode -DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False +DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")  # Paths  BOT_DIR = os.path.dirname(__file__) diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 6e85e2b7d..48d2b6f02 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -18,9 +18,6 @@ class Sync(Cog):      def __init__(self, bot: Bot) -> None:          self.bot = bot -        self.role_syncer = _syncers.RoleSyncer(self.bot) -        self.user_syncer = _syncers.UserSyncer(self.bot) -          self.bot.loop.create_task(self.sync_guild())      async def sync_guild(self) -> None: @@ -31,7 +28,7 @@ class Sync(Cog):          if guild is None:              return -        for syncer in (self.role_syncer, self.user_syncer): +        for syncer in (_syncers.RoleSyncer, _syncers.UserSyncer):              await syncer.sync(guild)      async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: @@ -171,10 +168,10 @@ class Sync(Cog):      @commands.has_permissions(administrator=True)      async def sync_roles_command(self, ctx: Context) -> None:          """Manually synchronise the guild's roles with the roles on the site.""" -        await self.role_syncer.sync(ctx.guild, ctx) +        await _syncers.RoleSyncer.sync(ctx.guild, ctx)      @sync_group.command(name='users')      @commands.has_permissions(administrator=True)      async def sync_users_command(self, ctx: Context) -> None:          """Manually synchronise the guild's users with the users on the site.""" -        await self.user_syncer.sync(ctx.guild, ctx) +        await _syncers.UserSyncer.sync(ctx.guild, ctx) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 38468c2b1..2eb9f9971 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -6,8 +6,8 @@ from collections import namedtuple  from discord import Guild  from discord.ext.commands import Context +import bot  from bot.api import ResponseCodeError -from bot.bot import Bot  log = logging.getLogger(__name__) @@ -17,57 +17,60 @@ _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))  _Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) +# Implementation of static abstract methods are not enforced if the subclass is never instantiated. +# However, methods are kept abstract to at least symbolise that they should be abstract.  class Syncer(abc.ABC):      """Base class for synchronising the database with objects in the Discord cache.""" -    def __init__(self, bot: Bot) -> None: -        self.bot = bot - +    @staticmethod      @property      @abc.abstractmethod -    def name(self) -> str: +    def name() -> str:          """The name of the syncer; used in output messages and logging."""          raise NotImplementedError  # pragma: no cover +    @staticmethod      @abc.abstractmethod -    async def _get_diff(self, guild: Guild) -> _Diff: +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference between the cache of `guild` and the database."""          raise NotImplementedError  # pragma: no cover +    @staticmethod      @abc.abstractmethod -    async def _sync(self, diff: _Diff) -> None: +    async def _sync(diff: _Diff) -> None:          """Perform the API calls for synchronisation."""          raise NotImplementedError  # pragma: no cover -    async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: +    @classmethod +    async def sync(cls, guild: Guild, ctx: t.Optional[Context] = None) -> None:          """          Synchronise the database with the cache of `guild`.          If `ctx` is given, send a message with the results.          """ -        log.info(f"Starting {self.name} syncer.") +        log.info(f"Starting {cls.name} syncer.")          if ctx: -            message = await ctx.send(f"📊 Synchronising {self.name}s.") +            message = await ctx.send(f"📊 Synchronising {cls.name}s.")          else:              message = None -        diff = await self._get_diff(guild) +        diff = await cls._get_diff(guild)          try: -            await self._sync(diff) +            await cls._sync(diff)          except ResponseCodeError as e: -            log.exception(f"{self.name} syncer failed!") +            log.exception(f"{cls.name} syncer failed!")              # Don't show response text because it's probably some really long HTML.              results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" -            content = f":x: Synchronisation of {self.name}s failed: {results}" +            content = f":x: Synchronisation of {cls.name}s failed: {results}"          else:              diff_dict = diff._asdict()              results = (f"{name} `{len(val)}`" for name, val in diff_dict.items() if val is not None)              results = ", ".join(results) -            log.info(f"{self.name} syncer finished: {results}.") -            content = f":ok_hand: Synchronisation of {self.name}s complete: {results}" +            log.info(f"{cls.name} syncer finished: {results}.") +            content = f":ok_hand: Synchronisation of {cls.name}s complete: {results}"          if message:              await message.edit(content=content) @@ -78,10 +81,11 @@ class RoleSyncer(Syncer):      name = "role" -    async def _get_diff(self, guild: Guild) -> _Diff: +    @staticmethod +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference of roles between the cache of `guild` and the database."""          log.trace("Getting the diff for roles.") -        roles = await self.bot.api_client.get('bot/roles') +        roles = await bot.instance.api_client.get('bot/roles')          # Pack DB roles and guild roles into one common, hashable format.          # They're hashable so that they're easily comparable with sets later. @@ -110,19 +114,20 @@ class RoleSyncer(Syncer):          return _Diff(roles_to_create, roles_to_update, roles_to_delete) -    async def _sync(self, diff: _Diff) -> None: +    @staticmethod +    async def _sync(diff: _Diff) -> None:          """Synchronise the database with the role cache of `guild`."""          log.trace("Syncing created roles...")          for role in diff.created: -            await self.bot.api_client.post('bot/roles', json=role._asdict()) +            await bot.instance.api_client.post('bot/roles', json=role._asdict())          log.trace("Syncing updated roles...")          for role in diff.updated: -            await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) +            await bot.instance.api_client.put(f'bot/roles/{role.id}', json=role._asdict())          log.trace("Syncing deleted roles...")          for role in diff.deleted: -            await self.bot.api_client.delete(f'bot/roles/{role.id}') +            await bot.instance.api_client.delete(f'bot/roles/{role.id}')  class UserSyncer(Syncer): @@ -130,7 +135,8 @@ class UserSyncer(Syncer):      name = "user" -    async def _get_diff(self, guild: Guild) -> _Diff: +    @staticmethod +    async def _get_diff(guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") @@ -138,7 +144,7 @@ class UserSyncer(Syncer):          users_to_update = []          seen_guild_users = set() -        async for db_user in self._get_users(): +        async for db_user in UserSyncer._get_users():              # Store user fields which are to be updated.              updated_fields = {} @@ -185,24 +191,26 @@ class UserSyncer(Syncer):          return _Diff(users_to_create, users_to_update, None) -    async def _get_users(self) -> t.AsyncIterable: +    @staticmethod +    async def _get_users() -> t.AsyncIterable:          """GET users from database."""          query_params = {              "page": 1          }          while query_params["page"]: -            res = await self.bot.api_client.get("bot/users", params=query_params) +            res = await bot.instance.api_client.get("bot/users", params=query_params)              for user in res["results"]:                  yield user              query_params["page"] = res["next_page_no"] -    async def _sync(self, diff: _Diff) -> None: +    @staticmethod +    async def _sync(diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...")          if diff.created: -            await self.bot.api_client.post("bot/users", json=diff.created) +            await bot.instance.api_client.post("bot/users", json=diff.created)          log.trace("Syncing updated users...")          if diff.updated: -            await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) +            await bot.instance.api_client.patch("bot/users/bulk_patch", json=diff.updated) diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index 062d4fcfe..f5a8b251b 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -380,16 +380,13 @@ class HelpChannels(commands.Cog):          try:              self.available_category = await channel_utils.try_get_channel( -                constants.Categories.help_available, -                self.bot +                constants.Categories.help_available              )              self.in_use_category = await channel_utils.try_get_channel( -                constants.Categories.help_in_use, -                self.bot +                constants.Categories.help_in_use              )              self.dormant_category = await channel_utils.try_get_channel( -                constants.Categories.help_dormant, -                self.bot +                constants.Categories.help_dormant              )          except discord.HTTPException:              log.exception("Failed to get a category; cog will be removed") @@ -500,7 +497,7 @@ class HelpChannels(commands.Cog):          options should be avoided, as it may interfere with the category move we perform.          """          # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. -        category = await channel_utils.try_get_channel(category_id, self.bot) +        category = await channel_utils.try_get_channel(category_id)          payload = [{"id": c.id, "position": c.position} for c in category.channels] diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index 1e0feab0d..9094d9d15 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -114,7 +114,7 @@ class CodeBlockCog(Cog, name="Code Block"):          bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed)          self.codeblock_message_ids[message.id] = bot_message.id -        self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot)) +        self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,)))          # Increase amount of codeblock correction in stats          self.bot.stats.incr("codeblock_corrections") diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py index 7ec8caa4b..9b5bd6504 100644 --- a/bot/exts/info/doc.py +++ b/bot/exts/info/doc.py @@ -365,7 +365,7 @@ class Doc(commands.Cog):                      await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY)              else:                  msg = await ctx.send(embed=doc_embed) -                await wait_for_deletion(msg, (ctx.author.id,), client=self.bot) +                await wait_for_deletion(msg, (ctx.author.id,))      @docs_group.command(name='set', aliases=('s',))      @commands.has_any_role(*MODERATION_ROLES) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py index 599c5d5c0..461ff82fd 100644 --- a/bot/exts/info/help.py +++ b/bot/exts/info/help.py @@ -186,7 +186,7 @@ class CustomHelpCommand(HelpCommand):          """Send help for a single command."""          embed = await self.command_formatting(command)          message = await self.context.send(embed=embed) -        await wait_for_deletion(message, (self.context.author.id,), self.context.bot) +        await wait_for_deletion(message, (self.context.author.id,))      @staticmethod      def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: @@ -225,7 +225,7 @@ class CustomHelpCommand(HelpCommand):              embed.description += f"\n**Subcommands:**\n{command_details}"          message = await self.context.send(embed=embed) -        await wait_for_deletion(message, (self.context.author.id,), self.context.bot) +        await wait_for_deletion(message, (self.context.author.id,))      async def send_cog_help(self, cog: Cog) -> None:          """Send help for a cog.""" @@ -241,7 +241,7 @@ class CustomHelpCommand(HelpCommand):              embed.description += f"\n\n**Commands:**\n{command_details}"          message = await self.context.send(embed=embed) -        await wait_for_deletion(message, (self.context.author.id,), self.context.bot) +        await wait_for_deletion(message, (self.context.author.id,))      @staticmethod      def _category_key(command: Command) -> str: diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index ae95ac1ef..8f15f932b 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -236,7 +236,6 @@ class Tags(Cog):                  await wait_for_deletion(                      await ctx.send(embed=Embed.from_dict(tag['embed'])),                      [ctx.author.id], -                    self.bot                  )              elif founds and len(tag_name) >= 3:                  await wait_for_deletion( @@ -247,7 +246,6 @@ class Tags(Cog):                          )                      ),                      [ctx.author.id], -                    self.bot                  )          else: diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 1b4900f42..3521c8fd4 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -30,7 +30,7 @@ class Internal(Cog):          self.ln = 0          self.stdout = StringIO() -        self.interpreter = Interpreter(bot) +        self.interpreter = Interpreter()          self.socket_since = datetime.utcnow()          self.socket_event_total = 0 @@ -195,7 +195,7 @@ async def func():  # (None,) -> Any              truncate_index = newline_truncate_index          if len(out) > truncate_index: -            paste_link = await send_to_paste_service(self.bot.http_session, out, extension="py") +            paste_link = await send_to_paste_service(out, extension="py")              if paste_link is not None:                  paste_text = f"full contents at {paste_link}"              else: diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 41cb00541..9f480c067 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -70,7 +70,7 @@ class Snekbox(Cog):          if len(output) > MAX_PASTE_LEN:              log.info("Full output is too long to upload")              return "too long to upload" -        return await send_to_paste_service(self.bot.http_session, output, extension="txt") +        return await send_to_paste_service(output, extension="txt")      @staticmethod      def prepare_input(code: str) -> str: @@ -219,7 +219,7 @@ class Snekbox(Cog):                  response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")              else:                  response = await ctx.send(msg) -            self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,), ctx.bot)) +            self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,)))              log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")          return response diff --git a/bot/interpreter.py b/bot/interpreter.py index 8b7268746..b58f7a6b0 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -4,7 +4,7 @@ from typing import Any  from discord.ext.commands import Context -from bot.bot import Bot +import bot  CODE_TEMPLATE = """  async def _func(): @@ -21,8 +21,8 @@ class Interpreter(InteractiveInterpreter):      write_callable = None -    def __init__(self, bot: Bot): -        locals_ = {"bot": bot} +    def __init__(self): +        locals_ = {"bot": bot.instance}          super().__init__(locals_)      async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any: diff --git a/bot/log.py b/bot/log.py new file mode 100644 index 000000000..13141de40 --- /dev/null +++ b/bot/log.py @@ -0,0 +1,86 @@ +import logging +import os +import sys +from logging import Logger, handlers +from pathlib import Path + +import coloredlogs +import sentry_sdk +from sentry_sdk.integrations.aiohttp import AioHttpIntegration +from sentry_sdk.integrations.logging import LoggingIntegration +from sentry_sdk.integrations.redis import RedisIntegration + +from bot import constants + +TRACE_LEVEL = 5 + + +def setup() -> None: +    """Set up loggers.""" +    logging.TRACE = TRACE_LEVEL +    logging.addLevelName(TRACE_LEVEL, "TRACE") +    Logger.trace = _monkeypatch_trace + +    log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO +    format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" +    log_format = logging.Formatter(format_string) + +    log_file = Path("logs", "bot.log") +    log_file.parent.mkdir(exist_ok=True) +    file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8") +    file_handler.setFormatter(log_format) + +    root_log = logging.getLogger() +    root_log.setLevel(log_level) +    root_log.addHandler(file_handler) + +    if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: +        coloredlogs.DEFAULT_LEVEL_STYLES = { +            **coloredlogs.DEFAULT_LEVEL_STYLES, +            "trace": {"color": 246}, +            "critical": {"background": "red"}, +            "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"] +        } + +    if "COLOREDLOGS_LOG_FORMAT" not in os.environ: +        coloredlogs.DEFAULT_LOG_FORMAT = format_string + +    if "COLOREDLOGS_LOG_LEVEL" not in os.environ: +        coloredlogs.DEFAULT_LOG_LEVEL = log_level + +    coloredlogs.install(logger=root_log, stream=sys.stdout) + +    logging.getLogger("discord").setLevel(logging.WARNING) +    logging.getLogger("websockets").setLevel(logging.WARNING) +    logging.getLogger("chardet").setLevel(logging.WARNING) +    logging.getLogger("async_rediscache").setLevel(logging.WARNING) + + +def setup_sentry() -> None: +    """Set up the Sentry logging integrations.""" +    sentry_logging = LoggingIntegration( +        level=logging.DEBUG, +        event_level=logging.WARNING +    ) + +    sentry_sdk.init( +        dsn=constants.Bot.sentry_dsn, +        integrations=[ +            sentry_logging, +            AioHttpIntegration(), +            RedisIntegration(), +        ] +    ) + + +def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: +    """ +    Log 'msg % args' with severity 'TRACE'. + +    To pass exception information, use the keyword argument exc_info with +    a true value, e.g. + +    logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) +    """ +    if self.isEnabledFor(TRACE_LEVEL): +        self._log(TRACE_LEVEL, msg, args, **kwargs) diff --git a/bot/utils/channel.py b/bot/utils/channel.py index 6bf70bfde..0c072184c 100644 --- a/bot/utils/channel.py +++ b/bot/utils/channel.py @@ -2,6 +2,7 @@ import logging  import discord +import bot  from bot import constants  from bot.constants import Categories @@ -36,14 +37,14 @@ def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:      return getattr(channel, "category_id", None) == category_id -async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel: +async def try_get_channel(channel_id: int) -> discord.abc.GuildChannel:      """Attempt to get or fetch a channel and return it."""      log.trace(f"Getting the channel {channel_id}.") -    channel = client.get_channel(channel_id) +    channel = bot.instance.get_channel(channel_id)      if not channel:          log.debug(f"Channel {channel_id} is not in cache; fetching from API.") -        channel = await client.fetch_channel(channel_id) +        channel = await bot.instance.fetch_channel(channel_id)      log.trace(f"Channel #{channel} ({channel_id}) retrieved.")      return channel diff --git a/bot/utils/messages.py b/bot/utils/messages.py index b6c7cab50..42bde358d 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -10,6 +10,7 @@ import discord  from discord.errors import HTTPException  from discord.ext.commands import Context +import bot  from bot.constants import Emojis, NEGATIVE_REPLIES  log = logging.getLogger(__name__) @@ -18,7 +19,6 @@ log = logging.getLogger(__name__)  async def wait_for_deletion(      message: discord.Message,      user_ids: Sequence[discord.abc.Snowflake], -    client: discord.Client,      deletion_emojis: Sequence[str] = (Emojis.trashcan,),      timeout: float = 60 * 5,      attach_emojis: bool = True, @@ -49,7 +49,7 @@ async def wait_for_deletion(          )      with contextlib.suppress(asyncio.TimeoutError): -        await client.wait_for('reaction_add', check=check, timeout=timeout) +        await bot.instance.wait_for('reaction_add', check=check, timeout=timeout)          await message.delete() diff --git a/bot/utils/services.py b/bot/utils/services.py index 087b9f969..5949c9e48 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -1,8 +1,9 @@  import logging  from typing import Optional -from aiohttp import ClientConnectorError, ClientSession +from aiohttp import ClientConnectorError +import bot  from bot.constants import URLs  log = logging.getLogger(__name__) @@ -10,11 +11,10 @@ log = logging.getLogger(__name__)  FAILED_REQUEST_ATTEMPTS = 3 -async def send_to_paste_service(http_session: ClientSession, contents: str, *, extension: str = "") -> Optional[str]: +async def send_to_paste_service(contents: str, *, extension: str = "") -> Optional[str]:      """      Upload `contents` to the paste service. -    `http_session` should be the current running ClientSession from aiohttp      `extension` is added to the output URL      When an error occurs, `None` is returned, otherwise the generated URL with the suffix. @@ -24,7 +24,7 @@ async def send_to_paste_service(http_session: ClientSession, contents: str, *, e      paste_url = URLs.paste_service.format(key="documents")      for attempt in range(1, FAILED_REQUEST_ATTEMPTS + 1):          try: -            async with http_session.post(paste_url, data=contents) as response: +            async with bot.instance.http_session.post(paste_url, data=contents) as response:                  response_json = await response.json()          except ClientConnectorError:              log.warning( diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 4953550f9..3ad9db9c3 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -15,28 +15,21 @@ class TestSyncer(Syncer):      _sync = mock.AsyncMock() -class SyncerBaseTests(unittest.TestCase): -    """Tests for the syncer base class.""" - -    def setUp(self): -        self.bot = helpers.MockBot() - -    def test_instantiation_fails_without_abstract_methods(self): -        """The class must have abstract methods implemented.""" -        with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): -            Syncer(self.bot) - -  class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for main function orchestrating the sync."""      def setUp(self): -        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) -        self.syncer = TestSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True))) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) +          self.guild = helpers.MockGuild() +        TestSyncer._get_diff.reset_mock(return_value=True, side_effect=True) +        TestSyncer._sync.reset_mock(return_value=True, side_effect=True) +          # Make sure `_get_diff` returns a MagicMock, not an AsyncMock -        self.syncer._get_diff.return_value = mock.MagicMock() +        TestSyncer._get_diff.return_value = mock.MagicMock()      async def test_sync_message_edited(self):          """The message should be edited if one was sent, even if the sync has an API error.""" @@ -48,11 +41,11 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):          for message, side_effect, should_edit in subtests:              with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): -                self.syncer._sync.side_effect = side_effect +                TestSyncer._sync.side_effect = side_effect                  ctx = helpers.MockContext()                  ctx.send.return_value = message -                await self.syncer.sync(self.guild, ctx) +                await TestSyncer.sync(self.guild, ctx)                  if should_edit:                      message.edit.assert_called_once() @@ -67,7 +60,7 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):          for ctx, message in subtests:              with self.subTest(ctx=ctx, message=message): -                await self.syncer.sync(self.guild, ctx) +                await TestSyncer.sync(self.guild, ctx)                  if ctx is not None:                      ctx.send.assert_called_once() diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 063a82754..22a07313e 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = helpers.MockBot() -        self.role_syncer_patcher = mock.patch( +        role_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.RoleSyncer",              autospec=Syncer,              spec_set=True          ) -        self.user_syncer_patcher = mock.patch( +        user_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.UserSyncer",              autospec=Syncer,              spec_set=True          ) -        self.RoleSyncer = self.role_syncer_patcher.start() -        self.UserSyncer = self.user_syncer_patcher.start() -        self.cog = Sync(self.bot) +        self.RoleSyncer = role_syncer_patcher.start() +        self.UserSyncer = user_syncer_patcher.start() -    def tearDown(self): -        self.role_syncer_patcher.stop() -        self.user_syncer_patcher.stop() +        self.addCleanup(role_syncer_patcher.stop) +        self.addCleanup(user_syncer_patcher.stop) + +        self.cog = Sync(self.bot)      @staticmethod      def response_error(status: int) -> ResponseCodeError: @@ -73,8 +73,6 @@ class SyncCogTests(SyncCogTestCase):          Sync(self.bot) -        self.RoleSyncer.assert_called_once_with(self.bot) -        self.UserSyncer.assert_called_once_with(self.bot)          sync_guild.assert_called_once_with()          self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) @@ -83,8 +81,8 @@ class SyncCogTests(SyncCogTestCase):          for guild in (helpers.MockGuild(), None):              with self.subTest(guild=guild):                  self.bot.reset_mock() -                self.cog.role_syncer.reset_mock() -                self.cog.user_syncer.reset_mock() +                self.RoleSyncer.reset_mock() +                self.UserSyncer.reset_mock()                  self.bot.get_guild = mock.MagicMock(return_value=guild) @@ -94,11 +92,11 @@ class SyncCogTests(SyncCogTestCase):                  self.bot.get_guild.assert_called_once_with(constants.Guild.id)                  if guild is None: -                    self.cog.role_syncer.sync.assert_not_called() -                    self.cog.user_syncer.sync.assert_not_called() +                    self.RoleSyncer.sync.assert_not_called() +                    self.UserSyncer.sync.assert_not_called()                  else: -                    self.cog.role_syncer.sync.assert_called_once_with(guild) -                    self.cog.user_syncer.sync.assert_called_once_with(guild) +                    self.RoleSyncer.sync.assert_called_once_with(guild) +                    self.UserSyncer.sync.assert_called_once_with(guild)      async def patch_user_helper(self, side_effect: BaseException) -> None:          """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" @@ -394,14 +392,14 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):          ctx = helpers.MockContext()          await self.cog.sync_roles_command(self.cog, ctx) -        self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) +        self.RoleSyncer.sync.assert_called_once_with(ctx.guild, ctx)      async def test_sync_users_command(self):          """sync() should be called on the UserSyncer."""          ctx = helpers.MockContext()          await self.cog.sync_users_command(self.cog, ctx) -        self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) +        self.UserSyncer.sync.assert_called_once_with(ctx.guild, ctx)      async def test_commands_require_admin(self):          """The sync commands should only run if the author has the administrator permission.""" diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 7b9f40cad..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -22,8 +22,9 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between roles in the DB and roles in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @staticmethod      def get_guild(*roles): @@ -44,7 +45,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role()]          guild = self.get_guild(fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), set(), set())          self.assertEqual(actual_diff, expected_diff) @@ -56,7 +57,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]          guild = self.get_guild(updated_role, fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), {_Role(**updated_role)}, set())          self.assertEqual(actual_diff, expected_diff) @@ -68,7 +69,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role()]          guild = self.get_guild(fake_role(), new_role) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = ({_Role(**new_role)}, set(), set())          self.assertEqual(actual_diff, expected_diff) @@ -80,7 +81,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role(), deleted_role]          guild = self.get_guild(fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), set(), {_Role(**deleted_role)})          self.assertEqual(actual_diff, expected_diff) @@ -98,7 +99,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          ]          guild = self.get_guild(fake_role(), new, updated) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})          self.assertEqual(actual_diff, expected_diff) @@ -108,8 +109,9 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync roles."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      async def test_sync_created_roles(self):          """Only POST requests should be made with the correct payload.""" @@ -117,7 +119,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(role_tuples, set(), set()) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call("bot/roles", json=role) for role in roles]          self.bot.api_client.post.assert_has_calls(calls, any_order=True) @@ -132,7 +134,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(set(), role_tuples, set()) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]          self.bot.api_client.put.assert_has_calls(calls, any_order=True) @@ -147,7 +149,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(set(), set(), role_tuples) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]          self.bot.api_client.delete.assert_has_calls(calls, any_order=True) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 9f380a15d..61673e1bb 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,4 +1,5 @@  import unittest +from unittest import mock  from bot.exts.backend.sync._syncers import UserSyncer, _Diff  from tests import helpers @@ -19,8 +20,9 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @staticmethod      def get_guild(*members): @@ -57,7 +59,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          }          guild = self.get_guild() -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -73,7 +75,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          guild = self.get_guild(fake_user())          guild.get_member.return_value = self.get_mock_member(fake_user()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -94,7 +96,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):              self.get_mock_member(fake_user())          ] -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([], [{"id": 99, "name": "new"}], None)          self.assertEqual(actual_diff, expected_diff) @@ -114,7 +116,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):              self.get_mock_member(fake_user()),              self.get_mock_member(new_user)          ] -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([new_user], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -133,7 +135,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):              None          ] -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([], [{"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff) @@ -157,7 +159,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):              None          ] -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff) @@ -176,7 +178,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):              None          ] -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await UserSyncer._get_diff(guild)          expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -186,15 +188,16 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)]          diff = _Diff(users, [], None) -        await self.syncer._sync(diff) +        await UserSyncer._sync(diff)          self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created) @@ -206,7 +209,7 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          users = [fake_user(id=111), fake_user(id=222)]          diff = _Diff([], users, None) -        await self.syncer._sync(diff) +        await UserSyncer._sync(diff)          self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 9a42d0610..321a92445 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -42,9 +42,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      async def test_upload_output(self, mock_paste_util):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""          await self.cog.upload_output("Test output.") -        mock_paste_util.assert_called_once_with( -            self.bot.http_session, "Test output.", extension="txt" -        ) +        mock_paste_util.assert_called_once_with("Test output.", extension="txt")      def test_prepare_input(self):          cases = ( diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 5e0855704..1b48f6560 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -5,11 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch  from aiohttp import ClientConnectorError  from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from tests.helpers import MockBot  class PasteTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None: -        self.http_session = MagicMock() +        patcher = patch("bot.instance", new=MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")      async def test_url_and_sent_contents(self): @@ -17,10 +20,10 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):          response = MagicMock(              json=AsyncMock(return_value={"key": ""})          ) -        self.http_session.post().__aenter__.return_value = response -        self.http_session.post.reset_mock() -        await send_to_paste_service(self.http_session, "Content") -        self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") +        self.bot.http_session.post.return_value.__aenter__.return_value = response +        self.bot.http_session.post.reset_mock() +        await send_to_paste_service("Content") +        self.bot.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content")      @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")      async def test_paste_returns_correct_url_on_success(self): @@ -34,41 +37,41 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):          response = MagicMock(              json=AsyncMock(return_value={"key": key})          ) -        self.http_session.post().__aenter__.return_value = response +        self.bot.http_session.post.return_value.__aenter__.return_value = response          for expected_output, extension in test_cases:              with self.subTest(msg=f"Send contents with extension {repr(extension)}"):                  self.assertEqual( -                    await send_to_paste_service(self.http_session, "", extension=extension), +                    await send_to_paste_service("", extension=extension),                      expected_output                  )      async def test_request_repeated_on_json_errors(self):          """Json with error message and invalid json are handled as errors and requests repeated."""          test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) -        self.http_session.post().__aenter__.return_value = response = MagicMock() -        self.http_session.post.reset_mock() +        self.bot.http_session.post.return_value.__aenter__.return_value = response = MagicMock() +        self.bot.http_session.post.reset_mock()          for error_json in test_cases:              with self.subTest(error_json=error_json):                  response.json = AsyncMock(return_value=error_json) -                result = await send_to_paste_service(self.http_session, "") -                self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +                result = await send_to_paste_service("") +                self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)                  self.assertIsNone(result) -            self.http_session.post.reset_mock() +            self.bot.http_session.post.reset_mock()      async def test_request_repeated_on_connection_errors(self):          """Requests are repeated in the case of connection errors.""" -        self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) -        result = await send_to_paste_service(self.http_session, "") -        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) +        result = await send_to_paste_service("") +        self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)          self.assertIsNone(result)      async def test_general_error_handled_and_request_repeated(self):          """All `Exception`s are handled, logged and request repeated.""" -        self.http_session.post = MagicMock(side_effect=Exception) -        result = await send_to_paste_service(self.http_session, "") -        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.bot.http_session.post = MagicMock(side_effect=Exception) +        result = await send_to_paste_service("") +        self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)          self.assertLogs("bot.utils", logging.ERROR)          self.assertIsNone(result) | 
